166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, field_serializer
|
|
from typing import List, Optional, Dict, Any
|
|
from datetime import datetime
|
|
from db.db_utils import get_all_documents, get_document
|
|
import uvicorn
|
|
import os
|
|
|
|
app = FastAPI(
|
|
title="AGC Document Chatbot API",
|
|
description="API for Attorney General's Chambers Document Search and Chat System",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Add CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Allows all origins
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # Allows all methods
|
|
allow_headers=["*"], # Allows all headers
|
|
)
|
|
|
|
# Pydantic models for request/response
|
|
class DocumentResponse(BaseModel):
|
|
id: int
|
|
title: str
|
|
content: str
|
|
doc_type: str
|
|
created_at: Optional[datetime] = None
|
|
source: Optional[str] = None
|
|
|
|
# Serialize datetime to string
|
|
@field_serializer('created_at')
|
|
def serialize_created_at(self, value: Optional[datetime]) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
return value.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str
|
|
profile_search: bool = False
|
|
|
|
class SearchResponse(BaseModel):
|
|
query: str
|
|
enhanced_query: str
|
|
documents: List[Dict[str, Any]]
|
|
answer: str
|
|
|
|
# Routes
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "Welcome to AGC Document Chatbot API"}
|
|
|
|
@app.get("/ping")
|
|
async def ping():
|
|
"""Health check endpoint"""
|
|
return {"status": "ok", "message": "API is running"}
|
|
|
|
@app.get("/documents", response_model=List[DocumentResponse])
|
|
async def list_documents(
|
|
doc_type: Optional[str] = None,
|
|
title_filter: Optional[str] = None
|
|
):
|
|
"""Get all documents with optional filtering"""
|
|
try:
|
|
documents = get_all_documents()
|
|
|
|
# Apply filters
|
|
if doc_type and doc_type != "All Types":
|
|
documents = [doc for doc in documents if doc.get('doc_type') == doc_type]
|
|
|
|
if title_filter:
|
|
documents = [
|
|
doc for doc in documents
|
|
if title_filter.lower() in doc.get('title', '').lower() or
|
|
title_filter.lower() in doc.get('content', '').lower()
|
|
]
|
|
|
|
return documents
|
|
except Exception as e:
|
|
print(f"Error in list_documents: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Error fetching documents: {str(e)}")
|
|
|
|
@app.get("/documents/{document_id}", response_model=DocumentResponse)
|
|
async def get_document_by_id(document_id: int):
|
|
"""Get a specific document by ID"""
|
|
try:
|
|
document = get_document(document_id)
|
|
if not document:
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
return document
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
print(f"Error in get_document_by_id: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Error fetching document: {str(e)}")
|
|
|
|
@app.get("/document/{document_id}", response_model=DocumentResponse)
|
|
async def get_document_by_id_alt(document_id: str):
|
|
"""Alternative endpoint to get a specific document by ID (supports string IDs)"""
|
|
try:
|
|
# If the document_id starts with 'doc', extract the numeric part
|
|
if document_id.startswith('doc'):
|
|
try:
|
|
numeric_id = int(document_id[3:])
|
|
document = get_document(numeric_id)
|
|
if document:
|
|
return document
|
|
except ValueError:
|
|
pass # Not a numeric ID after 'doc', continue with normal lookup
|
|
|
|
# Try to convert the entire ID to an integer
|
|
try:
|
|
numeric_id = int(document_id)
|
|
document = get_document(numeric_id)
|
|
if document:
|
|
return document
|
|
except ValueError:
|
|
pass # Not a numeric ID, continue with not found
|
|
|
|
raise HTTPException(status_code=404, detail="Document not found")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
print(f"Error in get_document_by_id_alt: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Error fetching document: {str(e)}")
|
|
|
|
@app.post("/search", response_model=SearchResponse)
|
|
async def search_documents(request: SearchRequest):
|
|
"""Search documents using enhanced RAG or simple search as fallback"""
|
|
try:
|
|
# Check if OpenAI API key is available
|
|
if os.getenv('OPENAI_API_KEY'):
|
|
print("Using OpenAI-enhanced search")
|
|
try:
|
|
from embedding.enhanced_rag_service import enhanced_rag_search
|
|
results = enhanced_rag_search(request.query, request.profile_search)
|
|
return results
|
|
except Exception as e:
|
|
print(f"OpenAI search failed: {e}, falling back to simple search")
|
|
|
|
# Fallback to simple search
|
|
print("Using simple keyword search")
|
|
from embedding.simple_search_service import simple_search
|
|
results = simple_search(request.query, request.profile_search)
|
|
return results
|
|
|
|
except Exception as e:
|
|
print(f"Error in search_documents: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Error performing search: {str(e)}")
|
|
|
|
@app.get("/document-types")
|
|
async def get_document_types():
|
|
"""Get list of available document types"""
|
|
try:
|
|
documents = get_all_documents()
|
|
doc_types = list(set([doc.get('doc_type', 'Unknown') for doc in documents]))
|
|
return {"document_types": sorted(doc_types)}
|
|
except Exception as e:
|
|
print(f"Error in get_document_types: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=f"Error fetching document types: {str(e)}")
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) |