howmuch / app.py
mahendra189's picture
Add price fetching functionality and enhance Gradio interface for fashion item detection
9bde27c
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image, ImageDraw, ImageFont
import torch
import requests
import os
# Set SerpAPI key
SERPAPI_KEY = os.environ.get("SERPAPI_KEY")
# Load model and processor
model_name = "valentinafeve/yolos-fashionpedia"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForObjectDetection.from_pretrained(model_name)
# Fashion categories
CATS = [
'shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jacket', 'vest', 'pants', 'shorts', 'skirt',
'coat', 'dress', 'jumpsuit', 'cape', 'glasses', 'hat', 'headband, head covering, hair accessory', 'tie', 'glove',
'watch', 'belt', 'leg warmer', 'tights, stockings', 'sock', 'shoe', 'bag, wallet', 'scarf', 'umbrella', 'hood',
'collar', 'lapel', 'epaulette', 'sleeve', 'pocket', 'neckline', 'buckle', 'zipper', 'applique', 'bead', 'bow',
'flower', 'fringe', 'ribbon', 'rivet', 'ruffle', 'sequin', 'tassel'
]
model.config.id2label = {i: label for i, label in enumerate(CATS)}
model.config.label2id = {label: i for i, label in model.config.id2label.items()}
# Main outfit labels only
main_labels = set(CATS[:27])
def get_price(item_name):
"""Fetch average price from Google Shopping via SerpAPI."""
try:
url = "https://serpapi.com/search.json"
params = {
"q": f"{item_name} price",
"tbm": "shop",
"api_key": SERPAPI_KEY,
"num": 10
}
response = requests.get(url, params=params)
response.raise_for_status()
data = response.json()
prices = []
if "shopping_results" in data:
for result in data["shopping_results"]:
if "price" in result:
price_str = result["price"].replace("$", "").replace(",", "")
try:
prices.append(float(price_str))
except ValueError:
continue
return round(sum(prices) / len(prices), 2) if prices else 10.0
except Exception as e:
print(f"Error fetching price for {item_name}: {e}")
return 10.0
def detect_fashion_items(image):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Prepare inputs
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs, threshold=0.5, target_sizes=target_sizes
)[0]
# Filter to main labels and pick best per label
best_per_label = {}
for score, label_id, box in zip(results["scores"], results["labels"], results["boxes"]):
label = model.config.id2label[label_id.item()]
if label not in main_labels:
continue
score_val = score.item()
if label not in best_per_label or score_val > best_per_label[label]["score"]:
best_per_label[label] = {
"score": score_val,
"box": [round(i, 2) for i in box.tolist()],
"label": label,
"price": get_price(label)
}
# Draw on image
image = image.convert("RGBA") # For shadow transparency
draw = ImageDraw.Draw(image)
try:
font = ImageFont.truetype("DejaVuSans-Bold.ttf", 18) # Bold, 18px
except:
font = ImageFont.load_default()
for item in best_per_label.values():
box = item["box"]
label = item["label"]
score = item["score"]
price = item["price"]
# Draw bounding box
draw.rectangle(box, outline="blue", width=3)
# Draw label
label_text = f"{label}: {score:.2f}"
draw.text((box[0], box[1] - 50), label_text, fill="blue", font=font)
# Draw price tag (modern: yellow, rounded, shadowed)
tag_x = box[0]
tag_y = box[1] - 80 # Above label
tag_width = 120
tag_height = 40
draw.rounded_rectangle(
[tag_x + 2, tag_y + 2, tag_x + tag_width + 2, tag_y + tag_height + 2],
radius=10,
fill=(0, 0, 0, 64) # Shadow
)
draw.rounded_rectangle(
[tag_x, tag_y, tag_x + tag_width, tag_y + tag_height],
radius=10,
fill="yellow",
outline="black",
width=2
)
price_text = f"${price:.2f}"
text_bbox = draw.textbbox((0, 0), price_text, font=font)
text_width = text_bbox[2] - text_bbox[0]
draw.text((tag_x + (tag_width - text_width) // 2, tag_y + 10), price_text, fill="black", font=font)
# Convert back to RGB
image = image.convert("RGB")
# Calculate total price
total_price = sum(item["price"] for item in best_per_label.values())
return image, f"Total Outfit Price: ${total_price:.2f}"
# Gradio interface
with gr.Blocks(title="Fashion Outfit Detector with Live Prices") as iface:
gr.Markdown("### Fashion Outfit Detector with Live Prices\nUpload an image to detect unique outfit items with real-time prices from Google Shopping.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload a fashion image")
result_image = gr.Image(type="pil", label="Detected Outfits with Prices")
total_price_output = gr.Textbox(label="Total Price")
# Submit button
submit_btn = gr.Button("Detect Outfits")
submit_btn.click(
fn=detect_fashion_items,
inputs=image_input,
outputs=[result_image, total_price_output]
)
if __name__ == "__main__":
iface.launch(share=True)