263 lines
6.7 KiB
Python
263 lines
6.7 KiB
Python
import json
|
|
import mysql.connector
|
|
import numpy as np
|
|
from config import DB_CONFIG
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
def get_db_connection():
|
|
"""Create a connection to the MySQL database"""
|
|
return mysql.connector.connect(**DB_CONFIG)
|
|
|
|
|
|
def add_document(title, content, source=None, doc_type=None):
|
|
"""Add a document to the database and return its ID"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
query = """
|
|
INSERT INTO documents (title, content, source, doc_type)
|
|
VALUES (%s, %s, %s, %s)
|
|
"""
|
|
|
|
cursor.execute(query, (title, content, source, doc_type))
|
|
document_id = cursor.lastrowid
|
|
|
|
conn.commit()
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return document_id
|
|
|
|
|
|
def store_embedding(document_id, embedding):
|
|
"""Store an embedding for a document"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# Convert numpy array to list and store as JSON
|
|
embedding_json = json.dumps(embedding.tolist() if isinstance(embedding, np.ndarray) else embedding)
|
|
|
|
query = """
|
|
INSERT INTO embeddings (document_id, embedding)
|
|
VALUES (%s, %s)
|
|
"""
|
|
|
|
cursor.execute(query, (document_id, embedding_json))
|
|
|
|
conn.commit()
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def get_all_documents(include_embeddings: bool = False) -> List[Dict[str, Any]]:
|
|
"""Get all documents from the database"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
try:
|
|
if include_embeddings:
|
|
# Get documents with their embeddings
|
|
cursor.execute("""
|
|
SELECT d.*, e.embedding
|
|
FROM documents d
|
|
LEFT JOIN embeddings e ON d.id = e.document_id
|
|
ORDER BY d.created_at DESC
|
|
""")
|
|
else:
|
|
# Get documents without embeddings
|
|
cursor.execute("""
|
|
SELECT * FROM documents
|
|
ORDER BY created_at DESC
|
|
""")
|
|
|
|
documents = cursor.fetchall()
|
|
return documents
|
|
|
|
except mysql.connector.Error as err:
|
|
print(f"Error fetching documents: {err}")
|
|
return []
|
|
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def get_document(doc_id: int) -> Optional[Dict[str, Any]]:
|
|
"""Get a specific document by ID"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
try:
|
|
# Get document with its embedding
|
|
cursor.execute("""
|
|
SELECT d.*, e.embedding
|
|
FROM documents d
|
|
LEFT JOIN embeddings e ON d.id = e.document_id
|
|
WHERE d.id = %s
|
|
""", (doc_id,))
|
|
|
|
document = cursor.fetchone()
|
|
return document
|
|
|
|
except mysql.connector.Error as err:
|
|
print(f"Error fetching document {doc_id}: {err}")
|
|
return None
|
|
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def get_all_embeddings():
|
|
"""Get all embeddings with their associated documents"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
query = """
|
|
SELECT e.id, e.document_id, e.embedding, d.title, d.content, d.source, d.doc_type
|
|
FROM embeddings e
|
|
JOIN documents d ON e.document_id = d.id
|
|
"""
|
|
|
|
cursor.execute(query)
|
|
results = cursor.fetchall()
|
|
|
|
# Convert JSON strings to Python lists
|
|
for result in results:
|
|
result['embedding'] = json.loads(result['embedding'])
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return results
|
|
|
|
def log_search(query, results):
|
|
"""Log a search query and its results"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
query_sql = """
|
|
INSERT INTO search_logs (query, results)
|
|
VALUES (%s, %s)
|
|
"""
|
|
|
|
results_json = json.dumps(results)
|
|
cursor.execute(query_sql, (query, results_json))
|
|
|
|
conn.commit()
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def search_documents_by_keyword(keyword, limit=10):
|
|
"""Basic keyword search in documents"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
query = """
|
|
SELECT * FROM documents
|
|
WHERE MATCH(title, content) AGAINST(%s IN NATURAL LANGUAGE MODE)
|
|
LIMIT %s
|
|
"""
|
|
|
|
cursor.execute(query, (keyword, limit))
|
|
results = cursor.fetchall()
|
|
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
return results
|
|
|
|
def get_search_history() -> List[Dict[str, Any]]:
|
|
"""Get search history"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor(dictionary=True)
|
|
|
|
try:
|
|
cursor.execute("""
|
|
SELECT * FROM search_logs
|
|
ORDER BY created_at DESC
|
|
LIMIT 100
|
|
""")
|
|
|
|
history = cursor.fetchall()
|
|
return history
|
|
|
|
except mysql.connector.Error as err:
|
|
print(f"Error fetching search history: {err}")
|
|
return []
|
|
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def add_document(title: str, content: str, source: str = None, doc_type: str = None) -> Optional[int]:
|
|
"""Add a new document to the database"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute("""
|
|
INSERT INTO documents (title, content, source, doc_type)
|
|
VALUES (%s, %s, %s, %s)
|
|
""", (title, content, source, doc_type))
|
|
|
|
doc_id = cursor.lastrowid
|
|
conn.commit()
|
|
return doc_id
|
|
|
|
except mysql.connector.Error as err:
|
|
print(f"Error adding document: {err}")
|
|
conn.rollback()
|
|
return None
|
|
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def store_embedding(doc_id: int, embedding: List[float]) -> bool:
|
|
"""Store embedding for a document"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute("""
|
|
INSERT INTO embeddings (document_id, embedding)
|
|
VALUES (%s, %s)
|
|
""", (doc_id, embedding))
|
|
|
|
conn.commit()
|
|
return True
|
|
|
|
except mysql.connector.Error as err:
|
|
print(f"Error storing embedding for document {doc_id}: {err}")
|
|
conn.rollback()
|
|
return False
|
|
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def log_search(query: str, results: Dict[str, Any]) -> bool:
|
|
"""Log a search query and its results"""
|
|
conn = get_db_connection()
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute("""
|
|
INSERT INTO search_logs (query, results)
|
|
VALUES (%s, %s)
|
|
""", (query, results))
|
|
|
|
conn.commit()
|
|
return True
|
|
|
|
except mysql.connector.Error as err:
|
|
print(f"Error logging search: {err}")
|
|
conn.rollback()
|
|
return False
|
|
|
|
finally:
|
|
cursor.close()
|
|
conn.close() |