| import os |
| from typing import List |
| from fastapi import FastAPI, HTTPException, UploadFile, File |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import asyncio |
| import tempfile |
| from aimakerspace.vectordatabase import VectorDatabase |
| from aimakerspace.openai_utils.chatmodel import ChatOpenAI |
|
|
| from app import ( |
| RetrievalAugmentedQAPipeline, |
| process_file, |
| system_role_prompt, |
| user_role_prompt, |
| ) |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=[ |
| "http://localhost:3001", |
| "http://localhost:7860", |
| "http://localhost", |
| "*", |
| ], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| expose_headers=["*"], |
| ) |
|
|
| class ChatResponse(BaseModel): |
| response: str |
| context: List[tuple] |
|
|
| class ChatRequest(BaseModel): |
| query: str |
|
|
| @app.post("/api/upload", response_model=dict) |
| async def upload_file(file: UploadFile = File(...)): |
| try: |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.split('.')[-1]}") as temp_file: |
| content = await file.read() |
| temp_file.write(content) |
| temp_file.flush() |
|
|
| |
| texts = process_file(temp_file.name, file.filename) |
|
|
| |
| vector_db = VectorDatabase() |
| vector_db = await vector_db.abuild_from_list(texts) |
|
|
| |
| chat_openai = ChatOpenAI() |
|
|
| |
| pipeline = RetrievalAugmentedQAPipeline( |
| vector_db_retriever=vector_db, |
| llm=chat_openai |
| ) |
|
|
| |
| if not hasattr(app, 'pipelines'): |
| app.pipelines = {} |
| pipeline_id = str(len(app.pipelines)) |
| app.pipelines[pipeline_id] = pipeline |
|
|
| |
| os.unlink(temp_file.name) |
|
|
| return {"pipeline_id": pipeline_id, "message": "File processed successfully"} |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/chat/{pipeline_id}", response_model=ChatResponse) |
| async def chat(pipeline_id: str, request: ChatRequest): |
| try: |
| if not hasattr(app, 'pipelines') or pipeline_id not in app.pipelines: |
| raise HTTPException(status_code=404, detail="Pipeline not found. Please upload a file first.") |
|
|
| pipeline = app.pipelines[pipeline_id] |
| result = await pipeline.arun_pipeline(request.query) |
|
|
| |
| response_text = "" |
| async for chunk in result["response"]: |
| response_text += chunk |
|
|
| return ChatResponse( |
| response=response_text, |
| context=result["context"] |
| ) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |