149 lines
5.5 KiB
Python
149 lines
5.5 KiB
Python
import os
|
|
import json
|
|
import numpy as np
|
|
from typing import List, Dict, Any
|
|
from db.db_utils import get_all_documents, get_document
|
|
import openai
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Configure OpenAI
|
|
openai.api_key = os.getenv('OPENAI_API_KEY')
|
|
EMBEDDING_MODEL = "text-embedding-ada-002"
|
|
CHAT_MODEL = "gpt-3.5-turbo"
|
|
|
|
def generate_embedding(text: str) -> List[float]:
|
|
"""Generate embedding for the given text using OpenAI's API"""
|
|
response = openai.Embedding.create(
|
|
model=EMBEDDING_MODEL,
|
|
input=text
|
|
)
|
|
return response['data'][0]['embedding']
|
|
|
|
def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|
"""Calculate cosine similarity between two vectors"""
|
|
a = np.array(a)
|
|
b = np.array(b)
|
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|
|
|
def enhance_search_query(query: str) -> str:
|
|
"""Enhance the search query using OpenAI's chat model"""
|
|
try:
|
|
messages = [
|
|
{"role": "system", "content": "You are a legal search expert. Your task is to enhance the given search query to improve search results in a legal document database. Keep the enhanced query concise and focused on the key legal concepts and facts."},
|
|
{"role": "user", "content": f"Please enhance this search query for searching in legal documents: {query}"}
|
|
]
|
|
|
|
response = openai.ChatCompletion.create(
|
|
model=CHAT_MODEL,
|
|
messages=messages,
|
|
temperature=0.3,
|
|
max_tokens=100
|
|
)
|
|
|
|
enhanced_query = response['choices'][0]['message']['content'].strip()
|
|
return enhanced_query
|
|
except Exception as e:
|
|
print(f"Error enhancing query: {e}")
|
|
return query
|
|
|
|
def get_relevant_documents(query: str, documents: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""Get the most relevant documents for the query"""
|
|
# Generate embedding for the query
|
|
query_embedding = generate_embedding(query)
|
|
|
|
# Calculate similarities with all documents
|
|
results = []
|
|
for doc in documents:
|
|
try:
|
|
doc_embedding = json.loads(doc['embedding'])
|
|
similarity = cosine_similarity(query_embedding, doc_embedding)
|
|
|
|
# Create a preview of the content
|
|
content = doc['content']
|
|
content_preview = content[:300] + "..." if len(content) > 300 else content
|
|
|
|
results.append({
|
|
'id': doc['id'],
|
|
'title': doc['title'],
|
|
'content': doc['content'],
|
|
'content_preview': content_preview,
|
|
'doc_type': doc['doc_type'],
|
|
'similarity': similarity
|
|
})
|
|
except Exception as e:
|
|
print(f"Error processing document {doc.get('id')}: {e}")
|
|
continue
|
|
|
|
# Sort by similarity and return top_k results
|
|
results.sort(key=lambda x: x['similarity'], reverse=True)
|
|
return results[:top_k]
|
|
|
|
def generate_answer(query: str, relevant_docs: List[Dict[str, Any]]) -> str:
|
|
"""Generate an answer based on the query and relevant documents"""
|
|
# Prepare context from relevant documents
|
|
context = "\n\n".join([
|
|
f"Document {i+1}:\nTitle: {doc['title']}\nContent: {doc['content']}"
|
|
for i, doc in enumerate(relevant_docs)
|
|
])
|
|
|
|
try:
|
|
messages = [
|
|
{"role": "system", "content": """You are a legal assistant helping to answer questions about legal cases.
|
|
Use the provided document context to answer questions accurately and professionally.
|
|
If the information is not available in the context, say so clearly."""},
|
|
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}\n\nPlease provide a clear and accurate answer based on the above context:"}
|
|
]
|
|
|
|
response = openai.ChatCompletion.create(
|
|
model=CHAT_MODEL,
|
|
messages=messages,
|
|
temperature=0.5,
|
|
max_tokens=500
|
|
)
|
|
|
|
return response['choices'][0]['message']['content'].strip()
|
|
except Exception as e:
|
|
print(f"Error generating answer: {e}")
|
|
return "I apologize, but I encountered an error while generating the answer. Please try rephrasing your question."
|
|
|
|
def enhanced_rag_search(query: str, profile_search: bool = False) -> Dict[str, Any]:
|
|
"""
|
|
Perform enhanced RAG search on the documents
|
|
|
|
Args:
|
|
query: The search query
|
|
profile_search: Whether to search in user profiles (not used currently)
|
|
|
|
Returns:
|
|
Dict containing search results and generated answer
|
|
"""
|
|
try:
|
|
# Enhance the query
|
|
enhanced_query = enhance_search_query(query)
|
|
|
|
# Get all documents with their embeddings
|
|
documents = get_all_documents(include_embeddings=True)
|
|
|
|
# Get relevant documents
|
|
relevant_docs = get_relevant_documents(enhanced_query, documents)
|
|
|
|
# Generate answer
|
|
answer = generate_answer(query, relevant_docs)
|
|
|
|
return {
|
|
"query": query,
|
|
"enhanced_query": enhanced_query,
|
|
"documents": relevant_docs,
|
|
"answer": answer
|
|
}
|
|
except Exception as e:
|
|
print(f"Error in enhanced RAG search: {e}")
|
|
return {
|
|
"query": query,
|
|
"enhanced_query": query,
|
|
"documents": [],
|
|
"answer": "I apologize, but I encountered an error while processing your query. Please try again."
|
|
} |