Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoImageProcessor, AutoModelForObjectDetection | |
| from PIL import Image, ImageDraw, ImageFont | |
| import torch | |
| import requests | |
| import os | |
| # Set SerpAPI key | |
| SERPAPI_KEY = os.environ.get("SERPAPI_KEY") | |
| # Load model and processor | |
| model_name = "valentinafeve/yolos-fashionpedia" | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| model = AutoModelForObjectDetection.from_pretrained(model_name) | |
| # Fashion categories | |
| CATS = [ | |
| 'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt', | |
| 'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove', | |
| 'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood', | |
| 'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow', | |
| 'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel' | |
| ] | |
| model.config.id2label = {i: label for i, label in enumerate(CATS)} | |
| model.config.label2id = {label: i for i, label in model.config.id2label.items()} | |
| # Main outfit labels only | |
| main_labels = set(CATS[:27]) | |
| def get_price(item_name): | |
| """Fetch average price from Google Shopping via SerpAPI.""" | |
| try: | |
| url = "https://serpapi.com/search.json" | |
| params = { | |
| "q": f"{item_name} price", | |
| "tbm": "shop", | |
| "api_key": SERPAPI_KEY, | |
| "num": 10 | |
| } | |
| response = requests.get(url, params=params) | |
| response.raise_for_status() | |
| data = response.json() | |
| prices = [] | |
| if "shopping_results" in data: | |
| for result in data["shopping_results"]: | |
| if "price" in result: | |
| price_str = result["price"].replace("$", "").replace(",", "") | |
| try: | |
| prices.append(float(price_str)) | |
| except ValueError: | |
| continue | |
| return round(sum(prices) / len(prices), 2) if prices else 10.0 | |
| except Exception as e: | |
| print(f"Error fetching price for {item_name}: {e}") | |
| return 10.0 | |
| def detect_fashion_items(image): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Prepare inputs | |
| inputs = processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Post-process | |
| target_sizes = torch.tensor([image.size[::-1]]) | |
| results = processor.post_process_object_detection( | |
| outputs, threshold=0.5, target_sizes=target_sizes | |
| )[0] | |
| # Filter to main labels and pick best per label | |
| best_per_label = {} | |
| for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]): | |
| label = model.config.id2label[label_id.item()] | |
| if label not in main_labels: | |
| continue | |
| score_val = score.item() | |
| if label not in best_per_label or score_val > best_per_label[label]["score"]: | |
| best_per_label[label] = { | |
| "score": score_val, | |
| "box": [round(i, 2) for i in box.tolist()], | |
| "label": label, | |
| "price": get_price(label) | |
| } | |
| # Draw on image | |
| image = image.convert("RGBA") # For shadow transparency | |
| draw = ImageDraw.Draw(image) | |
| try: | |
| font = ImageFont.truetype("DejaVuSans-Bold.ttf", 18) # Bold, 18px | |
| except: | |
| font = ImageFont.load_default() | |
| for item in best_per_label.values(): | |
| box = item["box"] | |
| label = item["label"] | |
| score = item["score"] | |
| price = item["price"] | |
| # Draw bounding box | |
| draw.rectangle(box, outline="blue", width=3) | |
| # Draw label | |
| label_text = f"{label}: {score:.2f}" | |
| draw.text((box[0], box[1] - 50), label_text, fill="blue", font=font) | |
| # Draw price tag (modern: yellow, rounded, shadowed) | |
| tag_x = box[0] | |
| tag_y = box[1] - 80 # Above label | |
| tag_width = 120 | |
| tag_height = 40 | |
| draw.rounded_rectangle( | |
| [tag_x + 2, tag_y + 2, tag_x + tag_width + 2, tag_y + tag_height + 2], | |
| radius=10, | |
| fill=(0, 0, 0, 64) # Shadow | |
| ) | |
| draw.rounded_rectangle( | |
| [tag_x, tag_y, tag_x + tag_width, tag_y + tag_height], | |
| radius=10, | |
| fill="yellow", | |
| outline="black", | |
| width=2 | |
| ) | |
| price_text = f"${price:.2f}" | |
| text_bbox = draw.textbbox((0, 0), price_text, font=font) | |
| text_width = text_bbox[2] - text_bbox[0] | |
| draw.text((tag_x + (tag_width - text_width) // 2, tag_y + 10), price_text, fill="black", font=font) | |
| # Convert back to RGB | |
| image = image.convert("RGB") | |
| # Calculate total price | |
| total_price = sum(item["price"] for item in best_per_label.values()) | |
| return image, f"Total Outfit Price: ${total_price:.2f}" | |
| # Gradio interface | |
| with gr.Blocks(title="Fashion Outfit Detector with Live Prices") as iface: | |
| gr.Markdown("### Fashion Outfit Detector with Live Prices\nUpload an image to detect unique outfit items with real-time prices from Google Shopping.") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload a fashion image") | |
| result_image = gr.Image(type="pil", label="Detected Outfits with Prices") | |
| total_price_output = gr.Textbox(label="Total Price") | |
| # Submit button | |
| submit_btn = gr.Button("Detect Outfits") | |
| submit_btn.click( | |
| fn=detect_fashion_items, | |
| inputs=image_input, | |
| outputs=[result_image, total_price_output] | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(share=True) |