| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import random |
| import threading |
| from abc import ABC |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from dataclasses import dataclass |
| from functools import partial |
| from itertools import chain |
| from typing import Any, Dict, List, Optional, Tuple, Union |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| from omegaconf import DictConfig |
|
|
| from common.distributed import get_global_rank, get_world_size |
| from common.fs import copy, exists, listdir, mkdir, remove |
| from common.partition import partition_by_groups |
| from common.persistence.utils import get_local_path |
| from data.common.parquet_sampler import ( |
| IdentityParquetSampler, |
| ParquetSampler, |
| create_parquet_sampler, |
| ) |
| from data.common.utils import filter_parquets, get_parquet_metadata |
|
|
|
|
| |
| def save_and_copy( |
| pa_table, |
| local_path: str, |
| target_path: str, |
| row_group_size: int, |
| executor: ThreadPoolExecutor, |
| do_async: bool = False, |
| futures: List[Tuple[threading.Thread, str]] = [], |
| ): |
| |
| def _make_on_complete(local_path): |
| def _on_complete(future): |
| target_path = future.result() |
| remove(local_path) |
| |
| print(f"Target path saved: {target_path}") |
|
|
| return _on_complete |
|
|
| |
| def _fn(pa_table, local_path, target_path, row_group_size): |
| pq.write_table( |
| pa_table, |
| local_path, |
| row_group_size=row_group_size, |
| ) |
| mkdir(os.path.dirname(target_path)) |
| copy(local_path, target_path) |
| return target_path |
|
|
| |
| future = executor.submit(_fn, pa_table, local_path, target_path, row_group_size) |
| future.add_done_callback(_make_on_complete(local_path)) |
| futures.append(future) |
|
|
| |
| if not do_async: |
| for future in as_completed(futures): |
| try: |
| future.result() |
| except Exception as exc: |
| print(f"Generated an exception: {exc}") |
| executor.shutdown(wait=True) |
|
|
|
|
| @dataclass |
| class FileListOutput: |
| existing_files: List[str] |
| source_files: List[Any] |
| target_files: List[str] |
|
|
|
|
| @dataclass |
| class PersistedParquet: |
| path: str |
|
|
| |
| def save( |
| self, |
| row_group_size: int, |
| executor: ThreadPoolExecutor, |
| pa_table: Optional[pa.Table] = None, |
| data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, |
| is_last_file=False, |
| futures: List[threading.Thread] = [], |
| ): |
| assert (pa_table is None) != (data_dict is None) |
| local_path = get_local_path(self.path) |
| if not pa_table: |
| schema_dict = self.generate_schema_from_dict(data_dict) |
| pa_table = pa.Table.from_pydict(data_dict, schema=schema_dict) |
| save_and_copy( |
| pa_table, |
| local_path=local_path, |
| target_path=self.path, |
| row_group_size=row_group_size, |
| executor=executor, |
| do_async=not is_last_file, |
| futures=futures, |
| ) |
|
|
| |
| def generate_schema_from_dict( |
| self, |
| data_dict: Dict[str, List[Union[str, bytes]]], |
| ): |
| schema_dict = {} |
| for key, value in data_dict.items(): |
| if isinstance(value[0], str): |
| schema_dict[key] = pa.string() |
| elif isinstance(value[0], bytes): |
| schema_dict[key] = pa.binary() |
| else: |
| raise ValueError(f"Unsupported data type for key '{key}': {type(value)}") |
| return pa.schema(schema_dict) |
|
|
|
|
| |
| class ParquetManager(ABC): |
| """ |
| Base class for the DumpingManager and RepackingManager. |
| """ |
|
|
| def __init__( |
| self, |
| task: Optional[DictConfig] = None, |
| target_dir: str = ".", |
| ): |
| self.task = task |
| self.target_dir = target_dir.rstrip("/") |
| self.executor = ThreadPoolExecutor(max_workers=4) |
| self.futures = [] |
|
|
| |
| def get_parquet_files( |
| self, |
| source_path: str, |
| parquet_sampler: ParquetSampler = IdentityParquetSampler(), |
| path_mode: str = "dir", |
| ): |
|
|
| |
| def _flatten(paths): |
| if isinstance(paths, list): |
| if any(isinstance(i, list) for i in paths): |
| return list(chain(*paths)) |
| else: |
| return paths |
| else: |
| return [paths] |
|
|
| file_paths = _flatten(source_path) |
| if path_mode == "dir": |
| file_paths = map(listdir, file_paths) |
| if isinstance(parquet_sampler.size, float): |
| file_paths = map(filter_parquets, file_paths) |
| file_paths = map(parquet_sampler, file_paths) |
| file_paths = list(chain(*file_paths)) |
| else: |
| file_paths = chain(*file_paths) |
| file_paths = parquet_sampler(filter_parquets(file_paths)) |
|
|
| return file_paths |
|
|
| |
| def save_parquet( |
| self, |
| *, |
| file_name: str, |
| row_group_size: int, |
| pa_table: Optional[pa.Table] = None, |
| data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, |
| override: bool = True, |
| is_last_file: bool = False, |
| ): |
|
|
| persist = self._get_parquet(file_name) |
| if override or not exists(persist.path): |
| persist.save( |
| pa_table=pa_table, |
| data_dict=data_dict, |
| executor=self.executor, |
| row_group_size=row_group_size, |
| is_last_file=is_last_file, |
| futures=self.futures, |
| ) |
|
|
| |
| def _get_parquet(self, file_name: str) -> PersistedParquet: |
| return PersistedParquet(file_name) |
|
|
|
|
| |
| class DumpingManager(ParquetManager): |
| """ |
| Dumping manager handles parquet saving and resuming. |
| """ |
|
|
| def __init__( |
| self, |
| task: DictConfig, |
| target_dir: str, |
| ): |
| super().__init__(task=task, target_dir=target_dir) |
|
|
| |
| def generate_saving_path(self, file_path: str, rsplit: int): |
| part_list = file_path.rsplit("/", rsplit) |
| result_folder = "/".join( |
| [self.target_dir] + [f"epoch_{self.task.epoch}"] + part_list[-rsplit:-1] |
| ) |
| result_file = "/".join([result_folder, part_list[-1]]) |
| return result_folder, result_file |
|
|
| |
| def configure_task_path(self, source_path: str, rsplit: int, path_mode: str = "dir"): |
|
|
| file_paths = self.get_parquet_files( |
| source_path=source_path, |
| path_mode=path_mode, |
| ) |
|
|
| |
| random.Random(0).shuffle(file_paths) |
|
|
| |
| full_source_files = partition_by_groups(file_paths, self.task.total_count)[self.task.index] |
| full_source_files = partition_by_groups(full_source_files, get_world_size())[ |
| get_global_rank() |
| ] |
|
|
| if not full_source_files: |
| return FileListOutput([], [], []) |
|
|
| generate_saving_path = partial(self.generate_saving_path, rsplit=rsplit) |
| full_paths = map(generate_saving_path, full_source_files) |
| full_target_folders, full_target_files = map(list, zip(*full_paths)) |
| full_target_folders = set(full_target_folders) |
|
|
| existing_file_paths = map( |
| lambda folder: listdir(folder) if exists(folder) else [], full_target_folders |
| ) |
| existing_file_paths = chain(*existing_file_paths) |
| self.existing_files = list( |
| filter( |
| lambda path: path.endswith(".parquet") and path in full_target_files, |
| existing_file_paths, |
| ) |
| ) |
|
|
| filtered_pairs = list( |
| filter( |
| lambda pair: pair[1] not in self.existing_files, |
| zip(full_source_files, full_target_files), |
| ) |
| ) |
| if filtered_pairs: |
| filtered_source_files, filtered_target_files = map(list, zip(*filtered_pairs)) |
| else: |
| filtered_source_files, filtered_target_files = [], [] |
|
|
| |
| skip_exists = self.task.skip_exists |
| self.source_files = filtered_source_files if skip_exists else full_source_files |
| self.target_files = filtered_target_files if skip_exists else full_target_files |
|
|
| return FileListOutput(self.existing_files, self.source_files, self.target_files) |
|
|
|
|
| class RepackingManager(ParquetManager): |
| """ |
| Repacking manager handles parquet spliting and saving. |
| """ |
|
|
| def __init__( |
| self, |
| task: DictConfig, |
| target_dir: str, |
| repackaging: DictConfig, |
| ): |
| super().__init__(task=task, target_dir=target_dir) |
| self.repackaging = repackaging |
|
|
| |
| def configure_task_path( |
| self, |
| source_path: str, |
| parquet_sampler: Optional[DictConfig] = None, |
| path_mode: str = "dir", |
| ): |
|
|
| parquet_sampler = create_parquet_sampler(config=parquet_sampler) |
| file_paths = self.get_parquet_files( |
| source_path=source_path, |
| parquet_sampler=parquet_sampler, |
| path_mode=path_mode, |
| ) |
|
|
| random.Random(0).shuffle(file_paths) |
| target_dir = self.target_dir |
| size = abs(parquet_sampler.size) |
|
|
| if self.task: |
| |
| file_paths = partition_by_groups(file_paths, self.task.total_count)[self.task.index] |
| target_dir = os.path.join(target_dir, f"{self.task.total_count}_{self.task.index}") |
|
|
| if size > 1: |
| size = len( |
| partition_by_groups(range(size), self.task.total_count)[self.task.index] |
| ) |
|
|
| |
| metadatas = get_parquet_metadata(file_paths, self.repackaging.num_processes) |
|
|
| |
| target_items = [ |
| (file_path, row) |
| for file_path, metadata in zip(file_paths, metadatas) |
| for row in range(metadata.num_rows) |
| ] |
|
|
| |
| random.Random(0).shuffle(target_items) |
|
|
| if size > 1: |
| target_items = target_items[:size] |
|
|
| |
| items_per_file = partition_by_groups(target_items, self.repackaging.num_files) |
|
|
| |
| target_files = [ |
| os.path.join(target_dir, f"{str(i).zfill(5)}.parquet") |
| for i in range(self.repackaging.num_files) |
| ] |
|
|
| existing_file_paths = listdir(target_dir) if exists(target_dir) else [] |
| self.existing_files = list( |
| filter( |
| lambda path: path.endswith(".parquet"), |
| existing_file_paths, |
| ) |
| ) |
| self.source_files = items_per_file |
| self.target_files = target_files |
|
|
| return FileListOutput(self.existing_files, self.source_files, self.target_files) |
|
|