AU-LLM-Demo / app.py
lakki03's picture
Update app.py
e984579 verified
raw
history blame
9.62 kB
import os
import cv2
import numpy as np
import torch
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
from peft import PeftModel
OFFLOAD_DIR = "offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)
# ----------------------------
# 1. CONFIG
# ----------------------------
BASE_MODEL = "unsloth/Qwen3-1.7B"
HF_CONFUSION = "lakki03/qwen-au-confusion-AUplusDesc"
HF_ENGAGEMENT = "lakki03/qwen-au-engagement-AUplusDesc"
HF_BOREDOM = "lakki03/qwen-au-boredom-AUplusDesc"
HF_FRUSTRATION = "lakki03/qwen-au-frustration-AUplusDesc"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# same AU sets we used during training
AUS_PER_LABEL = {
"Confusion": [4, 7, 15, 17, 23],
"Engagement": [1, 2, 6, 12, 25],
"Boredom": [1, 2, 4, 7, 15],
"Frustration": [4, 7, 9, 23, 24],
}
AU_NAMES_RICH = {
1: "Inner brow raiser",
2: "Outer brow raiser",
4: "Brow lowerer",
6: "Cheek raiser",
7: "Lid tightener",
9: "Nose wrinkler",
12: "Lip corner puller (smile)",
15: "Lip corner depressor",
17: "Chin raiser",
23: "Lip tightener",
24: "Lip pressor",
25: "Lips part",
}
SCALE_TEXT = {
"Confusion": "confusion level on a 0–3 scale (0 = not confused, 3 = highly confused).",
"Engagement": "engagement level on a 0–3 scale (0 = disengaged, 3 = highly engaged).",
"Boredom": "boredom level on a 0–3 scale (0 = not bored, 3 = very bored).",
"Frustration": "frustration level on a 0–3 scale (0 = calm, 3 = highly frustrated).",
}
# ----------------------------
# 2. LOAD BASE + 4 LoRA MODELS
# ----------------------------
print("Loading base model:", BASE_MODEL)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
def load_lora(repo_id: str):
"""Load one LoRA adapter on top of the base model."""
# Base model, let HF/accelerate place layers (GPU + CPU)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=DTYPE,
device_map="auto",
)
# IMPORTANT: give accelerate an offload folder
model = PeftModel.from_pretrained(
base,
repo_id,
device_map="auto",
offload_folder=OFFLOAD_DIR, # <-- fixes the offload_dir error
)
model.eval()
return model
print("Loading LoRA models (this happens once at startup)...")
model_confusion = load_lora(HF_CONFUSION)
model_engagement = load_lora(HF_ENGAGEMENT)
model_boredom = load_lora(HF_BOREDOM)
model_frustration = load_lora(HF_FRUSTRATION)
print("All 4 models loaded.")
# ----------------------------
# 3. LOGITS PROCESSOR: force 0–3
# ----------------------------
class Only0123(LogitsProcessor):
def __init__(self, tok):
self.allowed = torch.tensor([tok.convert_tokens_to_ids(t) for t in ["0", "1", "2", "3"]])
def __call__(self, input_ids, scores):
mask = torch.full_like(scores, float("-inf"))
mask[:, self.allowed] = 0.0
return scores + mask
# ----------------------------
# 4. VERY SIMPLE "FAKE AU" EXTRACTOR
# (FOR DEMO ONLY)
# ----------------------------
def approximate_aus_from_video(video_path: str, aus):
"""
DEMO-ONLY:
We sample a few frames and derive 0–100 'AU intensities'
from simple brightness/contrast statistics.
For a real system, replace this with your AU extractor
(e.g., OpenFace / Py-Feat / your AU JSON pipeline).
"""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError("Could not open video")
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
sample_idx = np.linspace(0, frame_count - 1, num=min(16, frame_count)).astype(int)
values = []
for idx in sample_idx:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
ret, frame = cap.read()
if not ret:
continue
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
mean = float(np.mean(gray))
std = float(np.std(gray))
values.append((mean, std))
cap.release()
if not values:
# fallback
values = [(80.0, 40.0)]
means = np.mean(np.array(values), axis=0) # shape (2,)
mu, sigma = float(means[0]), float(means[1])
# crude mapping to 0–100 range
base_intensity = np.clip((mu - 60.0) / 2.0 + 50.0, 0, 100)
var_intensity = np.clip((sigma - 20.0) * 2.0 + 50.0, 0, 100)
result = {}
for i, au in enumerate(aus):
# alternate using base vs variance just to get diversity
val = base_intensity if (i % 2 == 0) else var_intensity
result[au] = float(val)
return result
def build_rule_description(label_name: str, means: dict, thr: float = 60.0):
"""
Simple rule-based text similar to what we used in training:
- If AU mean >= thr β†’ 'high'
- else β†’ 'low'
"""
parts = []
for au in AUS_PER_LABEL[label_name]:
v = means.get(au, 0.0)
level = "high" if v >= thr else "low"
name = AU_NAMES_RICH.get(au, f"AU{au}")
parts.append(f"{name} is {level} (mean {v:.1f})")
if not parts:
return "Facial activity appears minimal."
if label_name == "Confusion":
prefix = "Overall, the face shows signs related to confusion:"
elif label_name == "Engagement":
prefix = "Overall, the face shows signs related to engagement:"
elif label_name == "Boredom":
prefix = "Overall, the face shows signs related to boredom:"
else:
prefix = "Overall, the face shows signs related to frustration:"
return prefix + " " + "; ".join(parts)
def make_prompt(label_name: str, means: dict):
aus = AUS_PER_LABEL[label_name]
au_lines = []
for au in aus:
name = AU_NAMES_RICH.get(au, f"AU{au}")
val = means.get(au, 0.0)
au_lines.append(f"AU{au} ({name}): mean={val:.1f}")
au_block = "\n".join(au_lines)
desc = build_rule_description(label_name, means)
prompt = (
"You are given facial action unit (AU) features and a short description "
"for a learner during a task.\n"
"AU values are on a 0–100 scale.\n\n"
f"AU summary:\n{au_block}\n\n"
f"Description:\n{desc}\n\n"
f"Predict the {SCALE_TEXT[label_name]}\n"
"Answer with a single digit 0, 1, 2, or 3."
)
return prompt
def run_one_model(model, label_name: str, means: dict) -> int:
prompt = make_prompt(label_name, means)
# figure out the right device (for auto-sharded models this is e.g. "cuda:0")
device = getattr(model, "device", DEVICE)
with torch.no_grad():
toks = tokenizer(prompt, return_tensors="pt").to(device)
out = model.generate(
**toks,
max_new_tokens=1,
do_sample=False,
logits_processor=LogitsProcessorList([Only0123(tokenizer)]),
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(
out[0, toks["input_ids"].shape[1]:],
skip_special_tokens=True,
)
m = re.search(r"[0-3]", text)
return int(m.group()) if m else -1
# ----------------------------
# 5. GRADIO PIPELINE
# ----------------------------
def analyze_video(video_file):
if video_file is None:
return "Please upload a video.", None, None, None, None
video_path = video_file
results = {}
# For each label, approximate AUs from video and run its LoRA model
# (we recompute means separately per label so each uses its own AU set)
for label, model in [
("Confusion", model_confusion),
("Engagement", model_engagement),
("Boredom", model_boredom),
("Frustration", model_frustration),
]:
aus = AUS_PER_LABEL[label]
means = approximate_aus_from_video(video_path, aus)
pred = run_one_model(model, label, means)
results[label] = (means, pred)
# Build a pretty text summary
lines = []
for label in ["Confusion", "Engagement", "Boredom", "Frustration"]:
means, pred = results[label]
au_txt = ", ".join([f"AU{au}={means[au]:.1f}" for au in AUS_PER_LABEL[label]])
lines.append(f"{label}: {pred} | AUs: {au_txt}")
summary = "\n".join(lines)
return (
summary,
results["Confusion"][1],
results["Engagement"][1],
results["Boredom"][1],
results["Frustration"][1],
)
with gr.Blocks() as demo:
gr.Markdown("# AU-LLM Demo\nUpload a short face video and get predicted affect labels (0–3). \n\n"
"**Note:** AU extraction here is a simple placeholder. For real use, plug in your AU extractor.")
with gr.Row():
video_input = gr.Video(label="Input video (.mp4)", sources=["upload"])
with gr.Row():
confusion_out = gr.Number(label="Confusion (0–3)", precision=0)
engagement_out = gr.Number(label="Engagement (0–3)", precision=0)
boredom_out = gr.Number(label="Boredom (0–3)", precision=0)
frustration_out = gr.Number(label="Frustration (0–3)", precision=0)
summary_box = gr.Textbox(label="Raw output", lines=6)
analyze_btn = gr.Button("Analyze video")
analyze_btn.click(
fn=analyze_video,
inputs=video_input,
outputs=[summary_box, confusion_out, engagement_out, boredom_out, frustration_out],
)
if __name__ == "__main__":
demo.launch()