Spaces:
Running
Running
| import glob | |
| import json | |
| import logging | |
| import os | |
| import random | |
| from collections import OrderedDict | |
| from multiprocessing import Value | |
| from pathlib import Path | |
| import braceexpand | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import webdataset as wds | |
| from lightning_fabric.utilities.rank_zero import _get_rank | |
| from PIL import Image | |
| from torch.utils.data import Dataset, get_worker_info | |
| from tqdm import tqdm | |
| from webdataset.tariterators import ( | |
| base_plus_ext, | |
| tar_file_expander, | |
| url_opener, | |
| valid_sample, | |
| ) | |
| from functools import partial | |
| import math | |
| class GPSWebdataset(wds.DataPipeline): | |
| def __init__( | |
| self, | |
| root, | |
| image_transforms=None, | |
| distributed=True, | |
| train=True, | |
| epoch=0, | |
| seed=3407, | |
| embedding_name=None, | |
| return_image=True, | |
| shard_shuffle_size=2000, | |
| shard_shuffle_initial=500, | |
| sample_shuffle_size=5000, | |
| sample_shuffle_initial=1000, | |
| metadata_attributes=[], | |
| ): | |
| self.image_transforms = image_transforms | |
| dataset_tar_files = [] | |
| # Get a list of all tar files in the directory | |
| if " " in root: | |
| root = root.split(" ") | |
| print(f"Using multiple dataset[s: {root}") | |
| if isinstance(root, str): | |
| tar_files = [f for f in os.listdir(root) if f.endswith(".tar")] | |
| # Sort the list of tar files | |
| tar_files.sort() | |
| first_tar_file = tar_files[0].split(".")[0] | |
| last_tar_file = tar_files[-1].split(".")[0] | |
| for tar_file in tar_files: | |
| dataset_tar_files.append(f"{root}/{tar_file}") | |
| dataset_pattern = f"{root}/{{{first_tar_file}..{last_tar_file}}}.tar" | |
| self.num_samples, _ = get_dataset_size(dataset_pattern) | |
| elif isinstance(root, list): | |
| num_samples = 0 | |
| for r in root: | |
| tar_files = [f for f in os.listdir(r) if f.endswith(".tar")] | |
| tar_files.sort() | |
| first_tar_file = tar_files[0].split(".")[0] | |
| last_tar_file = tar_files[-1].split(".")[0] | |
| for tar_file in tar_files: | |
| dataset_tar_files.append(f"{r}/{tar_file}") | |
| num_samples += get_dataset_size( | |
| f"{r}/{{{first_tar_file}..{last_tar_file}}}.tar" | |
| )[0] | |
| self.num_samples = num_samples | |
| else: | |
| raise ValueError( | |
| f"root must be a string or list of strings. Got {type(root)}" | |
| ) | |
| rank = _get_rank() | |
| self.shared_epoch = SharedEpoch(epoch) | |
| pipeline = [wds.SimpleShardList(dataset_tar_files)] | |
| if distributed: | |
| if train: | |
| pipeline.extend( | |
| [ | |
| detshuffle2( | |
| bufsize=shard_shuffle_size, | |
| initial=shard_shuffle_initial, | |
| seed=seed, | |
| epoch=self.shared_epoch, | |
| ), | |
| wds.split_by_node, | |
| wds.split_by_worker, | |
| tarfile_to_samples_nothrow, | |
| wds.shuffle( | |
| bufsize=sample_shuffle_size, | |
| initial=sample_shuffle_initial, | |
| ), | |
| ] | |
| ) | |
| else: | |
| pipeline.extend( | |
| [wds.split_by_node, wds.split_by_worker, tarfile_to_samples_nothrow] | |
| ) | |
| else: | |
| if train: | |
| pipeline.extend( | |
| [ | |
| wds.shuffle( | |
| bufsize=shard_shuffle_size, | |
| initial=sample_shuffle_initial, | |
| ), | |
| wds.split_by_worker, | |
| tarfile_to_samples_nothrow, | |
| wds.shuffle( | |
| bufsize=sample_shuffle_size, | |
| initial=sample_shuffle_initial, | |
| ), | |
| ] | |
| ) | |
| else: | |
| pipeline.extend([wds.split_by_worker, tarfile_to_samples_nothrow]) | |
| outputs_transforms = OrderedDict() | |
| outputs_rename = OrderedDict() | |
| if return_image: | |
| outputs_rename["img.jpg"] = "jpg;png;webp;jpeg" | |
| outputs_transforms["img.jpg"] = ( | |
| self.image_transforms | |
| if self.image_transforms is not None | |
| else lambda x: x | |
| ) | |
| if embedding_name is not None: | |
| outputs_rename[f"emb.npy"] = f"{embedding_name}.npy" | |
| outputs_transforms[f"emb.npy"] = lambda x: torch.from_numpy(x) | |
| if metadata_attributes != []: | |
| for attr in metadata_attributes: | |
| outputs_rename[f"{attr}.json"] = f"json" | |
| outputs_transforms[f"{attr}.json"] = partial(get_attr, attr=attr) | |
| outputs_rename["gps"] = "json" | |
| outputs_transforms["gps"] = get_gps | |
| pipeline.extend( | |
| [ | |
| wds.rename(**outputs_rename), | |
| filter_dict_keys(*outputs_rename.keys(), handler=log_and_continue), | |
| ] | |
| ) | |
| if return_image: | |
| pipeline.append(wds.decode("pilrgb", handler=log_and_continue)) | |
| else: | |
| pipeline.append(wds.decode(handler=log_and_continue)) | |
| pipeline.extend( | |
| [ | |
| wds.map_dict(**outputs_transforms, handler=log_and_continue), | |
| wds.rename( | |
| **{k.split(".")[0]: k for k in outputs_transforms.keys()}, | |
| ), | |
| ] | |
| ) | |
| super().__init__(*pipeline) | |
| def __len__(self): | |
| return self.num_samples | |
| def normalize_gps(lat, lon): | |
| """Used to put all lat lon inside ±90 and ±180.""" | |
| lat = (lat + 90) % 360 - 90 | |
| if lat > 90: | |
| lat = 180 - lat | |
| lon += 180 | |
| lon = (lon + 180) % 360 - 180 | |
| return lat, lon | |
| def get_attr(metadata, attr): | |
| # datapoint = json.loads(metadata) | |
| attr_value = metadata[attr] | |
| if isinstance(attr_value, float) and math.isnan(attr_value): | |
| return "NaN" | |
| else: | |
| return attr_value | |
| def get_gps(metadata): | |
| datapoint = json.loads(metadata) | |
| lat, lon = normalize_gps( | |
| float(datapoint["latitude"]), float(datapoint["longitude"]) | |
| ) | |
| gps = torch.tensor([np.radians(lat), np.radians(lon)], dtype=torch.float) | |
| return gps | |
| def get_dataset_size(shards): | |
| shards_list, _ = expand_urls(shards) | |
| dir_path = os.path.dirname(shards_list[0]) | |
| sizes_filename = os.path.join(dir_path, "sizes.json") | |
| if os.path.exists(sizes_filename): | |
| sizes = json.load(open(sizes_filename, "r")) | |
| total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list]) | |
| else: | |
| total_size = 0 # num samples undefined | |
| sizes = {} | |
| for shard in tqdm(shards_list): | |
| dataset = wds.WebDataset(shard) | |
| num_samples = sum(1 for _ in dataset) | |
| total_size += num_samples | |
| sizes[os.path.basename(shard)] = num_samples | |
| print(f"Total number of samples: {total_size}") | |
| with open(sizes_filename, "w") as f: | |
| json.dump(sizes, f) | |
| num_shards = len(shards_list) | |
| return total_size, num_shards | |
| def expand_urls(urls, weights=None): | |
| if weights is None: | |
| expanded_urls = wds.shardlists.expand_urls(urls) | |
| return expanded_urls, None | |
| if isinstance(urls, str): | |
| urllist = urls.split("::") | |
| weights = weights.split("::") | |
| assert len(weights) == len( | |
| urllist | |
| ), f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match." | |
| weights = [float(weight) for weight in weights] | |
| all_urls, all_weights = [], [] | |
| for url, weight in zip(urllist, weights): | |
| expanded_url = list(braceexpand.braceexpand(url)) | |
| expanded_weights = [weight for _ in expanded_url] | |
| all_urls.extend(expanded_url) | |
| all_weights.extend(expanded_weights) | |
| return all_urls, all_weights | |
| else: | |
| all_urls = list(urls) | |
| return all_urls, weights | |
| class SharedEpoch: | |
| def __init__(self, epoch: int = 0): | |
| self.shared_epoch = Value("i", epoch) | |
| def set_value(self, epoch): | |
| self.shared_epoch.value = epoch | |
| def get_value(self): | |
| return self.shared_epoch.value | |
| # _SHARD_SHUFFLE_SIZE = 256 | |
| # _SHARD_SHUFFLE_INITIAL = 128 | |
| # _SAMPLE_SHUFFLE_SIZE = 5000 | |
| # _SAMPLE_SHUFFLE_INITIAL = 1000 | |
| class detshuffle2(wds.PipelineStage): | |
| def __init__( | |
| self, | |
| bufsize=1000, | |
| initial=100, | |
| seed=0, | |
| epoch=-1, | |
| ): | |
| self.bufsize = bufsize | |
| self.initial = initial | |
| self.seed = seed | |
| self.epoch = epoch | |
| def run(self, src): | |
| if isinstance(self.epoch, SharedEpoch): | |
| epoch = self.epoch.get_value() | |
| else: | |
| # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) | |
| # situation as different workers may wrap at different times (or not at all). | |
| self.epoch += 1 | |
| epoch = self.epoch | |
| rng = random.Random() | |
| if self.seed < 0: | |
| # If seed is negative, we use the worker's seed, this will be different across all nodes/workers | |
| seed = pytorch_worker_seed(epoch) | |
| else: | |
| # This seed to be deterministic AND the same across all nodes/workers in each epoch | |
| seed = self.seed + epoch | |
| rng.seed(seed) | |
| return wds.filters._shuffle(src, self.bufsize, self.initial, rng) | |
| def pytorch_worker_seed(increment=0): | |
| """get dataloader worker seed from pytorch""" | |
| worker_info = get_worker_info() | |
| if worker_info is not None: | |
| # favour using the seed already created for pytorch dataloader workers if it exists | |
| seed = worker_info.seed | |
| if increment: | |
| # space out seed increments so they can't overlap across workers in different iterations | |
| seed += increment * max(1, worker_info.num_workers) | |
| return seed | |
| # fallback to wds rank based seed | |
| return wds.utils.pytorch_worker_seed() | |
| def log_and_continue(exn): | |
| """Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
| logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") | |
| return True | |
| def group_by_keys_nothrow( | |
| data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None | |
| ): | |
| """Return function over iterator that groups key, value pairs into samples. | |
| :param keys: function that splits the key into key and extension (base_plus_ext) | |
| :param lcase: convert suffixes to lower case (Default value = True) | |
| """ | |
| current_sample = None | |
| for filesample in data: | |
| assert isinstance(filesample, dict) | |
| fname, value = filesample["fname"], filesample["data"] | |
| prefix, suffix = keys(fname) | |
| if prefix is None: | |
| continue | |
| if lcase: | |
| suffix = suffix.lower() | |
| # FIXME webdataset version throws if suffix in current_sample, but we have a potential for | |
| # this happening in the current LAION400m dataset if a tar ends with same prefix as the next | |
| # begins, rare, but can happen since prefix aren't unique across tar files in that dataset | |
| if ( | |
| current_sample is None | |
| or prefix != current_sample["__key__"] | |
| or suffix in current_sample | |
| ): | |
| if valid_sample(current_sample): | |
| yield current_sample | |
| current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
| if suffixes is None or suffix in suffixes: | |
| current_sample[suffix] = value | |
| if valid_sample(current_sample): | |
| yield current_sample | |
| def tarfile_to_samples_nothrow(src, handler=log_and_continue): | |
| # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw | |
| streams = url_opener(src, handler=handler) | |
| files = tar_file_expander(streams, handler=handler) | |
| samples = group_by_keys_nothrow(files, handler=handler) | |
| return samples | |
| def filter_no_caption_or_no_image(sample): | |
| has_caption = "txt" in sample | |
| has_image = ( | |
| "png" in sample or "jpg" in sample or "jpeg" in sample or "webp" in sample | |
| ) | |
| return has_caption and has_image | |
| def filter_metadata(sample, min_image_size, min_clip_score): | |
| metadata = json.loads(sample["json"]) | |
| width = metadata["width"] | |
| height = metadata["height"] | |
| clip_score = metadata["clip_score"] / 100 | |
| return ( | |
| width >= min_image_size | |
| and height >= min_image_size | |
| and clip_score >= min_clip_score | |
| ) | |
| def _filter_dict_keys( | |
| data, | |
| *args, | |
| handler=wds.reraise_exception, | |
| missing_is_error=True, | |
| none_is_error=None, | |
| ): | |
| """Convert dict samples to tuples.""" | |
| if none_is_error is None: | |
| none_is_error = missing_is_error | |
| if len(args) == 1 and isinstance(args[0], str) and " " in args[0]: | |
| args = args[0].split() | |
| for sample in data: | |
| try: | |
| result = { | |
| f: wds.getfirst(sample, f, missing_is_error=missing_is_error) | |
| for f in args | |
| } | |
| if none_is_error and any(x is None for x in result): | |
| raise ValueError(f"to_tuple {args} got {sample.keys()}") | |
| yield result | |
| except Exception as exn: | |
| if handler(exn): | |
| continue | |
| else: | |
| break | |
| filter_dict_keys = wds.pipelinefilter(_filter_dict_keys) | |