Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torch.utils.data import Dataset | |
| class ImageCaptionDataset(Dataset): | |
| """ | |
| Custom PyTorch Dataset class to handle loading and transforming image-caption pairs | |
| where image paths and captions are provided in a CSV file. | |
| Attributes: | |
| caption_file (str): Path to the CSV file containing image paths and captions. | |
| transform (torchvision.transforms.Compose): Transformations to apply on the images. | |
| """ | |
| def __init__(self, caption_file: str, file_path: str, transform=None): | |
| """ | |
| Initialize dataset with caption CSV file and optional transform. | |
| Args: | |
| caption_file (str): Path to the CSV file where each row has an image path and caption. | |
| transform (callable, optional): Optional transform to apply on an image. | |
| """ | |
| self.df = pd.read_csv(caption_file) | |
| self.image_path = file_path | |
| self.transform = transform or transforms.Compose([ | |
| transforms.Resize((224, 224)), # Resize to 224x224 for ViT | |
| transforms.ToTensor(), # Convert to tensor | |
| # Normalize to have values in the range [0, 1] | |
| # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def __len__(self): | |
| """ | |
| Return the total number of samples in the dataset. | |
| """ | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| """ | |
| Retrieve an image and its corresponding caption by index. | |
| Args: | |
| idx (int): Index of the data item. | |
| Returns: | |
| tuple: (image, caption) where image is the transformed image tensor and caption is the associated text. | |
| """ | |
| img_path = self.df.iloc[idx, 0] # The first column contains image paths | |
| caption = self.df.iloc[idx, 1] # The second column contains captions | |
| # Load image | |
| image = Image.open(self.image_path+img_path).convert('RGB') | |
| # Apply transformations to the image | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, caption | |