| | from modules.dataset_tvrr import TrainDataset, QueryEvalDataset, CorpusEvalDataset |
| | import torch |
| | from torch.utils.data import DataLoader |
| | from utils.tensor_utils import pad_sequences_1d |
| | import numpy as np |
| |
|
| | def collate_fn(batch, task): |
| | fixed_length = 128 |
| | batch_data = dict() |
| |
|
| | if task == "train": |
| | simis = [e["simi"] for e in batch] |
| | batch_data["simi"] = torch.tensor(simis) |
| | |
| |
|
| | |
| | query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None) |
| | batch_data["query_feat"] = query_feat_mask[0] |
| | batch_data["query_mask"] = query_feat_mask[1] |
| | video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length) |
| | batch_data["video_feat"] = video_feat_mask[0] |
| | batch_data["video_mask"] = video_feat_mask[1] |
| | sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length) |
| | batch_data["sub_feat"] = sub_feat_mask[0] |
| | batch_data["sub_mask"] = sub_feat_mask[1] |
| |
|
| | st_ed_indices = [e["st_ed_indices"] for e in batch] |
| | batch_data["st_ed_indices"] = torch.stack(st_ed_indices, dim=0) |
| | match_labels = np.zeros(shape=(len(st_ed_indices), fixed_length), dtype=np.int32) |
| | for idx, st_ed_index in enumerate(st_ed_indices): |
| | st_ed = st_ed_index.cpu().numpy() |
| | st, ed = st_ed[0], st_ed[1] |
| | match_labels[idx][st:(ed + 1)] = 1 |
| | batch_data['match_labels'] = torch.tensor(match_labels, dtype=torch.long) |
| | |
| | if task == "corpus": |
| | video_feat_mask = pad_sequences_1d([e["video_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length) |
| | batch_data["video_feat"] = video_feat_mask[0] |
| | batch_data["video_mask"] = video_feat_mask[1] |
| | sub_feat_mask = pad_sequences_1d([e["sub_feat"] for e in batch], dtype=torch.float32, fixed_length=fixed_length) |
| | batch_data["sub_feat"] = sub_feat_mask[0] |
| | batch_data["sub_mask"] = sub_feat_mask[1] |
| | |
| | |
| | |
| | |
| | if task == "eval": |
| | query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None) |
| | batch_data["query_feat"] = query_feat_mask[0] |
| | batch_data["query_mask"] = query_feat_mask[1] |
| |
|
| | query_id = [e["query_id"] for e in batch] |
| | batch_data["query_id"] = torch.tensor(query_id) |
| |
|
| | return batch_data |
| |
|
| |
|
| |
|
| |
|
| | def prepare_dataset(opt): |
| | train_set = TrainDataset( |
| | data_path=opt.train_path, |
| | desc_bert_path=opt.desc_bert_path, |
| | sub_bert_path=opt.sub_bert_path, |
| | max_desc_len=opt.max_desc_l, |
| | max_ctx_len=opt.max_ctx_l, |
| | video_feat_path=opt.video_feat_path, |
| | clip_length=opt.clip_length, |
| | ctx_mode=opt.ctx_mode, |
| | normalize_vfeat=not opt.no_norm_vfeat, |
| | normalize_tfeat=not opt.no_norm_tfeat) |
| | train_loader = DataLoader(train_set, collate_fn=lambda batch: collate_fn(batch, task='train'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=True, pin_memory=True) |
| | |
| | corpus_set = CorpusEvalDataset(corpus_path=opt.corpus_path, max_ctx_len=opt.max_ctx_l, sub_bert_path=opt.sub_bert_path, video_feat_path=opt.video_feat_path, ctx_mode=opt.ctx_mode) |
| | corpus_loader = DataLoader(corpus_set, collate_fn=lambda batch: collate_fn(batch, task='corpus'), batch_size=opt.bsz, num_workers=opt.num_workers, shuffle=False, pin_memory=True) |
| |
|
| | val_set = QueryEvalDataset(data_path=opt.val_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l) |
| | val_loader = DataLoader(val_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True) |
| | test_set = QueryEvalDataset(data_path=opt.test_path, desc_bert_path=opt.desc_bert_path, max_desc_len=opt.max_desc_l) |
| | test_loader = DataLoader(test_set, collate_fn=lambda batch: collate_fn(batch, task='eval'), batch_size=opt.bsz_eval, num_workers=opt.num_workers, shuffle=False, pin_memory=True) |
| | |
| | val_gt = val_set.ground_truth |
| | test_gt = test_set.ground_truth |
| | corpus_video_list = corpus_set.corpus_video_list |
| | return train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt |
| |
|