agc-chatbot/db/db_utils.py

263 lines
6.7 KiB
Python

import json
import mysql.connector
import numpy as np
from config import DB_CONFIG
import os
from dotenv import load_dotenv
from typing import List, Dict, Any, Optional
# Load environment variables
load_dotenv()
def get_db_connection():
"""Create a connection to the MySQL database"""
return mysql.connector.connect(**DB_CONFIG)
def add_document(title, content, source=None, doc_type=None):
"""Add a document to the database and return its ID"""
conn = get_db_connection()
cursor = conn.cursor()
query = """
INSERT INTO documents (title, content, source, doc_type)
VALUES (%s, %s, %s, %s)
"""
cursor.execute(query, (title, content, source, doc_type))
document_id = cursor.lastrowid
conn.commit()
cursor.close()
conn.close()
return document_id
def store_embedding(document_id, embedding):
"""Store an embedding for a document"""
conn = get_db_connection()
cursor = conn.cursor()
# Convert numpy array to list and store as JSON
embedding_json = json.dumps(embedding.tolist() if isinstance(embedding, np.ndarray) else embedding)
query = """
INSERT INTO embeddings (document_id, embedding)
VALUES (%s, %s)
"""
cursor.execute(query, (document_id, embedding_json))
conn.commit()
cursor.close()
conn.close()
def get_all_documents(include_embeddings: bool = False) -> List[Dict[str, Any]]:
"""Get all documents from the database"""
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
try:
if include_embeddings:
# Get documents with their embeddings
cursor.execute("""
SELECT d.*, e.embedding
FROM documents d
LEFT JOIN embeddings e ON d.id = e.document_id
ORDER BY d.created_at DESC
""")
else:
# Get documents without embeddings
cursor.execute("""
SELECT * FROM documents
ORDER BY created_at DESC
""")
documents = cursor.fetchall()
return documents
except mysql.connector.Error as err:
print(f"Error fetching documents: {err}")
return []
finally:
cursor.close()
conn.close()
def get_document(doc_id: int) -> Optional[Dict[str, Any]]:
"""Get a specific document by ID"""
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
try:
# Get document with its embedding
cursor.execute("""
SELECT d.*, e.embedding
FROM documents d
LEFT JOIN embeddings e ON d.id = e.document_id
WHERE d.id = %s
""", (doc_id,))
document = cursor.fetchone()
return document
except mysql.connector.Error as err:
print(f"Error fetching document {doc_id}: {err}")
return None
finally:
cursor.close()
conn.close()
def get_all_embeddings():
"""Get all embeddings with their associated documents"""
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT e.id, e.document_id, e.embedding, d.title, d.content, d.source, d.doc_type
FROM embeddings e
JOIN documents d ON e.document_id = d.id
"""
cursor.execute(query)
results = cursor.fetchall()
# Convert JSON strings to Python lists
for result in results:
result['embedding'] = json.loads(result['embedding'])
cursor.close()
conn.close()
return results
def log_search(query, results):
"""Log a search query and its results"""
conn = get_db_connection()
cursor = conn.cursor()
query_sql = """
INSERT INTO search_logs (query, results)
VALUES (%s, %s)
"""
results_json = json.dumps(results)
cursor.execute(query_sql, (query, results_json))
conn.commit()
cursor.close()
conn.close()
def search_documents_by_keyword(keyword, limit=10):
"""Basic keyword search in documents"""
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
query = """
SELECT * FROM documents
WHERE MATCH(title, content) AGAINST(%s IN NATURAL LANGUAGE MODE)
LIMIT %s
"""
cursor.execute(query, (keyword, limit))
results = cursor.fetchall()
cursor.close()
conn.close()
return results
def get_search_history() -> List[Dict[str, Any]]:
"""Get search history"""
conn = get_db_connection()
cursor = conn.cursor(dictionary=True)
try:
cursor.execute("""
SELECT * FROM search_logs
ORDER BY created_at DESC
LIMIT 100
""")
history = cursor.fetchall()
return history
except mysql.connector.Error as err:
print(f"Error fetching search history: {err}")
return []
finally:
cursor.close()
conn.close()
def add_document(title: str, content: str, source: str = None, doc_type: str = None) -> Optional[int]:
"""Add a new document to the database"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute("""
INSERT INTO documents (title, content, source, doc_type)
VALUES (%s, %s, %s, %s)
""", (title, content, source, doc_type))
doc_id = cursor.lastrowid
conn.commit()
return doc_id
except mysql.connector.Error as err:
print(f"Error adding document: {err}")
conn.rollback()
return None
finally:
cursor.close()
conn.close()
def store_embedding(doc_id: int, embedding: List[float]) -> bool:
"""Store embedding for a document"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute("""
INSERT INTO embeddings (document_id, embedding)
VALUES (%s, %s)
""", (doc_id, embedding))
conn.commit()
return True
except mysql.connector.Error as err:
print(f"Error storing embedding for document {doc_id}: {err}")
conn.rollback()
return False
finally:
cursor.close()
conn.close()
def log_search(query: str, results: Dict[str, Any]) -> bool:
"""Log a search query and its results"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute("""
INSERT INTO search_logs (query, results)
VALUES (%s, %s)
""", (query, results))
conn.commit()
return True
except mysql.connector.Error as err:
print(f"Error logging search: {err}")
conn.rollback()
return False
finally:
cursor.close()
conn.close()