Daankular's picture
download
raw
9.06 kB
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
"""Necks are the interface between a vision backbone and the rest of the detection model"""
from copy import deepcopy
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from ..model.data_misc import NestedTensor
class Sam3DualViTDetNeck(nn.Module):
def __init__(
self,
trunk: nn.Module,
position_encoding: nn.Module,
d_model: int,
scale_factors=(4.0, 2.0, 1.0, 0.5),
add_sam2_neck: bool = False,
):
"""
SimpleFPN neck a la ViTDet
(From detectron2, very lightly adapted)
It supports a "dual neck" setting, where we have two identical necks (for SAM3 and SAM2), with different weights
:param trunk: the backbone
:param position_encoding: the positional encoding to use
:param d_model: the dimension of the model
"""
super().__init__()
self.trunk = trunk
self.position_encoding = position_encoding
self.convs = nn.ModuleList()
self.scale_factors = scale_factors
use_bias = True
dim: int = self.trunk.channel_list[-1]
for _, scale in enumerate(scale_factors):
current = nn.Sequential()
if scale == 4.0:
current.add_module(
"dconv_2x2_0",
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
)
current.add_module(
"gelu",
nn.GELU(),
)
current.add_module(
"dconv_2x2_1",
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
)
out_dim = dim // 4
elif scale == 2.0:
current.add_module(
"dconv_2x2",
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
)
out_dim = dim // 2
elif scale == 1.0:
out_dim = dim
elif scale == 0.5:
current.add_module(
"maxpool_2x2",
nn.MaxPool2d(kernel_size=2, stride=2),
)
out_dim = dim
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
current.add_module(
"conv_1x1",
nn.Conv2d(
in_channels=out_dim,
out_channels=d_model,
kernel_size=1,
bias=use_bias,
),
)
current.add_module(
"conv_3x3",
nn.Conv2d(
in_channels=d_model,
out_channels=d_model,
kernel_size=3,
padding=1,
bias=use_bias,
),
)
self.convs.append(current)
self.sam2_convs = None
if add_sam2_neck:
# Assumes sam2 neck is just a clone of the original neck
self.sam2_convs = deepcopy(self.convs)
def forward(
self, tensor_list: List[torch.Tensor]
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
xs = self.trunk(tensor_list)
sam3_out, sam3_pos = [], []
sam2_out, sam2_pos = None, None
if self.sam2_convs is not None:
sam2_out, sam2_pos = [], []
x = xs[-1] # simpleFPN
for i in range(len(self.convs)):
sam3_x_out = self.convs[i](x)
sam3_pos_out = self.position_encoding(sam3_x_out).to(sam3_x_out.dtype)
sam3_out.append(sam3_x_out)
sam3_pos.append(sam3_pos_out)
if self.sam2_convs is not None:
sam2_x_out = self.sam2_convs[i](x)
sam2_pos_out = self.position_encoding(sam2_x_out).to(sam2_x_out.dtype)
sam2_out.append(sam2_x_out)
sam2_pos.append(sam2_pos_out)
return sam3_out, sam3_pos, sam2_out, sam2_pos
class Sam3TriViTDetNeck(nn.Module):
def __init__(
self,
trunk: nn.Module,
position_encoding: nn.Module,
d_model: int,
neck_norm=None,
scale_factors=(4.0, 2.0, 1.0),
):
"""
SimpleFPN neck with three heads (sam3, interactive, propagation).
"""
super().__init__()
self.trunk = trunk
self.position_encoding = position_encoding
self.convs = nn.ModuleList()
self.scale_factors = scale_factors
use_bias = neck_norm is None
dim = self.trunk.channel_list[-1]
for _, scale in enumerate(scale_factors):
current = nn.Sequential()
if scale == 4.0:
current.add_module(
"dconv_2x2_0",
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
)
current.add_module(
"gelu",
nn.GELU(),
)
current.add_module(
"dconv_2x2_1",
nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2),
)
out_dim = dim // 4
elif scale == 2.0:
current.add_module(
"dconv_2x2",
nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2),
)
out_dim = dim // 2
elif scale == 1.0:
out_dim = dim
elif scale == 0.5:
current.add_module(
"maxpool_2x2",
nn.MaxPool2d(kernel_size=2, stride=2),
)
out_dim = dim
else:
raise NotImplementedError(f"scale_factor={scale} is not supported yet.")
current.add_module(
"conv_1x1",
nn.Conv2d(
in_channels=out_dim,
out_channels=d_model,
kernel_size=1,
bias=use_bias,
),
)
current.add_module(
"conv_3x3",
nn.Conv2d(
in_channels=d_model,
out_channels=d_model,
kernel_size=3,
padding=1,
bias=use_bias,
),
)
self.convs.append(current)
# Assumes the new necks are just clones of the original neck
self.interactive_convs = deepcopy(self.convs)
self.propagation_convs = deepcopy(self.convs)
def forward(
self,
tensor_list,
*,
need_sam3_out: bool = True,
need_interactive_out: bool = True,
need_propagation_out: bool = True,
):
xs = self.trunk(tensor_list)
sam3_out = []
interactive_out = []
propagation_out = []
sam3_pos = []
interactive_pos = []
propagation_pos = []
x = xs[-1] # simpleFPN
# OSS trunk returns plain tensors; onevision trunk returns NestedTensors.
# Use getattr to handle both in a torch.compile-friendly way.
x_data = getattr(x, "tensors", x)
x_mask = getattr(x, "mask", None)
for _, (conv, interactive_conv, propagation_conv) in enumerate(
zip(self.convs, self.interactive_convs, self.propagation_convs)
):
if need_sam3_out:
sam3_conv_out = conv(x_data)
sam3_x_out = NestedTensor(sam3_conv_out, x_mask)
sam3_out.append(sam3_x_out)
sam3_pos.append(
self.position_encoding(sam3_conv_out).to(sam3_conv_out.dtype)
)
if need_interactive_out:
interactive_conv_out_t = interactive_conv(x_data)
interactive_conv_out = NestedTensor(interactive_conv_out_t, x_mask)
interactive_out.append(interactive_conv_out)
interactive_pos.append(
self.position_encoding(interactive_conv_out_t).to(
interactive_conv_out_t.dtype
)
)
if need_propagation_out:
propagation_conv_out = propagation_conv(x_data)
propagation_x_out = NestedTensor(propagation_conv_out, x_mask)
propagation_out.append(propagation_x_out)
propagation_pos.append(
self.position_encoding(propagation_conv_out).to(
propagation_conv_out.dtype
)
)
return (
sam3_out,
sam3_pos,
interactive_out,
interactive_pos,
propagation_out,
propagation_pos,
)

Xet Storage Details

Size:
9.06 kB
·
Xet hash:
c307d20a8330cdb4c8ab02224e0c04269d6e5b6725d950723b293825f129ab0f

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.