| | import PIL |
| | import torch |
| |
|
| | from .prompts import GetPromptList |
| |
|
| | ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat'] |
| | ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail'] |
| |
|
| | def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: list[str], device: str, max_batch_size: int = 512): |
| | total_num_batches = len(descs) // max_batch_size + 1 |
| | with torch.no_grad(): |
| | text_embeds = [] |
| | for batch_idx in range(total_num_batches): |
| | query_descs = descs[batch_idx*max_batch_size:(batch_idx+1)*max_batch_size] |
| | query_tokens = owlvit_det_processor(text=query_descs, padding="max_length", truncation=True, return_tensors="pt").to(device) |
| | query_embeds = model.owlvit.get_text_features(**query_tokens) |
| | text_embeds.append(query_embeds.cpu().float()) |
| | text_embeds = torch.cat(text_embeds, dim=0) |
| | return text_embeds.to(device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def xclip_pred(new_desc: dict, |
| | new_part_mask: dict, |
| | new_class: str, |
| | org_desc: str, |
| | image: PIL.Image, |
| | model: callable, |
| | owlvit_processor: callable, |
| | device: str, |
| | return_img_embeds: bool = False, |
| | use_precompute_embeddings = True, |
| | image_name: str = None,): |
| | |
| | if new_class is not None: |
| | new_desc_ = {k: new_desc[k] for k in ORG_PART_ORDER} |
| | new_part_mask_ = {k: new_part_mask[k] for k in ORG_PART_ORDER} |
| | desc_mask = list(new_part_mask_.values()) |
| | else: |
| | desc_mask = [1] * 12 |
| |
|
| | |
| | getprompt = GetPromptList(org_desc) |
| | if new_class not in getprompt.desc and new_class is not None: |
| | getprompt.name2idx[new_class] = len(getprompt.name2idx) |
| | if new_class is not None: |
| | getprompt.desc[new_class] = list(new_desc_.values()) |
| | |
| | idx2name = dict(zip(getprompt.name2idx.values(), getprompt.name2idx.keys())) |
| | modified_class_idx = getprompt.name2idx[new_class] if new_class is not None else None |
| | |
| | n_classes = len(getprompt.name2idx) |
| | model.cls_head.num_classes = n_classes |
| | |
| | descs, class_idxs, class_mapping, org_desc_mapper, class_list = getprompt('chatgpt-no-template', max_len=12, pad=True) |
| | query_embeds = encode_descs_xclip(owlvit_processor, model, descs, device) |
| | |
| | with torch.no_grad(): |
| | image_input = owlvit_processor(images=image, return_tensors='pt').to(device) |
| | |
| | |
| | part_embeds = owlvit_processor(text=[ORG_PART_ORDER], return_tensors="pt").to(device) |
| | if return_img_embeds: |
| | feature_map, _ = model.image_embedder(pixel_values = image_input['pixel_values']) |
| | if use_precompute_embeddings: |
| | image_embeds = torch.load(f'data/image_embeddings/{image_name}.pt').to(device) |
| | pred_logits, part_logits, output_dict = model(image_embeds, part_embeds, query_embeds, None) |
| | else: |
| | pred_logits, part_logits, output_dict = model(image_input, part_embeds, query_embeds, None) |
| | |
| | b, c, n = part_logits.shape |
| | mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device) |
| | |
| | part_logits = part_logits * mask |
| | pred_logits = torch.sum(part_logits, dim=-1) |
| | |
| | pred_class_idx = torch.argmax(pred_logits, dim=-1).cpu() |
| | pred_class_name = idx2name[pred_class_idx.item()] |
| | |
| | softmax_scores = torch.softmax(pred_logits, dim=-1).cpu() |
| | softmax_score_top1 = torch.topk(softmax_scores, k=1, dim=-1)[0].squeeze(-1).item() |
| | |
| | part_scores = part_logits[0, pred_class_idx].cpu().squeeze(0) |
| | part_scores_dict = dict(zip(ORG_PART_ORDER, part_scores.tolist())) |
| | |
| | if modified_class_idx is not None: |
| | modified_score = softmax_scores[0, modified_class_idx].item() |
| | modified_part_scores = part_logits[0, modified_class_idx].cpu().squeeze(0) |
| | modified_part_scores_dict = dict(zip(ORG_PART_ORDER, modified_part_scores.tolist())) |
| | else: |
| | modified_score = None |
| | modified_part_scores_dict = None |
| | modified_part_scores_dict = None |
| | |
| | output_dict = {"pred_class": pred_class_name, |
| | "pred_score": softmax_score_top1, |
| | "pred_desc_scores": part_scores_dict, |
| | "descriptions": getprompt.desc[pred_class_name], |
| | "modified_class": new_class, |
| | "modified_score": modified_score, |
| | "modified_desc_scores": modified_part_scores_dict, |
| | "modified_descriptions": getprompt.desc[new_class] if new_class is not None else None, |
| | } |
| | return output_dict if not return_img_embeds else (output_dict, feature_map) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |