Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| ''' | |
| @File : metauas.py | |
| @Time : 2025/03/26 23:46:12 | |
| @Author : Bin-Bin Gao | |
| @Email : csgaobb@gmail.com | |
| @Homepage: https://csgaobb.github.io/ | |
| @Version : 1.0 | |
| @Desc : some classes and functions for MetaUAS | |
| ''' | |
| import os | |
| import random | |
| import kornia as K | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| import tqdm | |
| import time | |
| import cv2 | |
| from PIL import Image | |
| from einops import rearrange | |
| from torch.nn import functional as F | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import pil_to_tensor | |
| from segmentation_models_pytorch.unet.model import UnetDecoder | |
| from segmentation_models_pytorch.fpn.decoder import FPNDecoder | |
| from segmentation_models_pytorch.encoders import get_encoder, get_preprocessing_params | |
| from huggingface_hub import PyTorchModelHubMixin | |
| def set_random_seed(seed=233, reproduce=False): | |
| np.random.seed(seed) | |
| torch.manual_seed(seed ** 2) | |
| torch.cuda.manual_seed(seed ** 3) | |
| random.seed(seed ** 4) | |
| if reproduce: | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| else: | |
| torch.backends.cudnn.benchmark = True | |
| def normalize(pred, max_value=None, min_value=None): | |
| if max_value is None or min_value is None: | |
| return (pred - pred.min()) / (pred.max() - pred.min()) | |
| else: | |
| return (pred - min_value) / (max_value - min_value) | |
| def apply_ad_scoremap(image, scoremap, alpha=0.5): | |
| np_image = np.asarray(image, dtype=np.float32) | |
| scoremap = (scoremap * 255).astype(np.uint8) | |
| scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET) | |
| scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB) | |
| return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8) | |
| def read_image_as_tensor(path_to_image): | |
| pil_image = Image.open(path_to_image).convert("RGB") | |
| image_as_tensor = pil_to_tensor(pil_image).float() / 255.0 | |
| return image_as_tensor | |
| def safely_load_state_dict(model, checkpoint): | |
| model.load_state_dict(torch.load(checkpoint), strict=True) | |
| return model | |
| class AlignmentModule(nn.Module): | |
| def __init__(self, input_channels=2048, hidden_channels=256, alignment_type="sa", fusion_policy='cat'): | |
| super().__init__() | |
| self.fusion_policy = fusion_policy | |
| self.alignment_layer = AlignmentLayer(input_channels, hidden_channels, alignment_type=alignment_type) | |
| def forward(self, query_features, prompt_features): | |
| if isinstance(prompt_features, list): | |
| aligned_prompt = [] | |
| for i in range(len(prompt_features)): | |
| weighted_prompt.append(self.alignment_layer(query_features, prompt_features[i])) | |
| aligned_prompt = torch.mean(torch.stack(aligned_prompt),0) | |
| else: | |
| aligned_prompt = self.alignment_layer(query_features, prompt_features) | |
| if self.fusion_policy == 'cat': | |
| query_features = rearrange( | |
| [query_features, aligned_prompt], "two b c h w -> b (two c) h w" | |
| ) | |
| elif self.fusion_policy == 'add': | |
| query_features = query_features + aligned_prompt | |
| elif self.fusion_policy == 'absdiff': | |
| query_features = (query_features - aligned_prompt).abs() | |
| return query_features | |
| class AlignmentLayer(nn.Module): | |
| def __init__(self, input_channels=2048, hidden_channels=256, alignment_type="sa"): | |
| super().__init__() | |
| self.alignment_type = alignment_type | |
| if alignment_type != "na": | |
| self.dimensionality_reduction = nn.Conv2d( | |
| input_channels, hidden_channels, kernel_size=1, stride=1, padding=0, bias=True | |
| ) | |
| def forward(self, query_features, prompt_features): | |
| # no-alignment | |
| if self.alignment_type == 'na': | |
| return prompt_features | |
| else: | |
| Q = self.dimensionality_reduction(query_features) | |
| K = self.dimensionality_reduction(prompt_features) | |
| V = rearrange(prompt_features, "b c h w -> b c (h w)") | |
| soft_attention_map = torch.einsum("bcij,bckl->bijkl", Q, K) | |
| soft_attention_map = rearrange(soft_attention_map, "b h1 w1 h2 w2 -> b h1 w1 (h2 w2)") | |
| soft_attention_map = nn.Softmax(dim=3)(soft_attention_map) | |
| # soft-alignment | |
| if self.alignment_type == 'sa': | |
| aligned_features = torch.einsum("bijp,bcp->bcij", soft_attention_map, V) | |
| # hard-alignment | |
| if self.alignment_type == 'ha': | |
| max_v, max_index = attention_map.max(dim=-1, keepdim=True) | |
| hard_attention_map = (attention_map == max_v).float() | |
| aligned_features = torch.einsum("bijp,bcp->bcij", hard_attention_map, V) | |
| return aligned_features | |
| class MetaUAS(pl.LightningModule): | |
| def __init__(self, encoder_name, decoder_name, encoder_depth, decoder_depth, num_alignment_layers, alignment_type, fusion_policy): | |
| super().__init__() | |
| self.encoder_name = encoder_name | |
| self.decoder_name = decoder_name | |
| self.encoder_depth = encoder_depth | |
| self.decoder_depth = decoder_depth | |
| self.num_alignment_layers = num_alignment_layers | |
| self.alignment_type = alignment_type | |
| self.fusion_policy = fusion_policy | |
| align_input_channels = [448, 160, 56] | |
| align_hidden_channels = [224, 80, 28] | |
| encoder_channels = [3, 48, 32, 56, 160, 448] | |
| decoder_channels = [256, 128, 64, 64, 48] | |
| self.encoder = get_encoder( | |
| self.encoder_name, | |
| in_channels=3, | |
| depth=self.encoder_depth, | |
| weights="imagenet",) | |
| preparams = get_preprocessing_params( | |
| self.encoder_name, | |
| pretrained="imagenet" | |
| ) | |
| self.preprocess = transforms.Normalize(preparams['mean'], preparams['std']) | |
| self.encoder.eval() | |
| for param in self.encoder.parameters(): | |
| param.requires_grad = False | |
| if self.decoder_name == "unet": | |
| encoder_out_channels = encoder_channels[self.encoder_depth-self.decoder_depth:] | |
| if self.fusion_policy == 'cat': | |
| num_alignment_layers = self.num_alignment_layers | |
| elif self.fusion_policy == 'add' or self.fusion_policy == 'absdiff': | |
| num_alignment_layers = 0 | |
| self.decoder = UnetDecoder( | |
| encoder_channels=encoder_out_channels, | |
| decoder_channels=decoder_channels, | |
| n_blocks= self.decoder_depth, | |
| attention_type="scse", | |
| num_coam_layers= num_alignment_layers, | |
| ) | |
| elif self.decoder_name == "fpn": | |
| encoder_out_channels = encoder_channels | |
| if self.fusion_policy == 'cat': | |
| for i in range(self.num_alignment_layers): | |
| encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)] | |
| self.decoder = FPNDecoder( | |
| encoder_channels= encoder_out_channels, | |
| encoder_depth=self.encoder_depth, | |
| pyramid_channels=256, | |
| segmentation_channels=decoder_channels[-1], | |
| dropout=0.2, | |
| merge_policy="add", | |
| ) | |
| elif self.decoder_name == "fpnadd": | |
| segmentation_channels = 256 #128 | |
| encoder_out_channels = encoder_channels | |
| if self.fusion_policy == 'cat': | |
| for i in range(self.num_alignment_layers): | |
| encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)] | |
| self.decoder = FPNDecoder( | |
| encoder_channels= encoder_out_channels, | |
| encoder_depth=self.encoder_depth, | |
| pyramid_channels=256, | |
| segmentation_channels=segmentation_channels, | |
| dropout=0.2, | |
| merge_policy="add", | |
| ) | |
| elif self.decoder_name == "fpncat": | |
| encoder_out_channels = encoder_channels | |
| segmentation_channels = 256 #128 | |
| if self.fusion_policy == 'cat': | |
| for i in range(self.num_alignment_layers): | |
| encoder_out_channels[-(i+1)] = 2 * encoder_out_channels[-(i+1)] | |
| self.decoder = FPNDecoder( | |
| encoder_channels= encoder_out_channels, | |
| encoder_depth=self.encoder_depth, | |
| pyramid_channels=256, | |
| segmentation_channels=segmentation_channels, | |
| dropout=0.2, | |
| merge_policy="cat", | |
| ) | |
| if self.alignment_type == "sa" or self.alignment_type == "na" or self.alignment_type == "ha" : | |
| self.alignment = nn.ModuleList( | |
| [ | |
| AlignmentModule( | |
| input_channels=align_input_channels[i], | |
| hidden_channels=align_hidden_channels[i], | |
| alignment_type=self.alignment_type, | |
| fusion_policy=self.fusion_policy, | |
| ) | |
| for i in range(self.num_alignment_layers) | |
| ] | |
| ) | |
| if self.decoder_name == "fpncat": | |
| self.mask_head = nn.Conv2d( | |
| segmentation_channels*4, | |
| 1, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| ) | |
| elif self.decoder_name == "fpnadd": | |
| self.mask_head = nn.Conv2d( | |
| segmentation_channels, | |
| 1, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| ) | |
| else: | |
| self.mask_head = nn.Conv2d( | |
| decoder_channels[-1], | |
| 1, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| ) | |
| def forward(self, batch): | |
| query_input = self.preprocess(batch["query_image"]) | |
| prompt_input = self.preprocess(batch["prompt_image"]) | |
| with torch.no_grad(): | |
| query_encoded_features = self.encoder(query_input) | |
| prompt_encoded_features = self.encoder(prompt_input) | |
| for i in range(len(self.alignment)): | |
| query_encoded_features[-(i + 1)] = self.alignment[i](query_encoded_features[-(i + 1)], prompt_encoded_features[-(i + 1)]) | |
| query_decoded_features = self.decoder(*query_encoded_features[self.encoder_depth-self.decoder_depth:]) | |
| if self.decoder_name == "fpn" or self.decoder_name == "fpncat" or self.decoder_name == "fpnadd": | |
| output = F.interpolate(self.mask_head(query_decoded_features), scale_factor=4, mode="bilinear", align_corners=False) | |
| elif self.decoder_name == "unet": | |
| if self.decoder_depth == 4: | |
| output = F.interpolate(self.mask_head(query_decoded_features), scale_factor=2, mode="bilinear", align_corners=False) | |
| if self.decoder_depth == 5: | |
| if not self.training: | |
| output = self.mask_head(query_decoded_features) | |
| return output.sigmoid() | |