| --- |
| license: mit |
| language: |
| - en |
| base_model: |
| - microsoft/Florence-2-large |
| datasets: |
| - diffusers/ShotDEAD-v0 |
| --- |
| # Shot Categorizer 🎬 |
|
|
| <div align="center"> |
| <img src="assets/header.jpg"/> |
| </div> |
|
|
| Shot categorization model finetuned from the [`microsoft/Florence-2-large`](https://huggingface.co/microsoft/Florence-2-large) model. This |
| model can be used to obtain metadata information about shots which can further be used to curate datasets of different kinds. |
|
|
| Training configuration: |
|
|
| * Batch size: 16 |
| * Gradient accumulation steps: 4 |
| * Learning rate: 1e-6 |
| * Epochs: 20 |
| * Max grad norm: 1.0 |
| * Hardware: 8xH100s |
|
|
| Training was conducted using FP16 mixed-precision and DeepSpeed Zero2 scheme. The vision tower of the model |
| was kept frozen during the training. We used the [diffusers/ShotDEAD-v0](https://huggingface.co/datasets/diffusers/ShotDEAD-v0) |
| dataset for conducting training. |
|
|
| Training code is available [here](https://github.com/huggingface/movie-shot-categorizer). |
|
|
| ## Inference |
|
|
| ```py |
| from transformers import AutoModelForCausalLM, AutoProcessor |
| import torch |
| from PIL import Image |
| import requests |
| |
| |
| folder_path = "diffusers/shot-categorizer-v0" |
| model = ( |
| AutoModelForCausalLM.from_pretrained(folder_path, torch_dtype=torch.float16, trust_remote_code=True) |
| .to("cuda") |
| .eval() |
| ) |
| processor = AutoProcessor.from_pretrained(folder_path, trust_remote_code=True) |
| |
| prompts = ["<COLOR>", "<LIGHTING>", "<LIGHTING_TYPE>", "<COMPOSITION>"] |
| img_path = "./assets/image_3.jpg" |
| image = Image.open(img_path).convert("RGB") |
| |
| with torch.no_grad() and torch.inference_mode(): |
| for prompt in prompts: |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda", torch.float16) |
| generated_ids = model.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=1024, |
| early_stopping=False, |
| do_sample=False, |
| num_beams=3, |
| ) |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
| parsed_answer = processor.post_process_generation( |
| generated_text, task=prompt, image_size=(image.width, image.height) |
| ) |
| print(parsed_answer) |
| |
| ``` |
|
|
| Should print: |
|
|
| ```bash |
| {'<COLOR>': 'Cool, Saturated, Cyan, Blue'} |
| {'<LIGHTING>': 'Soft light, Low contrast'} |
| {'<LIGHTING_TYPE>': 'Daylight, Sunny'} |
| {'<COMPOSITION>': 'Left heavy'} |
| ``` |