SAM3TrackerVideoProcessor: Batched multi-object inference produces corrupted masks when objects have different point counts
SAM3TrackerVideoProcessor: Batched multi-object inference produces corrupted masks when objects have different point counts
To reproduce
When using Sam3TrackerVideoProcessor.add_inputs_to_inference_session() to track multiple objects with different numbers of prompt points in a single batched call, the resulting masks for some objects are corrupted/incorrect.
Minimal reproduction example:
from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor
from transformers.video_utils import load_video
import torch
device = "cuda"
model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
processor = Sam3TrackerVideoProcessor.from_pretrained("facebook/sam3")
# Load video
video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
video_frames, _ = load_video(video_url)
# Initialize session
inference_session = processor.init_video_session(
video=video_frames,
inference_device=device,
dtype=torch.bfloat16,
)
# Define objects with DIFFERENT point counts
obj_ids = [1, 2, 3]
# Object 1: 2 points (bbox corners)
# Object 2: 2 points (bbox corners)
# Object 3: 12 points (bbox + positive/negative clicks)
input_points = [[
[[100, 100], [200, 200]], # Object 1: 2 points
[[300, 100], [400, 200]], # Object 2: 2 points
[[150, 150], [250, 250], [160, 160], [170, 170], [180, 180], [190, 190],
[200, 160], [210, 170], [220, 180], [230, 190], [240, 200], [250, 210]], # Object 3: 12 points
]]
# Labels: 2=box_top_left, 3=box_bottom_right, 1=positive, 0=negative
input_labels = [[
[2, 3], # Object 1: box corners
[2, 3], # Object 2: box corners
[2, 3, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], # Object 3: box + pos + neg points
]]
# Pad shorter objects to match longest (as required by API)
max_points = 12
for i in range(len(input_points[0])):
while len(input_points[0][i]) < max_points:
input_points[0][i].append([-10, -10]) # pad with -10
input_labels[0][i].append(-1) # pad label
processor.add_inputs_to_inference_session(
inference_session=inference_session,
frame_idx=0,
obj_ids=obj_ids,
input_points=input_points,
input_labels=input_labels,
)
# Run inference
outputs = model(inference_session=inference_session, frame_idx=0)
Observed behavior:
When running the above code:
- Object 1's mask is corrupted/incorrect (affected by the batching)
- Object 2's mask is sometimes correct, sometimes affected
- Object 3's mask is typically correct
The corruption appears to be related to the first object(s) in the batch when there's a large disparity in point counts between objects.
Expected behavior:
All objects should produce correct masks regardless of how many points other objects in the batch have. The padding with -10 coordinates and -1 labels should be properly ignored.
Workaround attempts that did NOT work:
- Reordering objects (complex first, then simple) - still corrupts masks
- Using
input_boxesparameter for bbox-only objects - causes "maskmem_features cannot be empty" error when combined withinput_points - Making separate calls for different object types - causes "maskmem_features cannot be empty when not is_initial_conditioning_frame" error
What DOES work:
- All objects have the same number of points (no padding needed) - masks are correct
- Only bbox objects OR only multi-point objects (not mixed) - masks are correct
Root cause hypothesis
The batched inference seems to have issues when the attention mechanism processes objects with vastly different numbers of "real" points vs padded points. The padding values (-10 for coordinates, -1 for labels) may not be properly masked out during the prompt encoding step.
Questions
- Is mixing objects with different point counts in a single
add_inputs_to_inference_session()call supported? - If so, is there a different padding strategy that should be used?
- Should
input_boxesandinput_pointsbe usable together in the same call for different objects?
Environment
transformers>=4.48.0
torch==2.10.0
accelerate