| | """ |
| | Utilities for working with the local dataset cache. |
| | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp |
| | Copyright by the AllenNLP authors. |
| | """ |
| |
|
| | import os |
| | import logging |
| | import shutil |
| | import tempfile |
| | import json |
| | from urllib.parse import urlparse |
| | from pathlib import Path |
| | from typing import Optional, Tuple, Union, IO, Callable, Set |
| | from hashlib import sha256 |
| | from functools import wraps |
| |
|
| | from tqdm import tqdm |
| |
|
| | import boto3 |
| | from botocore.exceptions import ClientError |
| | import requests |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', |
| | Path.home() / '.pytorch_pretrained_bert')) |
| |
|
| |
|
| | def url_to_filename(url: str, etag: str = None) -> str: |
| | """ |
| | Convert `url` into a hashed filename in a repeatable way. |
| | If `etag` is specified, append its hash to the url's, delimited |
| | by a period. |
| | """ |
| | url_bytes = url.encode('utf-8') |
| | url_hash = sha256(url_bytes) |
| | filename = url_hash.hexdigest() |
| |
|
| | if etag: |
| | etag_bytes = etag.encode('utf-8') |
| | etag_hash = sha256(etag_bytes) |
| | filename += '.' + etag_hash.hexdigest() |
| |
|
| | return filename |
| |
|
| |
|
| | def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: |
| | """ |
| | Return the url and etag (which may be ``None``) stored for `filename`. |
| | Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. |
| | """ |
| | if cache_dir is None: |
| | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| | if isinstance(cache_dir, Path): |
| | cache_dir = str(cache_dir) |
| |
|
| | cache_path = os.path.join(cache_dir, filename) |
| | if not os.path.exists(cache_path): |
| | raise FileNotFoundError("file {} not found".format(cache_path)) |
| |
|
| | meta_path = cache_path + '.json' |
| | if not os.path.exists(meta_path): |
| | raise FileNotFoundError("file {} not found".format(meta_path)) |
| |
|
| | with open(meta_path) as meta_file: |
| | metadata = json.load(meta_file) |
| | url = metadata['url'] |
| | etag = metadata['etag'] |
| |
|
| | return url, etag |
| |
|
| |
|
| | def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: |
| | """ |
| | Given something that might be a URL (or might be a local path), |
| | determine which. If it's a URL, download the file and cache it, and |
| | return the path to the cached file. If it's already a local path, |
| | make sure the file exists and then return the path. |
| | """ |
| | if cache_dir is None: |
| | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| | if isinstance(url_or_filename, Path): |
| | url_or_filename = str(url_or_filename) |
| | if isinstance(cache_dir, Path): |
| | cache_dir = str(cache_dir) |
| |
|
| | parsed = urlparse(url_or_filename) |
| |
|
| | if parsed.scheme in ('http', 'https', 's3'): |
| | |
| | return get_from_cache(url_or_filename, cache_dir) |
| | elif os.path.exists(url_or_filename): |
| | |
| | return url_or_filename |
| | elif parsed.scheme == '': |
| | |
| | raise FileNotFoundError("file {} not found".format(url_or_filename)) |
| | else: |
| | |
| | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) |
| |
|
| |
|
| | def split_s3_path(url: str) -> Tuple[str, str]: |
| | """Split a full s3 path into the bucket name and path.""" |
| | parsed = urlparse(url) |
| | if not parsed.netloc or not parsed.path: |
| | raise ValueError("bad s3 path {}".format(url)) |
| | bucket_name = parsed.netloc |
| | s3_path = parsed.path |
| | |
| | if s3_path.startswith("/"): |
| | s3_path = s3_path[1:] |
| | return bucket_name, s3_path |
| |
|
| |
|
| | def s3_request(func: Callable): |
| | """ |
| | Wrapper function for s3 requests in order to create more helpful error |
| | messages. |
| | """ |
| |
|
| | @wraps(func) |
| | def wrapper(url: str, *args, **kwargs): |
| | try: |
| | return func(url, *args, **kwargs) |
| | except ClientError as exc: |
| | if int(exc.response["Error"]["Code"]) == 404: |
| | raise FileNotFoundError("file {} not found".format(url)) |
| | else: |
| | raise |
| |
|
| | return wrapper |
| |
|
| |
|
| | @s3_request |
| | def s3_etag(url: str) -> Optional[str]: |
| | """Check ETag on S3 object.""" |
| | s3_resource = boto3.resource("s3") |
| | bucket_name, s3_path = split_s3_path(url) |
| | s3_object = s3_resource.Object(bucket_name, s3_path) |
| | return s3_object.e_tag |
| |
|
| |
|
| | @s3_request |
| | def s3_get(url: str, temp_file: IO) -> None: |
| | """Pull a file directly from S3.""" |
| | s3_resource = boto3.resource("s3") |
| | bucket_name, s3_path = split_s3_path(url) |
| | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) |
| |
|
| |
|
| | def http_get(url: str, temp_file: IO) -> None: |
| | req = requests.get(url, stream=True) |
| | content_length = req.headers.get('Content-Length') |
| | total = int(content_length) if content_length is not None else None |
| | progress = tqdm(unit="B", total=total) |
| | for chunk in req.iter_content(chunk_size=1024): |
| | if chunk: |
| | progress.update(len(chunk)) |
| | temp_file.write(chunk) |
| | progress.close() |
| |
|
| |
|
| | def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: |
| | """ |
| | Given a URL, look for the corresponding dataset in the local cache. |
| | If it's not there, download it. Then return the path to the cached file. |
| | """ |
| | if cache_dir is None: |
| | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| | if isinstance(cache_dir, Path): |
| | cache_dir = str(cache_dir) |
| |
|
| | os.makedirs(cache_dir, exist_ok=True) |
| |
|
| | |
| | if url.startswith("s3://"): |
| | etag = s3_etag(url) |
| | else: |
| | response = requests.head(url, allow_redirects=True) |
| | if response.status_code != 200: |
| | raise IOError("HEAD request failed for url {} with status code {}" |
| | .format(url, response.status_code)) |
| | etag = response.headers.get("ETag") |
| |
|
| | filename = url_to_filename(url, etag) |
| |
|
| | |
| | cache_path = os.path.join(cache_dir, filename) |
| |
|
| | if not os.path.exists(cache_path): |
| | |
| | |
| | with tempfile.NamedTemporaryFile() as temp_file: |
| | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) |
| |
|
| | |
| | if url.startswith("s3://"): |
| | s3_get(url, temp_file) |
| | else: |
| | http_get(url, temp_file) |
| |
|
| | |
| | temp_file.flush() |
| | |
| | temp_file.seek(0) |
| |
|
| | logger.info("copying %s to cache at %s", temp_file.name, cache_path) |
| | with open(cache_path, 'wb') as cache_file: |
| | shutil.copyfileobj(temp_file, cache_file) |
| |
|
| | logger.info("creating metadata file for %s", cache_path) |
| | meta = {'url': url, 'etag': etag} |
| | meta_path = cache_path + '.json' |
| | with open(meta_path, 'w') as meta_file: |
| | json.dump(meta, meta_file) |
| |
|
| | logger.info("removing temp file %s", temp_file.name) |
| |
|
| | return cache_path |
| |
|
| |
|
| | def read_set_from_file(filename: str) -> Set[str]: |
| | ''' |
| | Extract a de-duped collection (set) of text from a file. |
| | Expected file format is one item per line. |
| | ''' |
| | collection = set() |
| | with open(filename, 'r', encoding='utf-8') as file_: |
| | for line in file_: |
| | collection.add(line.rstrip()) |
| | return collection |
| |
|
| |
|
| | def get_file_extension(path: str, dot=True, lower: bool = True): |
| | ext = os.path.splitext(path)[1] |
| | ext = ext if dot else ext[1:] |
| | return ext.lower() if lower else ext |
| |
|