185 lines
7.3 KiB
Python
185 lines
7.3 KiB
Python
import os
|
|
import json
|
|
import numpy as np
|
|
from typing import List, Dict, Any
|
|
from db.db_utils import get_all_documents, get_document
|
|
from openai import OpenAI
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Configure OpenAI client
|
|
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
|
|
EMBEDDING_MODEL = "text-embedding-ada-002"
|
|
CHAT_MODEL = "gpt-3.5-turbo"
|
|
|
|
def generate_embedding(text: str) -> List[float]:
|
|
"""Generate embedding for the given text using OpenAI's API"""
|
|
try:
|
|
response = client.embeddings.create(
|
|
model=EMBEDDING_MODEL,
|
|
input=text
|
|
)
|
|
return response.data[0].embedding
|
|
except Exception as e:
|
|
print(f"Error generating embedding: {e}")
|
|
# Return a dummy embedding of appropriate size if API fails
|
|
return [0.0] * 1536
|
|
|
|
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|
"""Calculate cosine similarity between two vectors"""
|
|
a = np.array(a)
|
|
b = np.array(b)
|
|
norm_a = np.linalg.norm(a)
|
|
norm_b = np.linalg.norm(b)
|
|
if norm_a == 0 or norm_b == 0:
|
|
return 0.0
|
|
return np.dot(a, b) / (norm_a * norm_b)
|
|
|
|
def enhance_search_query(query: str) -> str:
|
|
"""Enhance the search query using OpenAI's chat model"""
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=CHAT_MODEL,
|
|
messages=[
|
|
{"role": "system", "content": "You are a legal search expert. Your task is to enhance the given search query to improve search results in a legal document database. Keep the enhanced query concise and focused on the key legal concepts and facts."},
|
|
{"role": "user", "content": f"Please enhance this search query for searching in legal documents: {query}"}
|
|
],
|
|
temperature=0.3,
|
|
max_tokens=100
|
|
)
|
|
|
|
enhanced_query = response.choices[0].message.content.strip()
|
|
return enhanced_query
|
|
except Exception as e:
|
|
print(f"Error enhancing query: {e}")
|
|
return query
|
|
|
|
def get_relevant_documents(query: str, documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""Get the most relevant documents for the query"""
|
|
try:
|
|
# Generate embedding for the query
|
|
query_embedding = generate_embedding(query)
|
|
|
|
# Calculate similarities with all documents
|
|
results = []
|
|
for doc in documents:
|
|
try:
|
|
if doc.get('embedding'):
|
|
doc_embedding = json.loads(doc['embedding'])
|
|
similarity = cosine_similarity(query_embedding, doc_embedding)
|
|
|
|
# Create a preview of the content
|
|
content = doc.get('content', '')
|
|
content_preview = content[:300] + "..." if len(content) > 300 else content
|
|
|
|
results.append({
|
|
'id': doc['id'],
|
|
'title': doc.get('title', 'Untitled'),
|
|
'content': content,
|
|
'content_preview': content_preview,
|
|
'doc_type': doc.get('doc_type', 'Unknown'),
|
|
'similarity': similarity
|
|
})
|
|
except Exception as e:
|
|
print(f"Error processing document {doc.get('id')}: {e}")
|
|
continue
|
|
|
|
# Sort by similarity and return top_k results
|
|
results.sort(key=lambda x: x['similarity'], reverse=True)
|
|
return results[:top_k]
|
|
|
|
except Exception as e:
|
|
print(f"Error in get_relevant_documents: {e}")
|
|
return []
|
|
|
|
def generate_answer(query: str, relevant_docs: List[Dict[str, Any]]) -> str:
|
|
"""Generate an answer based on the query and relevant documents"""
|
|
if not relevant_docs:
|
|
return "I couldn't find any relevant documents to answer your question. Please try rephrasing your query."
|
|
|
|
try:
|
|
# Prepare context from relevant documents (limit context size)
|
|
context_parts = []
|
|
for i, doc in enumerate(relevant_docs[:3]): # Limit to top 3 documents
|
|
content = doc.get('content', '')[:500] # Limit content length
|
|
context_parts.append(f"Document {i+1}:\nTitle: {doc.get('title', 'Untitled')}\nContent: {content}")
|
|
|
|
context = "\n\n".join(context_parts)
|
|
|
|
response = client.chat.completions.create(
|
|
model=CHAT_MODEL,
|
|
messages=[
|
|
{"role": "system", "content": """You are a legal assistant helping to answer questions about legal cases.
|
|
Use the provided document context to answer questions accurately and professionally.
|
|
If the information is not available in the context, say so clearly."""},
|
|
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}\n\nPlease provide a clear and accurate answer based on the above context:"}
|
|
],
|
|
temperature=0.5,
|
|
max_tokens=500
|
|
)
|
|
|
|
return response.choices[0].message.content.strip()
|
|
except Exception as e:
|
|
print(f"Error generating answer: {e}")
|
|
return "I apologize, but I encountered an error while generating the answer. Please try rephrasing your question."
|
|
|
|
def enhanced_rag_search(query: str, profile_search: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Perform enhanced RAG search on the documents
|
|
|
|
Args:
|
|
query: The search query
|
|
profile_search: Whether to search in user profiles (not used currently)
|
|
|
|
Returns:
|
|
Dict containing search results and generated answer
|
|
"""
|
|
try:
|
|
print(f"Processing search query: {query}")
|
|
|
|
# Check if OpenAI API key is available
|
|
if not os.getenv('OPENAI_API_KEY'):
|
|
print("OpenAI API key not found")
|
|
return {
|
|
"query": query,
|
|
"enhanced_query": query,
|
|
"documents": [],
|
|
"answer": "OpenAI API key is not configured. Please set the OPENAI_API_KEY environment variable."
|
|
}
|
|
|
|
# Enhance the query
|
|
enhanced_query = enhance_search_query(query)
|
|
print(f"Enhanced query: {enhanced_query}")
|
|
|
|
# Get all documents with their embeddings
|
|
documents = get_all_documents(include_embeddings=True)
|
|
print(f"Found {len(documents)} documents")
|
|
|
|
# Filter documents that have embeddings
|
|
docs_with_embeddings = [doc for doc in documents if doc.get('embedding')]
|
|
print(f"Documents with embeddings: {len(docs_with_embeddings)}")
|
|
|
|
# Get relevant documents
|
|
relevant_docs = get_relevant_documents(enhanced_query, docs_with_embeddings)
|
|
print(f"Found {len(relevant_docs)} relevant documents")
|
|
|
|
# Generate answer
|
|
answer = generate_answer(query, relevant_docs)
|
|
print(f"Generated answer: {answer[:100]}...")
|
|
|
|
return {
|
|
"query": query,
|
|
"enhanced_query": enhanced_query,
|
|
"documents": relevant_docs,
|
|
"answer": answer
|
|
}
|
|
except Exception as e:
|
|
print(f"Error in enhanced RAG search: {e}")
|
|
return {
|
|
"query": query,
|
|
"enhanced_query": query,
|
|
"documents": [],
|
|
"answer": f"I apologize, but I encountered an error while processing your query: {str(e)}. Please try again."
|
|
} |