agc-chatbot/embedding/enhanced_rag_service.py

185 lines
7.3 KiB
Python

import os
import json
import numpy as np
from typing import List, Dict, Any
from db.db_utils import get_all_documents, get_document
from openai import OpenAI
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure OpenAI client
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
EMBEDDING_MODEL = "text-embedding-ada-002"
CHAT_MODEL = "gpt-3.5-turbo"
def generate_embedding(text: str) -> List[float]:
"""Generate embedding for the given text using OpenAI's API"""
try:
response = client.embeddings.create(
model=EMBEDDING_MODEL,
input=text
)
return response.data[0].embedding
except Exception as e:
print(f"Error generating embedding: {e}")
# Return a dummy embedding of appropriate size if API fails
return [0.0] * 1536
def cosine_similarity(a: List[float], b: List[float]) -> float:
"""Calculate cosine similarity between two vectors"""
a = np.array(a)
b = np.array(b)
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return np.dot(a, b) / (norm_a * norm_b)
def enhance_search_query(query: str) -> str:
"""Enhance the search query using OpenAI's chat model"""
try:
response = client.chat.completions.create(
model=CHAT_MODEL,
messages=[
{"role": "system", "content": "You are a legal search expert. Your task is to enhance the given search query to improve search results in a legal document database. Keep the enhanced query concise and focused on the key legal concepts and facts."},
{"role": "user", "content": f"Please enhance this search query for searching in legal documents: {query}"}
],
temperature=0.3,
max_tokens=100
)
enhanced_query = response.choices[0].message.content.strip()
return enhanced_query
except Exception as e:
print(f"Error enhancing query: {e}")
return query
def get_relevant_documents(query: str, documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
"""Get the most relevant documents for the query"""
try:
# Generate embedding for the query
query_embedding = generate_embedding(query)
# Calculate similarities with all documents
results = []
for doc in documents:
try:
if doc.get('embedding'):
doc_embedding = json.loads(doc['embedding'])
similarity = cosine_similarity(query_embedding, doc_embedding)
# Create a preview of the content
content = doc.get('content', '')
content_preview = content[:300] + "..." if len(content) > 300 else content
results.append({
'id': doc['id'],
'title': doc.get('title', 'Untitled'),
'content': content,
'content_preview': content_preview,
'doc_type': doc.get('doc_type', 'Unknown'),
'similarity': similarity
})
except Exception as e:
print(f"Error processing document {doc.get('id')}: {e}")
continue
# Sort by similarity and return top_k results
results.sort(key=lambda x: x['similarity'], reverse=True)
return results[:top_k]
except Exception as e:
print(f"Error in get_relevant_documents: {e}")
return []
def generate_answer(query: str, relevant_docs: List[Dict[str, Any]]) -> str:
"""Generate an answer based on the query and relevant documents"""
if not relevant_docs:
return "I couldn't find any relevant documents to answer your question. Please try rephrasing your query."
try:
# Prepare context from relevant documents (limit context size)
context_parts = []
for i, doc in enumerate(relevant_docs[:3]): # Limit to top 3 documents
content = doc.get('content', '')[:500] # Limit content length
context_parts.append(f"Document {i+1}:\nTitle: {doc.get('title', 'Untitled')}\nContent: {content}")
context = "\n\n".join(context_parts)
response = client.chat.completions.create(
model=CHAT_MODEL,
messages=[
{"role": "system", "content": """You are a legal assistant helping to answer questions about legal cases.
Use the provided document context to answer questions accurately and professionally.
If the information is not available in the context, say so clearly."""},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}\n\nPlease provide a clear and accurate answer based on the above context:"}
],
temperature=0.5,
max_tokens=500
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Error generating answer: {e}")
return "I apologize, but I encountered an error while generating the answer. Please try rephrasing your question."
def enhanced_rag_search(query: str, profile_search: bool = False) -> Dict[str, Any]:
"""
Perform enhanced RAG search on the documents
Args:
query: The search query
profile_search: Whether to search in user profiles (not used currently)
Returns:
Dict containing search results and generated answer
"""
try:
print(f"Processing search query: {query}")
# Check if OpenAI API key is available
if not os.getenv('OPENAI_API_KEY'):
print("OpenAI API key not found")
return {
"query": query,
"enhanced_query": query,
"documents": [],
"answer": "OpenAI API key is not configured. Please set the OPENAI_API_KEY environment variable."
}
# Enhance the query
enhanced_query = enhance_search_query(query)
print(f"Enhanced query: {enhanced_query}")
# Get all documents with their embeddings
documents = get_all_documents(include_embeddings=True)
print(f"Found {len(documents)} documents")
# Filter documents that have embeddings
docs_with_embeddings = [doc for doc in documents if doc.get('embedding')]
print(f"Documents with embeddings: {len(docs_with_embeddings)}")
# Get relevant documents
relevant_docs = get_relevant_documents(enhanced_query, docs_with_embeddings)
print(f"Found {len(relevant_docs)} relevant documents")
# Generate answer
answer = generate_answer(query, relevant_docs)
print(f"Generated answer: {answer[:100]}...")
return {
"query": query,
"enhanced_query": enhanced_query,
"documents": relevant_docs,
"answer": answer
}
except Exception as e:
print(f"Error in enhanced RAG search: {e}")
return {
"query": query,
"enhanced_query": query,
"documents": [],
"answer": f"I apologize, but I encountered an error while processing your query: {str(e)}. Please try again."
}