Spaces:
Runtime error
Runtime error
| from fastai.vision.all import * | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| import logging | |
| import tempfile | |
| from pathlib import Path | |
| import firebase_admin | |
| from firebase_admin import credentials, firestore, storage | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
| from PIL import Image, ImageDraw, ImageFont | |
| import cv2 | |
| import random | |
| # Load model and processor | |
| processor = AutoImageProcessor.from_pretrained("valentinafeve/yolos-fashionpedia") | |
| model = AutoModelForObjectDetection.from_pretrained("valentinafeve/yolos-fashionpedia") | |
| # Fashionpedia categories | |
| FASHION_CATEGORIES = [ | |
| 'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', | |
| 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', | |
| 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', | |
| 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', | |
| 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel' | |
| ] | |
| def detect_fashion(image): | |
| inputs = processor(images=image, return_tensors="pt") | |
| outputs = model(**inputs) | |
| # Convert outputs (bounding boxes and class logits) to COCO API | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = processor.post_process_object_detection(outputs, threshold=0.1, target_sizes=target_sizes)[0] | |
| detected_items = [] | |
| for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| if score > 0.5: # Adjust this threshold as needed | |
| detected_items.append((FASHION_CATEGORIES[label], score.item(), box.tolist())) | |
| return detected_items | |
| def check_dress_code(detected_items): | |
| formal_workplace_attire = { | |
| "shirt, blouse", "jacket", "tie", "coat", "sweater", "cardigan", "coat" | |
| } | |
| return any(item[0] in formal_workplace_attire for item in detected_items) | |
| async def process_file(file_data: FileProcess): | |
| logger.info(f"Processing file from Firebase Storage: {file_data.file_path}") | |
| try: | |
| # Get the file from Firebase Storage | |
| blob = bucket.blob(file_data.file_path) | |
| # Create a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file_data.file_path.split('.')[-1]}") as tmp_file: | |
| blob.download_to_filename(tmp_file.name) | |
| tmp_file_path = Path(tmp_file.name) | |
| logger.info(f"File downloaded temporarily at: {tmp_file_path}") | |
| file_type = file_data.file_path.split('.')[-1].lower() | |
| try: | |
| if file_type in ['mp4', 'avi', 'mov', 'wmv']: | |
| output,testing = process_video(str(tmp_file_path)) | |
| result = {"type": "video", "data": {"result": output}} | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported file type") | |
| logger.info(f"Processing complete. Result: {result}") | |
| # Store result in Firebase | |
| try: | |
| doc_ref = db.collection('results').add(result) | |
| return {"message": "File processed successfully", "result": result} | |
| except Exception as e: | |
| logger.error(f"Failed to store result in Firebase: {str(e)}") | |
| return {"message": "File processed successfully, but failed to store in Firebase", "result": result, | |
| "error": str(e)} | |
| finally: | |
| # Clean up the temporary file | |
| tmp_file_path.unlink() | |
| except Exception as e: | |
| logger.error(f"Error processing file: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
| def process_video(video_path,num_frames=10): | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_indices = sorted(random.sample(range(total_frames), min(num_frames, total_frames))) | |
| compliance_results = [] | |
| for frame_index in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) | |
| ret, frame = cap.read() | |
| if ret: | |
| image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| detected_items = detect_fashion(image) | |
| is_compliant = check_dress_code(detected_items) | |
| compliance_results.append(is_compliant) | |
| cap.release() | |
| average_compliance = sum(compliance_results) / len(compliance_results) | |
| return average_compliance, compliance_results | |
| if __name__ == "__main__": | |
| logger.info("Starting the Face Emotion Recognition API") | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |