| from __future__ import annotations |
|
|
| import os |
| import re |
| import secrets |
| from contextlib import asynccontextmanager |
| from typing import Annotated |
|
|
| import torch |
| from fastapi import FastAPI, HTTPException, Security |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.security import APIKeyHeader |
| from pydantic import BaseModel, Field |
| from transformers import pipeline |
|
|
| |
|
|
| MODEL_ID = "openai-community/roberta-base-openai-detector" |
|
|
| |
| API_KEY = os.environ.get("API_KEY", "") |
|
|
| if not API_KEY: |
| raise RuntimeError( |
| "API_KEY environment variable is not set. " |
| "Add it in your HuggingFace Space β Settings β Variables and secrets." |
| ) |
|
|
| |
| api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) |
|
|
|
|
| def verify_api_key(key: str | None = Security(api_key_header)) -> str: |
| """Dependency: rejects requests with a missing or wrong API key.""" |
| if not key or not secrets.compare_digest(key, API_KEY): |
| raise HTTPException( |
| status_code=401, |
| detail="Invalid or missing API key. Pass it as the X-API-Key header.", |
| ) |
| return key |
|
|
|
|
| |
|
|
| classifier = None |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| global classifier |
| print(f"Loading model {MODEL_ID} β¦") |
| classifier = pipeline( |
| "text-classification", |
| model=MODEL_ID, |
| device=0 if torch.cuda.is_available() else -1, |
| ) |
| print("Model ready.") |
| yield |
|
|
|
|
| |
|
|
| app = FastAPI( |
| title="AI Text Detector API", |
| description="Detects whether text is human-written or AI-generated. Requires X-API-Key header.", |
| version="2.0.0", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["POST", "GET"], |
| allow_headers=["*"], |
| ) |
|
|
| |
|
|
| def split_into_chunks(text: str) -> list[str]: |
| chunks: list[str] = [] |
| paragraphs = [p.strip() for p in text.split("\n") if p.strip()] or [text.strip()] |
|
|
| for para in paragraphs: |
| sentences = re.split(r"(?<=[.!?])\s+", para) |
| current = "" |
| for sent in sentences: |
| if len((current + " " + sent).split()) > 80: |
| if current.strip(): |
| chunks.append(current.strip()) |
| current = sent |
| else: |
| current = (current + " " + sent).strip() |
| if current.strip(): |
| chunks.append(current.strip()) |
|
|
| return chunks or [text.strip()] |
|
|
|
|
| |
|
|
| class DetectRequest(BaseModel): |
| text: Annotated[ |
| str, |
| Field( |
| min_length=1, |
| max_length=10_000, |
| description="Text to analyse (max 10,000 characters)", |
| ), |
| ] |
|
|
|
|
| class ChunkResult(BaseModel): |
| text: str |
| ai_probability: float |
| human_probability: float |
| label: str |
| confidence: float |
|
|
|
|
| class DetectResponse(BaseModel): |
| label: str |
| ai_probability: float |
| human_probability: float |
| confidence: float |
| chunks: list[ChunkResult] |
| total_chunks: int |
| ai_chunks: int |
| human_chunks: int |
|
|
|
|
| |
|
|
| @app.get("/", tags=["health"]) |
| async def health(): |
| """Public health-check β no API key required.""" |
| return {"status": "ok", "model": MODEL_ID} |
|
|
|
|
| @app.post( |
| "/detect", |
| response_model=DetectResponse, |
| tags=["detection"], |
| dependencies=[Security(verify_api_key)], |
| ) |
| async def detect(body: DetectRequest): |
| if classifier is None: |
| raise HTTPException(status_code=503, detail="Model not loaded yet β try again shortly.") |
|
|
| chunks = split_into_chunks(body.text) |
| raw = classifier(chunks, truncation=True, max_length=512, batch_size=8) |
|
|
| chunk_results: list[ChunkResult] = [] |
| ai_probs: list[float] = [] |
| word_counts: list[int] = [] |
|
|
| for chunk, res in zip(chunks, raw): |
| ai_prob = res["score"] if res["label"] == "Fake" else 1.0 - res["score"] |
| human_prob = 1.0 - ai_prob |
| is_ai = ai_prob >= 0.5 |
| label = "AI" if is_ai else "Human" |
| conf = ai_prob if is_ai else human_prob |
|
|
| chunk_results.append( |
| ChunkResult( |
| text=chunk, |
| ai_probability=round(ai_prob, 4), |
| human_probability=round(human_prob, 4), |
| label=label, |
| confidence=round(conf, 4), |
| ) |
| ) |
| ai_probs.append(ai_prob) |
| word_counts.append(len(chunk.split())) |
|
|
| total_words = sum(word_counts) |
| avg_ai = sum(p * w for p, w in zip(ai_probs, word_counts)) / total_words |
| avg_human = 1.0 - avg_ai |
| overall_label = "AI" if avg_ai >= 0.5 else "Human" |
| overall_conf = avg_ai if overall_label == "AI" else avg_human |
| ai_chunks = sum(1 for p in ai_probs if p >= 0.5) |
|
|
| return DetectResponse( |
| label=overall_label, |
| ai_probability=round(avg_ai, 4), |
| human_probability=round(avg_human, 4), |
| confidence=round(overall_conf, 4), |
| chunks=chunk_results, |
| total_chunks=len(chunks), |
| ai_chunks=ai_chunks, |
| human_chunks=len(chunks) - ai_chunks, |
| ) |