SAM3TrackerVideoProcessor: Batched multi-object inference produces corrupted masks when objects have different point counts

#132
by aptech0081 - opened

SAM3TrackerVideoProcessor: Batched multi-object inference produces corrupted masks when objects have different point counts

To reproduce

When using Sam3TrackerVideoProcessor.add_inputs_to_inference_session() to track multiple objects with different numbers of prompt points in a single batched call, the resulting masks for some objects are corrupted/incorrect.

Minimal reproduction example:

from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor
from transformers.video_utils import load_video
import torch

device = "cuda"
model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
processor = Sam3TrackerVideoProcessor.from_pretrained("facebook/sam3")

# Load video
video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
video_frames, _ = load_video(video_url)

# Initialize session
inference_session = processor.init_video_session(
    video=video_frames,
    inference_device=device,
    dtype=torch.bfloat16,
)

# Define objects with DIFFERENT point counts
obj_ids = [1, 2, 3]

# Object 1: 2 points (bbox corners)
# Object 2: 2 points (bbox corners)  
# Object 3: 12 points (bbox + positive/negative clicks)
input_points = [[
    [[100, 100], [200, 200]],  # Object 1: 2 points
    [[300, 100], [400, 200]],  # Object 2: 2 points
    [[150, 150], [250, 250], [160, 160], [170, 170], [180, 180], [190, 190], 
     [200, 160], [210, 170], [220, 180], [230, 190], [240, 200], [250, 210]],  # Object 3: 12 points
]]

# Labels: 2=box_top_left, 3=box_bottom_right, 1=positive, 0=negative
input_labels = [[
    [2, 3],  # Object 1: box corners
    [2, 3],  # Object 2: box corners
    [2, 3, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],  # Object 3: box + pos + neg points
]]

# Pad shorter objects to match longest (as required by API)
max_points = 12
for i in range(len(input_points[0])):
    while len(input_points[0][i]) < max_points:
        input_points[0][i].append([-10, -10])  # pad with -10
        input_labels[0][i].append(-1)  # pad label

processor.add_inputs_to_inference_session(
    inference_session=inference_session,
    frame_idx=0,
    obj_ids=obj_ids,
    input_points=input_points,
    input_labels=input_labels,
)

# Run inference
outputs = model(inference_session=inference_session, frame_idx=0)

Observed behavior:

When running the above code:

  • Object 1's mask is corrupted/incorrect (affected by the batching)
  • Object 2's mask is sometimes correct, sometimes affected
  • Object 3's mask is typically correct

The corruption appears to be related to the first object(s) in the batch when there's a large disparity in point counts between objects.

Expected behavior:

All objects should produce correct masks regardless of how many points other objects in the batch have. The padding with -10 coordinates and -1 labels should be properly ignored.

Workaround attempts that did NOT work:

  1. Reordering objects (complex first, then simple) - still corrupts masks
  2. Using input_boxes parameter for bbox-only objects - causes "maskmem_features cannot be empty" error when combined with input_points
  3. Making separate calls for different object types - causes "maskmem_features cannot be empty when not is_initial_conditioning_frame" error

What DOES work:

  • All objects have the same number of points (no padding needed) - masks are correct
  • Only bbox objects OR only multi-point objects (not mixed) - masks are correct

Root cause hypothesis

The batched inference seems to have issues when the attention mechanism processes objects with vastly different numbers of "real" points vs padded points. The padding values (-10 for coordinates, -1 for labels) may not be properly masked out during the prompt encoding step.

Questions

  1. Is mixing objects with different point counts in a single add_inputs_to_inference_session() call supported?
  2. If so, is there a different padding strategy that should be used?
  3. Should input_boxes and input_points be usable together in the same call for different objects?

Environment

transformers>=4.48.0
torch==2.10.0
accelerate

Sign up or log in to comment