diff --git a/.gitattributes b/.gitattributes index fa85d7b4b731dcd08bbecc23adac0d96215ceb33..2cef747209554e21db58b2502f4d197f2a81c5f2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,3 +3,6 @@ *.h5 filter=lfs diff=lfs merge=lfs -text *.pth filter=lfs diff=lfs merge=lfs -text *.tar filter=lfs diff=lfs merge=lfs -text +model_output/incremental_1_logs/tokenizer.model filter=lfs diff=lfs merge=lfs -text +merged_tinyllama_logger/tokenizer.model filter=lfs diff=lfs merge=lfs -text +model_output/incremental_1_logs/checkpoint-575/scheduler.pt filter=lfs diff=lfs merge=lfs -text diff --git a/merged_tinyllama_logger/tokenizer.model b/merged_tinyllama_logger/tokenizer.model new file mode 100644 index 0000000000000000000000000000000000000000..6c00c742ce03c627d6cd5b795984876fa49fa899 --- /dev/null +++ b/merged_tinyllama_logger/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347 +size 499723 diff --git a/model_output/incremental_1_logs/checkpoint-575/rng_state.pth b/model_output/incremental_1_logs/checkpoint-575/rng_state.pth new file mode 100644 index 0000000000000000000000000000000000000000..fdc46132cd31424ec7440ac6848ed2dc91fe2560 --- /dev/null +++ b/model_output/incremental_1_logs/checkpoint-575/rng_state.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db60df842ad12cea70ac8238ba2be455622ac32e26d4cb842a044126eddafb9d +size 14244 diff --git a/model_output/incremental_1_logs/checkpoint-575/scheduler.pt b/model_output/incremental_1_logs/checkpoint-575/scheduler.pt new file mode 100644 index 0000000000000000000000000000000000000000..7cfca5462f24863713f7cc1b617ee8de6d6fa7e8 --- /dev/null +++ b/model_output/incremental_1_logs/checkpoint-575/scheduler.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e164bde1413c3212f149a47133d7d6fc680e60537dbd755c476e37f7a92bf822 +size 1064 diff --git a/model_output/incremental_1_logs/tokenizer.model b/model_output/incremental_1_logs/tokenizer.model new file mode 100644 index 0000000000000000000000000000000000000000..6c00c742ce03c627d6cd5b795984876fa49fa899 --- /dev/null +++ b/model_output/incremental_1_logs/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347 +size 499723 diff --git a/model_output/incremental_1_logs/training_args.bin b/model_output/incremental_1_logs/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..18f5b155265f235de6e730f68a99508b4627e0f5 --- /dev/null +++ b/model_output/incremental_1_logs/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:290f55cf786a3f16cb2d84b8109bbdadf1005b363093d17cbd0bdbdb9cb95fdd +size 5176 diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d49ab5f94d98dac7be4ab16751004dfd0ef8c006 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/_cli_utils.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/_cli_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec4892e00d5dc5c52d3a078359f5e54902bfc356 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/_cli_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/upload.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/upload.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4ab774217cc66a42611a537e922aa98ddcc45a8 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/upload.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/upload_large_folder.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/upload_large_folder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..124c5ed7ca845460765dd90451f24bbd3174fddb Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/cli/__pycache__/upload_large_folder.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49d088214505b9604964ab142e7f8a5b38ccd5ef --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import _SubParsersAction + + +class BaseHuggingfaceCLICommand(ABC): + @staticmethod + @abstractmethod + def register_subcommand(parser: _SubParsersAction): + raise NotImplementedError() + + @abstractmethod + def run(self): + raise NotImplementedError() diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42bf3d9a0b3f5e2d9e124fc4bea667e7cc70664e Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c499bc1fec2df3bede1b3224e3253b05fca935fe Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/_cli_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7086af7f83544c9d97fd72c9c8bc67db653f13d1 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/delete_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/download.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/download.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..610909fac274375501e492182616693adb034198 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/download.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/env.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8166470a901574b437b599bfe9f9642799557a38 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/env.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..625f2fa6b6205e0830a8eea9c66f5e4da7ba299a Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/huggingface_cli.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3b3b6bcac10824d70e0d50fd9d2007cb95dc166 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/lfs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/repo.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/repo.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b675d0cd67968004aab4d5526444d07abbd27cf0 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/repo.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/repo_files.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/repo_files.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ea15addd2a44afe6ee066edd72156da9595e5d5 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/repo_files.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2384c87b4c4b50e02a59d3275d3447eaa67d129d Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/scan_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/tag.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/tag.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5b027b7de1ff20f0297d747bbd66b2fcd3af842 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/tag.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..729bb112838ef5d780f975fcbf6a50bceb157ddc Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/upload.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/upload_large_folder.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/upload_large_folder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..612eedfd8c4543c7283b6f627af81d1dbc981b08 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/upload_large_folder.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/user.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/user.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a33be49f3e71cbf3afa644fcf3a7b2ed99c54cc Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/user.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/version.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13f241e9de08b6e9b667f98c59a0826f583b815a Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/commands/__pycache__/version.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/_cli_utils.py b/phivenv/Lib/site-packages/huggingface_hub/commands/_cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bf4a1c0373b4d4bb71a3f4e8ea39da5a01cc79a7 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/_cli_utils.py @@ -0,0 +1,74 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a utility for good-looking prints.""" + +import os +from typing import List, Union + + +class ANSI: + """ + Helper for en.wikipedia.org/wiki/ANSI_escape_code + """ + + _bold = "\u001b[1m" + _gray = "\u001b[90m" + _red = "\u001b[31m" + _reset = "\u001b[0m" + _yellow = "\u001b[33m" + + @classmethod + def bold(cls, s: str) -> str: + return cls._format(s, cls._bold) + + @classmethod + def gray(cls, s: str) -> str: + return cls._format(s, cls._gray) + + @classmethod + def red(cls, s: str) -> str: + return cls._format(s, cls._bold + cls._red) + + @classmethod + def yellow(cls, s: str) -> str: + return cls._format(s, cls._yellow) + + @classmethod + def _format(cls, s: str, code: str) -> str: + if os.environ.get("NO_COLOR"): + # See https://no-color.org/ + return s + return f"{code}{s}{cls._reset}" + + +def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: + """ + Inspired by: + + - stackoverflow.com/a/8356620/593036 + - stackoverflow.com/questions/9535954/printing-lists-as-tabular-data + """ + col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] + row_format = ("{{:{}}} " * len(headers)).format(*col_widths) + lines = [] + lines.append(row_format.format(*headers)) + lines.append(row_format.format(*["-" * w for w in col_widths])) + for row in rows: + lines.append(row_format.format(*row)) + return "\n".join(lines) + + +def show_deprecation_warning(old_command: str, new_command: str): + """Show a yellow warning about deprecated CLI command.""" + print(ANSI.yellow(f"⚠️ Warning: '{old_command}' is deprecated. Use '{new_command}' instead.")) diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/delete_cache.py b/phivenv/Lib/site-packages/huggingface_hub/commands/delete_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..25fffd07eef4d2dd92d22fc698cf965435c9c66a --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/delete_cache.py @@ -0,0 +1,476 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to delete some revisions from the HF cache directory. + +Usage: + huggingface-cli delete-cache + huggingface-cli delete-cache --disable-tui + huggingface-cli delete-cache --dir ~/.cache/huggingface/hub + huggingface-cli delete-cache --sort=size + +NOTE: + This command is based on `InquirerPy` to build the multiselect menu in the terminal. + This dependency has to be installed with `pip install huggingface_hub[cli]`. Since + we want to avoid as much as possible cross-platform issues, I chose a library that + is built on top of `python-prompt-toolkit` which seems to be a reference in terminal + GUI (actively maintained on both Unix and Windows, 7.9k stars). + + For the moment, the TUI feature is in beta. + + See: + - https://github.com/kazhala/InquirerPy + - https://inquirerpy.readthedocs.io/en/latest/ + - https://github.com/prompt-toolkit/python-prompt-toolkit + + Other solutions could have been: + - `simple_term_menu`: would be good as well for our use case but some issues suggest + that Windows is less supported. + See: https://github.com/IngoMeyer441/simple-term-menu + - `PyInquirer`: very similar to `InquirerPy` but older and not maintained anymore. + In particular, no support of Python3.10. + See: https://github.com/CITGuru/PyInquirer + - `pick` (or `pickpack`): easy to use and flexible but built on top of Python's + standard library `curses` that is specific to Unix (not implemented on Windows). + See https://github.com/wong2/pick and https://github.com/anafvana/pickpack. + - `inquirer`: lot of traction (700 stars) but explicitly states "experimental + support of Windows". Not built on top of `python-prompt-toolkit`. + See https://github.com/magmax/python-inquirer + +TODO: add support for `huggingface-cli delete-cache aaaaaa bbbbbb cccccc (...)` ? +TODO: add "--keep-last" arg to delete revisions that are not on `main` ref +TODO: add "--filter" arg to filter repositories by name ? +TODO: add "--limit" arg to limit to X repos ? +TODO: add "-y" arg for immediate deletion ? +See discussions in https://github.com/huggingface/huggingface_hub/issues/1025. +""" + +import os +from argparse import Namespace, _SubParsersAction +from functools import wraps +from tempfile import mkstemp +from typing import Any, Callable, Iterable, List, Literal, Optional, Union + +from ..utils import CachedRepoInfo, CachedRevisionInfo, HFCacheInfo, scan_cache_dir +from . import BaseHuggingfaceCLICommand +from ._cli_utils import ANSI, show_deprecation_warning + + +try: + from InquirerPy import inquirer + from InquirerPy.base.control import Choice + from InquirerPy.separator import Separator + + _inquirer_py_available = True +except ImportError: + _inquirer_py_available = False + +SortingOption_T = Literal["alphabetical", "lastUpdated", "lastUsed", "size"] + + +def require_inquirer_py(fn: Callable) -> Callable: + """Decorator to flag methods that require `InquirerPy`.""" + + # TODO: refactor this + imports in a unified pattern across codebase + @wraps(fn) + def _inner(*args, **kwargs): + if not _inquirer_py_available: + raise ImportError( + "The `delete-cache` command requires extra dependencies to work with" + " the TUI.\nPlease run `pip install huggingface_hub[cli]` to install" + " them.\nOtherwise, disable TUI using the `--disable-tui` flag." + ) + + return fn(*args, **kwargs) + + return _inner + + +# Possibility for the user to cancel deletion +_CANCEL_DELETION_STR = "CANCEL_DELETION" + + +class DeleteCacheCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + delete_cache_parser = parser.add_parser("delete-cache", help="Delete revisions from the cache directory.") + + delete_cache_parser.add_argument( + "--dir", + type=str, + default=None, + help="cache directory (optional). Default to the default HuggingFace cache.", + ) + + delete_cache_parser.add_argument( + "--disable-tui", + action="store_true", + help=( + "Disable Terminal User Interface (TUI) mode. Useful if your" + " platform/terminal doesn't support the multiselect menu." + ), + ) + + delete_cache_parser.add_argument( + "--sort", + nargs="?", + choices=["alphabetical", "lastUpdated", "lastUsed", "size"], + help=( + "Sort repositories by the specified criteria. Options: " + "'alphabetical' (A-Z), " + "'lastUpdated' (newest first), " + "'lastUsed' (most recent first), " + "'size' (largest first)." + ), + ) + + delete_cache_parser.set_defaults(func=DeleteCacheCommand) + + def __init__(self, args: Namespace) -> None: + self.cache_dir: Optional[str] = args.dir + self.disable_tui: bool = args.disable_tui + self.sort_by: Optional[SortingOption_T] = args.sort + + def run(self): + """Run `delete-cache` command with or without TUI.""" + show_deprecation_warning("huggingface-cli delete-cache", "hf cache delete") + + # Scan cache directory + hf_cache_info = scan_cache_dir(self.cache_dir) + + # Manual review from the user + if self.disable_tui: + selected_hashes = _manual_review_no_tui(hf_cache_info, preselected=[], sort_by=self.sort_by) + else: + selected_hashes = _manual_review_tui(hf_cache_info, preselected=[], sort_by=self.sort_by) + + # If deletion is not cancelled + if len(selected_hashes) > 0 and _CANCEL_DELETION_STR not in selected_hashes: + confirm_message = _get_expectations_str(hf_cache_info, selected_hashes) + " Confirm deletion ?" + + # Confirm deletion + if self.disable_tui: + confirmed = _ask_for_confirmation_no_tui(confirm_message) + else: + confirmed = _ask_for_confirmation_tui(confirm_message) + + # Deletion is confirmed + if confirmed: + strategy = hf_cache_info.delete_revisions(*selected_hashes) + print("Start deletion.") + strategy.execute() + print( + f"Done. Deleted {len(strategy.repos)} repo(s) and" + f" {len(strategy.snapshots)} revision(s) for a total of" + f" {strategy.expected_freed_size_str}." + ) + return + + # Deletion is cancelled + print("Deletion is cancelled. Do nothing.") + + +def _get_repo_sorting_key(repo: CachedRepoInfo, sort_by: Optional[SortingOption_T] = None): + if sort_by == "alphabetical": + return (repo.repo_type, repo.repo_id.lower()) # by type then name + elif sort_by == "lastUpdated": + return -max(rev.last_modified for rev in repo.revisions) # newest first + elif sort_by == "lastUsed": + return -repo.last_accessed # most recently used first + elif sort_by == "size": + return -repo.size_on_disk # largest first + else: + return (repo.repo_type, repo.repo_id) # default stable order + + +@require_inquirer_py +def _manual_review_tui( + hf_cache_info: HFCacheInfo, + preselected: List[str], + sort_by: Optional[SortingOption_T] = None, +) -> List[str]: + """Ask the user for a manual review of the revisions to delete. + + Displays a multi-select menu in the terminal (TUI). + """ + # Define multiselect list + choices = _get_tui_choices_from_scan( + repos=hf_cache_info.repos, + preselected=preselected, + sort_by=sort_by, + ) + checkbox = inquirer.checkbox( + message="Select revisions to delete:", + choices=choices, # List of revisions with some pre-selection + cycle=False, # No loop between top and bottom + height=100, # Large list if possible + # We use the instruction to display to the user the expected effect of the + # deletion. + instruction=_get_expectations_str( + hf_cache_info, + selected_hashes=[c.value for c in choices if isinstance(c, Choice) and c.enabled], + ), + # We use the long instruction to should keybindings instructions to the user + long_instruction="Press to select, to validate and to quit without modification.", + # Message that is displayed once the user validates its selection. + transformer=lambda result: f"{len(result)} revision(s) selected.", + ) + + # Add a callback to update the information line when a revision is + # selected/unselected + def _update_expectations(_) -> None: + # Hacky way to dynamically set an instruction message to the checkbox when + # a revision hash is selected/unselected. + checkbox._instruction = _get_expectations_str( + hf_cache_info, + selected_hashes=[choice["value"] for choice in checkbox.content_control.choices if choice["enabled"]], + ) + + checkbox.kb_func_lookup["toggle"].append({"func": _update_expectations}) + + # Finally display the form to the user. + try: + return checkbox.execute() + except KeyboardInterrupt: + return [] # Quit without deletion + + +@require_inquirer_py +def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool: + """Ask for confirmation using Inquirer.""" + return inquirer.confirm(message, default=default).execute() + + +def _get_tui_choices_from_scan( + repos: Iterable[CachedRepoInfo], + preselected: List[str], + sort_by: Optional[SortingOption_T] = None, +) -> List: + """Build a list of choices from the scanned repos. + + Args: + repos (*Iterable[`CachedRepoInfo`]*): + List of scanned repos on which we want to delete revisions. + preselected (*List[`str`]*): + List of revision hashes that will be preselected. + sort_by (*Optional[SortingOption_T]*): + Sorting direction. Choices: "alphabetical", "lastUpdated", "lastUsed", "size". + + Return: + The list of choices to pass to `inquirer.checkbox`. + """ + choices: List[Union[Choice, Separator]] = [] + + # First choice is to cancel the deletion + choices.append( + Choice( + _CANCEL_DELETION_STR, + name="None of the following (if selected, nothing will be deleted).", + enabled=False, + ) + ) + + # Sort repos based on specified criteria + sorted_repos = sorted(repos, key=lambda repo: _get_repo_sorting_key(repo, sort_by)) + + for repo in sorted_repos: + # Repo as separator + choices.append( + Separator( + f"\n{repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str}," + f" used {repo.last_accessed_str})" + ) + ) + for revision in sorted(repo.revisions, key=_revision_sorting_order): + # Revision as choice + choices.append( + Choice( + revision.commit_hash, + name=( + f"{revision.commit_hash[:8]}:" + f" {', '.join(sorted(revision.refs)) or '(detached)'} #" + f" modified {revision.last_modified_str}" + ), + enabled=revision.commit_hash in preselected, + ) + ) + + # Return choices + return choices + + +def _manual_review_no_tui( + hf_cache_info: HFCacheInfo, + preselected: List[str], + sort_by: Optional[SortingOption_T] = None, +) -> List[str]: + """Ask the user for a manual review of the revisions to delete. + + Used when TUI is disabled. Manual review happens in a separate tmp file that the + user can manually edit. + """ + # 1. Generate temporary file with delete commands. + fd, tmp_path = mkstemp(suffix=".txt") # suffix to make it easier to find by editors + os.close(fd) + + lines = [] + + sorted_repos = sorted(hf_cache_info.repos, key=lambda repo: _get_repo_sorting_key(repo, sort_by)) + + for repo in sorted_repos: + lines.append( + f"\n# {repo.repo_type.capitalize()} {repo.repo_id} ({repo.size_on_disk_str}," + f" used {repo.last_accessed_str})" + ) + for revision in sorted(repo.revisions, key=_revision_sorting_order): + lines.append( + # Deselect by prepending a '#' + f"{'' if revision.commit_hash in preselected else '#'} " + f" {revision.commit_hash} # Refs:" + # Print `refs` as comment on same line + f" {', '.join(sorted(revision.refs)) or '(detached)'} # modified" + # Print `last_modified` as comment on same line + f" {revision.last_modified_str}" + ) + + with open(tmp_path, "w") as f: + f.write(_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS) + f.write("\n".join(lines)) + + # 2. Prompt instructions to user. + instructions = f""" + TUI is disabled. In order to select which revisions you want to delete, please edit + the following file using the text editor of your choice. Instructions for manual + editing are located at the beginning of the file. Edit the file, save it and confirm + to continue. + File to edit: {ANSI.bold(tmp_path)} + """ + print("\n".join(line.strip() for line in instructions.strip().split("\n"))) + + # 3. Wait for user confirmation. + while True: + selected_hashes = _read_manual_review_tmp_file(tmp_path) + if _ask_for_confirmation_no_tui( + _get_expectations_str(hf_cache_info, selected_hashes) + " Continue ?", + default=False, + ): + break + + # 4. Return selected_hashes sorted to maintain stable order + os.remove(tmp_path) + return sorted(selected_hashes) # Sort to maintain stable order + + +def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool: + """Ask for confirmation using pure-python.""" + YES = ("y", "yes", "1") + NO = ("n", "no", "0") + DEFAULT = "" + ALL = YES + NO + (DEFAULT,) + full_message = message + (" (Y/n) " if default else " (y/N) ") + while True: + answer = input(full_message).lower() + if answer == DEFAULT: + return default + if answer in YES: + return True + if answer in NO: + return False + print(f"Invalid input. Must be one of {ALL}") + + +def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str]) -> str: + """Format a string to display to the user how much space would be saved. + + Example: + ``` + >>> _get_expectations_str(hf_cache_info, selected_hashes) + '7 revisions selected counting for 4.3G.' + ``` + """ + if _CANCEL_DELETION_STR in selected_hashes: + return "Nothing will be deleted." + strategy = hf_cache_info.delete_revisions(*selected_hashes) + return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}." + + +def _read_manual_review_tmp_file(tmp_path: str) -> List[str]: + """Read the manually reviewed instruction file and return a list of revision hash. + + Example: + ```txt + # This is the tmp file content + ### + + # Commented out line + 123456789 # revision hash + + # Something else + # a_newer_hash # 2 days ago + an_older_hash # 3 days ago + ``` + + ```py + >>> _read_manual_review_tmp_file(tmp_path) + ['123456789', 'an_older_hash'] + ``` + """ + with open(tmp_path) as f: + content = f.read() + + # Split lines + lines = [line.strip() for line in content.split("\n")] + + # Filter commented lines + selected_lines = [line for line in lines if not line.startswith("#")] + + # Select only before comment + selected_hashes = [line.split("#")[0].strip() for line in selected_lines] + + # Return revision hashes + return [hash for hash in selected_hashes if len(hash) > 0] + + +_MANUAL_REVIEW_NO_TUI_INSTRUCTIONS = f""" +# INSTRUCTIONS +# ------------ +# This is a temporary file created by running `huggingface-cli delete-cache` with the +# `--disable-tui` option. It contains a set of revisions that can be deleted from your +# local cache directory. +# +# Please manually review the revisions you want to delete: +# - Revision hashes can be commented out with '#'. +# - Only non-commented revisions in this file will be deleted. +# - Revision hashes that are removed from this file are ignored as well. +# - If `{_CANCEL_DELETION_STR}` line is uncommented, the all cache deletion is cancelled and +# no changes will be applied. +# +# Once you've manually reviewed this file, please confirm deletion in the terminal. This +# file will be automatically removed once done. +# ------------ + +# KILL SWITCH +# ------------ +# Un-comment following line to completely cancel the deletion process +# {_CANCEL_DELETION_STR} +# ------------ + +# REVISIONS +# ------------ +""".strip() + + +def _revision_sorting_order(revision: CachedRevisionInfo) -> Any: + # Sort by last modified (oldest first) + return revision.last_modified diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/download.py b/phivenv/Lib/site-packages/huggingface_hub/commands/download.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd2c1070ead01f9ad6855de3929928d268279c2 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/download.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to download files from the Hub with the CLI. + +Usage: + huggingface-cli download --help + + # Download file + huggingface-cli download gpt2 config.json + + # Download entire repo + huggingface-cli download fffiloni/zeroscope --repo-type=space --revision=refs/pr/78 + + # Download repo with filters + huggingface-cli download gpt2 --include="*.safetensors" + + # Download with token + huggingface-cli download Wauplin/private-model --token=hf_*** + + # Download quietly (no progress bar, no warnings, only the returned path) + huggingface-cli download gpt2 config.json --quiet + + # Download to local dir + huggingface-cli download gpt2 --local-dir=./models/gpt2 +""" + +import warnings +from argparse import Namespace, _SubParsersAction +from typing import List, Optional + +from huggingface_hub import logging +from huggingface_hub._snapshot_download import snapshot_download +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub.utils import disable_progress_bars, enable_progress_bars + +from ._cli_utils import show_deprecation_warning + + +logger = logging.get_logger(__name__) + + +class DownloadCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + download_parser = parser.add_parser("download", help="Download files from the Hub") + download_parser.add_argument( + "repo_id", type=str, help="ID of the repo to download from (e.g. `username/repo-name`)." + ) + download_parser.add_argument( + "filenames", type=str, nargs="*", help="Files to download (e.g. `config.json`, `data/metadata.jsonl`)." + ) + download_parser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Type of repo to download from (defaults to 'model').", + ) + download_parser.add_argument( + "--revision", + type=str, + help="An optional Git revision id which can be a branch name, a tag, or a commit hash.", + ) + download_parser.add_argument( + "--include", nargs="*", type=str, help="Glob patterns to match files to download." + ) + download_parser.add_argument( + "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to download." + ) + download_parser.add_argument( + "--cache-dir", type=str, help="Path to the directory where to save the downloaded files." + ) + download_parser.add_argument( + "--local-dir", + type=str, + help=( + "If set, the downloaded file will be placed under this directory. Check out" + " https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-local-folder for more" + " details." + ), + ) + download_parser.add_argument( + "--local-dir-use-symlinks", + choices=["auto", "True", "False"], + help=("Deprecated and ignored. Downloading to a local directory does not use symlinks anymore."), + ) + download_parser.add_argument( + "--force-download", + action="store_true", + help="If True, the files will be downloaded even if they are already cached.", + ) + download_parser.add_argument( + "--resume-download", + action="store_true", + help="Deprecated and ignored. Downloading a file to local dir always attempts to resume previously interrupted downloads (unless hf-transfer is enabled).", + ) + download_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + download_parser.add_argument( + "--quiet", + action="store_true", + help="If True, progress bars are disabled and only the path to the download files is printed.", + ) + download_parser.add_argument( + "--max-workers", + type=int, + default=8, + help="Maximum number of workers to use for downloading files. Default is 8.", + ) + download_parser.set_defaults(func=DownloadCommand) + + def __init__(self, args: Namespace) -> None: + self.token = args.token + self.repo_id: str = args.repo_id + self.filenames: List[str] = args.filenames + self.repo_type: str = args.repo_type + self.revision: Optional[str] = args.revision + self.include: Optional[List[str]] = args.include + self.exclude: Optional[List[str]] = args.exclude + self.cache_dir: Optional[str] = args.cache_dir + self.local_dir: Optional[str] = args.local_dir + self.force_download: bool = args.force_download + self.resume_download: Optional[bool] = args.resume_download or None + self.quiet: bool = args.quiet + self.max_workers: int = args.max_workers + + if args.local_dir_use_symlinks is not None: + warnings.warn( + "Ignoring --local-dir-use-symlinks. Downloading to a local directory does not use symlinks anymore.", + FutureWarning, + ) + + def run(self) -> None: + show_deprecation_warning("huggingface-cli download", "hf download") + + if self.quiet: + disable_progress_bars() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + print(self._download()) # Print path to downloaded files + enable_progress_bars() + else: + logging.set_verbosity_info() + print(self._download()) # Print path to downloaded files + logging.set_verbosity_warning() + + def _download(self) -> str: + # Warn user if patterns are ignored + if len(self.filenames) > 0: + if self.include is not None and len(self.include) > 0: + warnings.warn("Ignoring `--include` since filenames have being explicitly set.") + if self.exclude is not None and len(self.exclude) > 0: + warnings.warn("Ignoring `--exclude` since filenames have being explicitly set.") + + # Single file to download: use `hf_hub_download` + if len(self.filenames) == 1: + return hf_hub_download( + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + filename=self.filenames[0], + cache_dir=self.cache_dir, + resume_download=self.resume_download, + force_download=self.force_download, + token=self.token, + local_dir=self.local_dir, + library_name="huggingface-cli", + ) + + # Otherwise: use `snapshot_download` to ensure all files comes from same revision + elif len(self.filenames) == 0: + allow_patterns = self.include + ignore_patterns = self.exclude + else: + allow_patterns = self.filenames + ignore_patterns = None + + return snapshot_download( + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + resume_download=self.resume_download, + force_download=self.force_download, + cache_dir=self.cache_dir, + token=self.token, + local_dir=self.local_dir, + library_name="huggingface-cli", + max_workers=self.max_workers, + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/env.py b/phivenv/Lib/site-packages/huggingface_hub/commands/env.py new file mode 100644 index 0000000000000000000000000000000000000000..ad674738b2f137ec0b79c11ef35057a351de6d86 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/env.py @@ -0,0 +1,39 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to print information about the environment. + +Usage: + huggingface-cli env +""" + +from argparse import _SubParsersAction + +from ..utils import dump_environment_info +from . import BaseHuggingfaceCLICommand +from ._cli_utils import show_deprecation_warning + + +class EnvironmentCommand(BaseHuggingfaceCLICommand): + def __init__(self, args): + self.args = args + + @staticmethod + def register_subcommand(parser: _SubParsersAction): + env_parser = parser.add_parser("env", help="Print information about the environment.") + env_parser.set_defaults(func=EnvironmentCommand) + + def run(self) -> None: + show_deprecation_warning("huggingface-cli env", "hf env") + + dump_environment_info() diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/huggingface_cli.py b/phivenv/Lib/site-packages/huggingface_hub/commands/huggingface_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..697c85d1e386d9c954be0f8112cb12e1bc84e7fe --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/huggingface_cli.py @@ -0,0 +1,65 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser + +from huggingface_hub.commands._cli_utils import show_deprecation_warning +from huggingface_hub.commands.delete_cache import DeleteCacheCommand +from huggingface_hub.commands.download import DownloadCommand +from huggingface_hub.commands.env import EnvironmentCommand +from huggingface_hub.commands.lfs import LfsCommands +from huggingface_hub.commands.repo import RepoCommands +from huggingface_hub.commands.repo_files import RepoFilesCommand +from huggingface_hub.commands.scan_cache import ScanCacheCommand +from huggingface_hub.commands.tag import TagCommands +from huggingface_hub.commands.upload import UploadCommand +from huggingface_hub.commands.upload_large_folder import UploadLargeFolderCommand +from huggingface_hub.commands.user import UserCommands +from huggingface_hub.commands.version import VersionCommand + + +def main(): + parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") + commands_parser = parser.add_subparsers(help="huggingface-cli command helpers") + + # Register commands + DownloadCommand.register_subcommand(commands_parser) + UploadCommand.register_subcommand(commands_parser) + RepoFilesCommand.register_subcommand(commands_parser) + EnvironmentCommand.register_subcommand(commands_parser) + UserCommands.register_subcommand(commands_parser) + RepoCommands.register_subcommand(commands_parser) + LfsCommands.register_subcommand(commands_parser) + ScanCacheCommand.register_subcommand(commands_parser) + DeleteCacheCommand.register_subcommand(commands_parser) + TagCommands.register_subcommand(commands_parser) + VersionCommand.register_subcommand(commands_parser) + + # Experimental + UploadLargeFolderCommand.register_subcommand(commands_parser) + + # Let's go + args = parser.parse_args() + if not hasattr(args, "func"): + show_deprecation_warning("huggingface-cli", "hf") + parser.print_help() + exit(1) + + # Run + service = args.func(args) + service.run() + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/lfs.py b/phivenv/Lib/site-packages/huggingface_hub/commands/lfs.py new file mode 100644 index 0000000000000000000000000000000000000000..e510e345e6a4bf6da03f71b35cbfa2a4f0eb7325 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/lfs.py @@ -0,0 +1,200 @@ +""" +Implementation of a custom transfer agent for the transfer type "multipart" for +git-lfs. + +Inspired by: +github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py + +Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md + + +To launch debugger while developing: + +``` [lfs "customtransfer.multipart"] +path = /path/to/huggingface_hub/.env/bin/python args = -m debugpy --listen 5678 +--wait-for-client +/path/to/huggingface_hub/src/huggingface_hub/commands/huggingface_cli.py +lfs-multipart-upload ```""" + +import json +import os +import subprocess +import sys +from argparse import _SubParsersAction +from typing import Dict, List, Optional + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND + +from ..utils import get_session, hf_raise_for_status, logging +from ..utils._lfs import SliceFileObj + + +logger = logging.get_logger(__name__) + + +class LfsCommands(BaseHuggingfaceCLICommand): + """ + Implementation of a custom transfer agent for the transfer type "multipart" + for git-lfs. This lets users upload large files >5GB 🔥. Spec for LFS custom + transfer agent is: + https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md + + This introduces two commands to the CLI: + + 1. $ huggingface-cli lfs-enable-largefiles + + This should be executed once for each model repo that contains a model file + >5GB. It's documented in the error message you get if you just try to git + push a 5GB file without having enabled it before. + + 2. $ huggingface-cli lfs-multipart-upload + + This command is called by lfs directly and is not meant to be called by the + user. + """ + + @staticmethod + def register_subcommand(parser: _SubParsersAction): + enable_parser = parser.add_parser( + "lfs-enable-largefiles", help="Configure your repository to enable upload of files > 5GB." + ) + enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.") + enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args)) + + # Command will get called by git-lfs, do not call it directly. + upload_parser = parser.add_parser(LFS_MULTIPART_UPLOAD_COMMAND, add_help=False) + upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args)) + + +class LfsEnableCommand: + def __init__(self, args): + self.args = args + + def run(self): + local_path = os.path.abspath(self.args.path) + if not os.path.isdir(local_path): + print("This does not look like a valid git repo.") + exit(1) + subprocess.run( + "git config lfs.customtransfer.multipart.path huggingface-cli".split(), + check=True, + cwd=local_path, + ) + subprocess.run( + f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(), + check=True, + cwd=local_path, + ) + print("Local repo set up for largefiles") + + +def write_msg(msg: Dict): + """Write out the message in Line delimited JSON.""" + msg_str = json.dumps(msg) + "\n" + sys.stdout.write(msg_str) + sys.stdout.flush() + + +def read_msg() -> Optional[Dict]: + """Read Line delimited JSON from stdin.""" + msg = json.loads(sys.stdin.readline().strip()) + + if "terminate" in (msg.get("type"), msg.get("event")): + # terminate message received + return None + + if msg.get("event") not in ("download", "upload"): + logger.critical("Received unexpected message") + sys.exit(1) + + return msg + + +class LfsUploadCommand: + def __init__(self, args) -> None: + self.args = args + + def run(self) -> None: + # Immediately after invoking a custom transfer process, git-lfs + # sends initiation data to the process over stdin. + # This tells the process useful information about the configuration. + init_msg = json.loads(sys.stdin.readline().strip()) + if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"): + write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}}) + sys.exit(1) + + # The transfer process should use the information it needs from the + # initiation structure, and also perform any one-off setup tasks it + # needs to do. It should then respond on stdout with a simple empty + # confirmation structure, as follows: + write_msg({}) + + # After the initiation exchange, git-lfs will send any number of + # transfer requests to the stdin of the transfer process, in a serial sequence. + while True: + msg = read_msg() + if msg is None: + # When all transfers have been processed, git-lfs will send + # a terminate event to the stdin of the transfer process. + # On receiving this message the transfer process should + # clean up and terminate. No response is expected. + sys.exit(0) + + oid = msg["oid"] + filepath = msg["path"] + completion_url = msg["action"]["href"] + header = msg["action"]["header"] + chunk_size = int(header.pop("chunk_size")) + presigned_urls: List[str] = list(header.values()) + + # Send a "started" progress event to allow other workers to start. + # Otherwise they're delayed until first "progress" event is reported, + # i.e. after the first 5GB by default (!) + write_msg( + { + "event": "progress", + "oid": oid, + "bytesSoFar": 1, + "bytesSinceLast": 0, + } + ) + + parts = [] + with open(filepath, "rb") as file: + for i, presigned_url in enumerate(presigned_urls): + with SliceFileObj( + file, + seek_from=i * chunk_size, + read_limit=chunk_size, + ) as data: + r = get_session().put(presigned_url, data=data) + hf_raise_for_status(r) + parts.append( + { + "etag": r.headers.get("etag"), + "partNumber": i + 1, + } + ) + # In order to support progress reporting while data is uploading / downloading, + # the transfer process should post messages to stdout + write_msg( + { + "event": "progress", + "oid": oid, + "bytesSoFar": (i + 1) * chunk_size, + "bytesSinceLast": chunk_size, + } + ) + # Not precise but that's ok. + + r = get_session().post( + completion_url, + json={ + "oid": oid, + "parts": parts, + }, + ) + hf_raise_for_status(r) + + write_msg({"event": "complete", "oid": oid}) diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/repo.py b/phivenv/Lib/site-packages/huggingface_hub/commands/repo.py new file mode 100644 index 0000000000000000000000000000000000000000..fe75349d67bdc0314afe737daa7224b2a090f810 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/repo.py @@ -0,0 +1,151 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to interact with repositories on the Hugging Face Hub. + +Usage: + # create a new dataset repo on the Hub + huggingface-cli repo create my-cool-dataset --repo-type=dataset + + # create a private model repo on the Hub + huggingface-cli repo create my-cool-model --private +""" + +import argparse +from argparse import _SubParsersAction +from typing import Optional + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.commands._cli_utils import ANSI +from huggingface_hub.constants import SPACES_SDK_TYPES +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import logging + +from ._cli_utils import show_deprecation_warning + + +logger = logging.get_logger(__name__) + + +class RepoCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + repo_parser = parser.add_parser("repo", help="{create} Commands to interact with your huggingface.co repos.") + repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands") + repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co") + repo_create_parser.add_argument( + "repo_id", + type=str, + help="The ID of the repo to create to (e.g. `username/repo-name`). The username is optional and will be set to your username if not provided.", + ) + repo_create_parser.add_argument( + "--repo-type", + type=str, + help='Optional: set to "dataset" or "space" if creating a dataset or space, default is model.', + ) + repo_create_parser.add_argument( + "--space_sdk", + type=str, + help='Optional: Hugging Face Spaces SDK type. Required when --type is set to "space".', + choices=SPACES_SDK_TYPES, + ) + repo_create_parser.add_argument( + "--private", + action="store_true", + help="Whether to create a private repository. Defaults to public unless the organization's default is private.", + ) + repo_create_parser.add_argument( + "--token", + type=str, + help="Hugging Face token. Will default to the locally saved token if not provided.", + ) + repo_create_parser.add_argument( + "--exist-ok", + action="store_true", + help="Do not raise an error if repo already exists.", + ) + repo_create_parser.add_argument( + "--resource-group-id", + type=str, + help="Resource group in which to create the repo. Resource groups is only available for Enterprise Hub organizations.", + ) + repo_create_parser.add_argument( + "--type", + type=str, + help="[Deprecated]: use --repo-type instead.", + ) + repo_create_parser.add_argument( + "-y", + "--yes", + action="store_true", + help="[Deprecated] no effect.", + ) + repo_create_parser.add_argument( + "--organization", type=str, help="[Deprecated] Pass the organization namespace directly in the repo_id." + ) + repo_create_parser.set_defaults(func=lambda args: RepoCreateCommand(args)) + + +class RepoCreateCommand: + def __init__(self, args: argparse.Namespace): + self.repo_id: str = args.repo_id + self.repo_type: Optional[str] = args.repo_type or args.type + self.space_sdk: Optional[str] = args.space_sdk + self.organization: Optional[str] = args.organization + self.yes: bool = args.yes + self.private: bool = args.private + self.token: Optional[str] = args.token + self.exist_ok: bool = args.exist_ok + self.resource_group_id: Optional[str] = args.resource_group_id + + if args.type is not None: + print( + ANSI.yellow( + "The --type argument is deprecated and will be removed in a future version. Use --repo-type instead." + ) + ) + if self.organization is not None: + print( + ANSI.yellow( + "The --organization argument is deprecated and will be removed in a future version. Pass the organization namespace directly in the repo_id." + ) + ) + if self.yes: + print( + ANSI.yellow( + "The --yes argument is deprecated and will be removed in a future version. It does not have any effect." + ) + ) + + self._api = HfApi() + + def run(self): + show_deprecation_warning("huggingface-cli repo", "hf repo") + + if self.organization is not None: + if "/" in self.repo_id: + print(ANSI.red("You cannot pass both --organization and a repo_id with a namespace.")) + exit(1) + self.repo_id = f"{self.organization}/{self.repo_id}" + + repo_url = self._api.create_repo( + repo_id=self.repo_id, + repo_type=self.repo_type, + private=self.private, + token=self.token, + exist_ok=self.exist_ok, + resource_group_id=self.resource_group_id, + space_sdk=self.space_sdk, + ) + print(f"Successfully created {ANSI.bold(repo_url.repo_id)} on the Hub.") + print(f"Your repo is now available at {ANSI.bold(repo_url)}") diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/repo_files.py b/phivenv/Lib/site-packages/huggingface_hub/commands/repo_files.py new file mode 100644 index 0000000000000000000000000000000000000000..da9685315ea67dc9d1e9921ecb2656244cae8783 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/repo_files.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to update or delete files in a repository using the CLI. + +Usage: + # delete all + huggingface-cli repo-files delete "*" + + # delete single file + huggingface-cli repo-files delete file.txt + + # delete single folder + huggingface-cli repo-files delete folder/ + + # delete multiple + huggingface-cli repo-files delete file.txt folder/ file2.txt + + # delete multiple patterns + huggingface-cli repo-files delete file.txt "*.json" "folder/*.parquet" + + # delete from different revision / repo-type + huggingface-cli repo-files delete file.txt --revision=refs/pr/1 --repo-type=dataset +""" + +from argparse import _SubParsersAction +from typing import List, Optional + +from huggingface_hub import logging +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.hf_api import HfApi + +from ._cli_utils import show_deprecation_warning + + +logger = logging.get_logger(__name__) + + +class DeleteFilesSubCommand: + def __init__(self, args) -> None: + self.args = args + self.repo_id: str = args.repo_id + self.repo_type: Optional[str] = args.repo_type + self.revision: Optional[str] = args.revision + self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli") + self.patterns: List[str] = args.patterns + self.commit_message: Optional[str] = args.commit_message + self.commit_description: Optional[str] = args.commit_description + self.create_pr: bool = args.create_pr + self.token: Optional[str] = args.token + + def run(self) -> None: + show_deprecation_warning("huggingface-cli repo-files", "hf repo-files") + + logging.set_verbosity_info() + url = self.api.delete_files( + delete_patterns=self.patterns, + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + commit_message=self.commit_message, + commit_description=self.commit_description, + create_pr=self.create_pr, + ) + print(f"Files correctly deleted from repo. Commit: {url}.") + logging.set_verbosity_warning() + + +class RepoFilesCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + repo_files_parser = parser.add_parser("repo-files", help="Manage files in a repo on the Hub") + repo_files_parser.add_argument( + "repo_id", type=str, help="The ID of the repo to manage (e.g. `username/repo-name`)." + ) + repo_files_subparsers = repo_files_parser.add_subparsers( + help="Action to execute against the files.", + required=True, + ) + delete_subparser = repo_files_subparsers.add_parser( + "delete", + help="Delete files from a repo on the Hub", + ) + delete_subparser.set_defaults(func=lambda args: DeleteFilesSubCommand(args)) + delete_subparser.add_argument( + "patterns", + nargs="+", + type=str, + help="Glob patterns to match files to delete.", + ) + delete_subparser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Type of the repo to upload to (e.g. `dataset`).", + ) + delete_subparser.add_argument( + "--revision", + type=str, + help=( + "An optional Git revision to push to. It can be a branch name " + "or a PR reference. If revision does not" + " exist and `--create-pr` is not set, a branch will be automatically created." + ), + ) + delete_subparser.add_argument( + "--commit-message", type=str, help="The summary / title / first line of the generated commit." + ) + delete_subparser.add_argument( + "--commit-description", type=str, help="The description of the generated commit." + ) + delete_subparser.add_argument( + "--create-pr", action="store_true", help="Whether to create a new Pull Request for these changes." + ) + repo_files_parser.add_argument( + "--token", + type=str, + help="A User Access Token generated from https://huggingface.co/settings/tokens", + ) + + repo_files_parser.set_defaults(func=RepoFilesCommand) diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/scan_cache.py b/phivenv/Lib/site-packages/huggingface_hub/commands/scan_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..45662fb973597493b0ee6121ac86636d2ffaa6c7 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/scan_cache.py @@ -0,0 +1,183 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to scan the HF cache directory. + +Usage: + huggingface-cli scan-cache + huggingface-cli scan-cache -v + huggingface-cli scan-cache -vvv + huggingface-cli scan-cache --dir ~/.cache/huggingface/hub +""" + +import time +from argparse import Namespace, _SubParsersAction +from typing import Optional + +from ..utils import CacheNotFound, HFCacheInfo, scan_cache_dir +from . import BaseHuggingfaceCLICommand +from ._cli_utils import ANSI, show_deprecation_warning, tabulate + + +class ScanCacheCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + scan_cache_parser = parser.add_parser("scan-cache", help="Scan cache directory.") + + scan_cache_parser.add_argument( + "--dir", + type=str, + default=None, + help="cache directory to scan (optional). Default to the default HuggingFace cache.", + ) + scan_cache_parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="show a more verbose output", + ) + scan_cache_parser.set_defaults(func=ScanCacheCommand) + + def __init__(self, args: Namespace) -> None: + self.verbosity: int = args.verbose + self.cache_dir: Optional[str] = args.dir + + def run(self): + show_deprecation_warning("huggingface-cli scan-cache", "hf cache scan") + + try: + t0 = time.time() + hf_cache_info = scan_cache_dir(self.cache_dir) + t1 = time.time() + except CacheNotFound as exc: + cache_dir = exc.cache_dir + print(f"Cache directory not found: {cache_dir}") + return + + self._print_hf_cache_info_as_table(hf_cache_info) + + print( + f"\nDone in {round(t1 - t0, 1)}s. Scanned {len(hf_cache_info.repos)} repo(s)" + f" for a total of {ANSI.red(hf_cache_info.size_on_disk_str)}." + ) + if len(hf_cache_info.warnings) > 0: + message = f"Got {len(hf_cache_info.warnings)} warning(s) while scanning." + if self.verbosity >= 3: + print(ANSI.gray(message)) + for warning in hf_cache_info.warnings: + print(ANSI.gray(warning)) + else: + print(ANSI.gray(message + " Use -vvv to print details.")) + + def _print_hf_cache_info_as_table(self, hf_cache_info: HFCacheInfo) -> None: + print(get_table(hf_cache_info, verbosity=self.verbosity)) + + +def get_table(hf_cache_info: HFCacheInfo, *, verbosity: int = 0) -> str: + """Generate a table from the [`HFCacheInfo`] object. + + Pass `verbosity=0` to get a table with a single row per repo, with columns + "repo_id", "repo_type", "size_on_disk", "nb_files", "last_accessed", "last_modified", "refs", "local_path". + + Pass `verbosity=1` to get a table with a row per repo and revision (thus multiple rows can appear for a single repo), with columns + "repo_id", "repo_type", "revision", "size_on_disk", "nb_files", "last_modified", "refs", "local_path". + + Example: + ```py + >>> from huggingface_hub.utils import scan_cache_dir + >>> from huggingface_hub.commands.scan_cache import get_table + + >>> hf_cache_info = scan_cache_dir() + HFCacheInfo(...) + + >>> print(get_table(hf_cache_info, verbosity=0)) + REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH + --------------------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------------- + roberta-base model 2.7M 5 1 day ago 1 week ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--roberta-base + suno/bark model 8.8K 1 1 week ago 1 week ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--suno--bark + t5-base model 893.8M 4 4 days ago 7 months ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--t5-base + t5-large model 3.0G 4 5 weeks ago 5 months ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--t5-large + + >>> print(get_table(hf_cache_info, verbosity=1)) + REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH + --------------------------------------------------- --------- ---------------------------------------- ------------ -------- ------------- ---- ----------------------------------------------------------------------------------------------------------------------------------------------------- + roberta-base model e2da8e2f811d1448a5b465c236feacd80ffbac7b 2.7M 5 1 week ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--roberta-base\\snapshots\\e2da8e2f811d1448a5b465c236feacd80ffbac7b + suno/bark model 70a8a7d34168586dc5d028fa9666aceade177992 8.8K 1 1 week ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--suno--bark\\snapshots\\70a8a7d34168586dc5d028fa9666aceade177992 + t5-base model a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 893.8M 4 7 months ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--t5-base\\snapshots\\a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 + t5-large model 150ebc2c4b72291e770f58e6057481c8d2ed331a 3.0G 4 5 months ago main C:\\Users\\admin\\.cache\\huggingface\\hub\\models--t5-large\\snapshots\\150ebc2c4b72291e770f58e6057481c8d2ed331a ``` + ``` + + Args: + hf_cache_info ([`HFCacheInfo`]): + The HFCacheInfo object to print. + verbosity (`int`, *optional*): + The verbosity level. Defaults to 0. + + Returns: + `str`: The table as a string. + """ + if verbosity == 0: + return tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + "{:>12}".format(repo.size_on_disk_str), + repo.nb_files, + repo.last_accessed_str, + repo.last_modified_str, + ", ".join(sorted(repo.refs)), + str(repo.repo_path), + ] + for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) + ], + headers=[ + "REPO ID", + "REPO TYPE", + "SIZE ON DISK", + "NB FILES", + "LAST_ACCESSED", + "LAST_MODIFIED", + "REFS", + "LOCAL PATH", + ], + ) + else: + return tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + revision.commit_hash, + "{:>12}".format(revision.size_on_disk_str), + revision.nb_files, + revision.last_modified_str, + ", ".join(sorted(revision.refs)), + str(revision.snapshot_path), + ] + for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) + for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash) + ], + headers=[ + "REPO ID", + "REPO TYPE", + "REVISION", + "SIZE ON DISK", + "NB FILES", + "LAST_MODIFIED", + "REFS", + "LOCAL PATH", + ], + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/tag.py b/phivenv/Lib/site-packages/huggingface_hub/commands/tag.py new file mode 100644 index 0000000000000000000000000000000000000000..405d407f8135d940cf078f905a6e66acd4b1dacc --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/tag.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2024-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains commands to perform tag management with the CLI. + +Usage Examples: + - Create a tag: + $ huggingface-cli tag user/my-model 1.0 --message "First release" + $ huggingface-cli tag user/my-model 1.0 -m "First release" --revision develop + $ huggingface-cli tag user/my-dataset 1.0 -m "First release" --repo-type dataset + $ huggingface-cli tag user/my-space 1.0 + - List all tags: + $ huggingface-cli tag -l user/my-model + $ huggingface-cli tag --list user/my-dataset --repo-type dataset + - Delete a tag: + $ huggingface-cli tag -d user/my-model 1.0 + $ huggingface-cli tag --delete user/my-dataset 1.0 --repo-type dataset + $ huggingface-cli tag -d user/my-space 1.0 -y +""" + +from argparse import Namespace, _SubParsersAction + +from requests.exceptions import HTTPError + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.constants import ( + REPO_TYPES, +) +from huggingface_hub.hf_api import HfApi + +from ..errors import HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError +from ._cli_utils import ANSI, show_deprecation_warning + + +class TagCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + tag_parser = parser.add_parser("tag", help="(create, list, delete) tags for a repo in the hub") + + tag_parser.add_argument("repo_id", type=str, help="The ID of the repo to tag (e.g. `username/repo-name`).") + tag_parser.add_argument("tag", nargs="?", type=str, help="The name of the tag for creation or deletion.") + tag_parser.add_argument("-m", "--message", type=str, help="The description of the tag to create.") + tag_parser.add_argument("--revision", type=str, help="The git revision to tag.") + tag_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens." + ) + tag_parser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Set the type of repository (model, dataset, or space).", + ) + tag_parser.add_argument("-y", "--yes", action="store_true", help="Answer Yes to prompts automatically.") + + tag_parser.add_argument("-l", "--list", action="store_true", help="List tags for a repository.") + tag_parser.add_argument("-d", "--delete", action="store_true", help="Delete a tag for a repository.") + + tag_parser.set_defaults(func=lambda args: handle_commands(args)) + + +def handle_commands(args: Namespace): + show_deprecation_warning("huggingface-cli tag", "hf repo tag") + + if args.list: + return TagListCommand(args) + elif args.delete: + return TagDeleteCommand(args) + else: + return TagCreateCommand(args) + + +class TagCommand: + def __init__(self, args: Namespace): + self.args = args + self.api = HfApi(token=self.args.token) + self.repo_id = self.args.repo_id + self.repo_type = self.args.repo_type + if self.repo_type not in REPO_TYPES: + print("Invalid repo --repo-type") + exit(1) + + +class TagCreateCommand(TagCommand): + def run(self): + print(f"You are about to create tag {ANSI.bold(self.args.tag)} on {self.repo_type} {ANSI.bold(self.repo_id)}") + + try: + self.api.create_tag( + repo_id=self.repo_id, + tag=self.args.tag, + tag_message=self.args.message, + revision=self.args.revision, + repo_type=self.repo_type, + ) + except RepositoryNotFoundError: + print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") + exit(1) + except RevisionNotFoundError: + print(f"Revision {ANSI.bold(self.args.revision)} not found.") + exit(1) + except HfHubHTTPError as e: + if e.response.status_code == 409: + print(f"Tag {ANSI.bold(self.args.tag)} already exists on {ANSI.bold(self.repo_id)}") + exit(1) + raise e + + print(f"Tag {ANSI.bold(self.args.tag)} created on {ANSI.bold(self.repo_id)}") + + +class TagListCommand(TagCommand): + def run(self): + try: + refs = self.api.list_repo_refs( + repo_id=self.repo_id, + repo_type=self.repo_type, + ) + except RepositoryNotFoundError: + print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") + exit(1) + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) + if len(refs.tags) == 0: + print("No tags found") + exit(0) + print(f"Tags for {self.repo_type} {ANSI.bold(self.repo_id)}:") + for tag in refs.tags: + print(tag.name) + + +class TagDeleteCommand(TagCommand): + def run(self): + print(f"You are about to delete tag {ANSI.bold(self.args.tag)} on {self.repo_type} {ANSI.bold(self.repo_id)}") + + if not self.args.yes: + choice = input("Proceed? [Y/n] ").lower() + if choice not in ("", "y", "yes"): + print("Abort") + exit() + try: + self.api.delete_tag(repo_id=self.repo_id, tag=self.args.tag, repo_type=self.repo_type) + except RepositoryNotFoundError: + print(f"{self.repo_type.capitalize()} {ANSI.bold(self.repo_id)} not found.") + exit(1) + except RevisionNotFoundError: + print(f"Tag {ANSI.bold(self.args.tag)} not found on {ANSI.bold(self.repo_id)}") + exit(1) + print(f"Tag {ANSI.bold(self.args.tag)} deleted on {ANSI.bold(self.repo_id)}") diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/upload.py b/phivenv/Lib/site-packages/huggingface_hub/commands/upload.py new file mode 100644 index 0000000000000000000000000000000000000000..c778555cda56eb17c905f0728fef6712acc75cb8 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/upload.py @@ -0,0 +1,318 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to upload a repo or file with the CLI. + +Usage: + # Upload file (implicit) + huggingface-cli upload my-cool-model ./my-cool-model.safetensors + + # Upload file (explicit) + huggingface-cli upload my-cool-model ./my-cool-model.safetensors model.safetensors + + # Upload directory (implicit). If `my-cool-model/` is a directory it will be uploaded, otherwise an exception is raised. + huggingface-cli upload my-cool-model + + # Upload directory (explicit) + huggingface-cli upload my-cool-model ./models/my-cool-model . + + # Upload filtered directory (example: tensorboard logs except for the last run) + huggingface-cli upload my-cool-model ./model/training /logs --include "*.tfevents.*" --exclude "*20230905*" + + # Upload with wildcard + huggingface-cli upload my-cool-model "./model/training/*.safetensors" + + # Upload private dataset + huggingface-cli upload Wauplin/my-cool-dataset ./data . --repo-type=dataset --private + + # Upload with token + huggingface-cli upload Wauplin/my-cool-model --token=hf_**** + + # Sync local Space with Hub (upload new files, delete removed files) + huggingface-cli upload Wauplin/space-example --repo-type=space --exclude="/logs/*" --delete="*" --commit-message="Sync local Space with Hub" + + # Schedule commits every 30 minutes + huggingface-cli upload Wauplin/my-cool-model --every=30 +""" + +import os +import time +import warnings +from argparse import Namespace, _SubParsersAction +from typing import List, Optional + +from huggingface_hub import logging +from huggingface_hub._commit_scheduler import CommitScheduler +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.constants import HF_HUB_ENABLE_HF_TRANSFER +from huggingface_hub.errors import RevisionNotFoundError +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import disable_progress_bars, enable_progress_bars +from huggingface_hub.utils._runtime import is_xet_available + +from ._cli_utils import show_deprecation_warning + + +logger = logging.get_logger(__name__) + + +class UploadCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + upload_parser = parser.add_parser("upload", help="Upload a file or a folder to a repo on the Hub") + upload_parser.add_argument( + "repo_id", type=str, help="The ID of the repo to upload to (e.g. `username/repo-name`)." + ) + upload_parser.add_argument( + "local_path", + nargs="?", + help="Local path to the file or folder to upload. Wildcard patterns are supported. Defaults to current directory.", + ) + upload_parser.add_argument( + "path_in_repo", + nargs="?", + help="Path of the file or folder in the repo. Defaults to the relative path of the file or folder.", + ) + upload_parser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + default="model", + help="Type of the repo to upload to (e.g. `dataset`).", + ) + upload_parser.add_argument( + "--revision", + type=str, + help=( + "An optional Git revision to push to. It can be a branch name or a PR reference. If revision does not" + " exist and `--create-pr` is not set, a branch will be automatically created." + ), + ) + upload_parser.add_argument( + "--private", + action="store_true", + help=( + "Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already" + " exists." + ), + ) + upload_parser.add_argument("--include", nargs="*", type=str, help="Glob patterns to match files to upload.") + upload_parser.add_argument( + "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to upload." + ) + upload_parser.add_argument( + "--delete", + nargs="*", + type=str, + help="Glob patterns for file to be deleted from the repo while committing.", + ) + upload_parser.add_argument( + "--commit-message", type=str, help="The summary / title / first line of the generated commit." + ) + upload_parser.add_argument("--commit-description", type=str, help="The description of the generated commit.") + upload_parser.add_argument( + "--create-pr", action="store_true", help="Whether to upload content as a new Pull Request." + ) + upload_parser.add_argument( + "--every", + type=float, + help="If set, a background job is scheduled to create commits every `every` minutes.", + ) + upload_parser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + upload_parser.add_argument( + "--quiet", + action="store_true", + help="If True, progress bars are disabled and only the path to the uploaded files is printed.", + ) + upload_parser.set_defaults(func=UploadCommand) + + def __init__(self, args: Namespace) -> None: + self.repo_id: str = args.repo_id + self.repo_type: Optional[str] = args.repo_type + self.revision: Optional[str] = args.revision + self.private: bool = args.private + + self.include: Optional[List[str]] = args.include + self.exclude: Optional[List[str]] = args.exclude + self.delete: Optional[List[str]] = args.delete + + self.commit_message: Optional[str] = args.commit_message + self.commit_description: Optional[str] = args.commit_description + self.create_pr: bool = args.create_pr + self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli") + self.quiet: bool = args.quiet # disable warnings and progress bars + + # Check `--every` is valid + if args.every is not None and args.every <= 0: + raise ValueError(f"`every` must be a positive value (got '{args.every}')") + self.every: Optional[float] = args.every + + # Resolve `local_path` and `path_in_repo` + repo_name: str = args.repo_id.split("/")[-1] # e.g. "Wauplin/my-cool-model" => "my-cool-model" + self.local_path: str + self.path_in_repo: str + + if args.local_path is not None and any(c in args.local_path for c in ["*", "?", "["]): + if args.include is not None: + raise ValueError("Cannot set `--include` when passing a `local_path` containing a wildcard.") + if args.path_in_repo is not None and args.path_in_repo != ".": + raise ValueError("Cannot set `path_in_repo` when passing a `local_path` containing a wildcard.") + self.local_path = "." + self.include = args.local_path + self.path_in_repo = "." + elif args.local_path is None and os.path.isfile(repo_name): + # Implicit case 1: user provided only a repo_id which happen to be a local file as well => upload it with same name + self.local_path = repo_name + self.path_in_repo = repo_name + elif args.local_path is None and os.path.isdir(repo_name): + # Implicit case 2: user provided only a repo_id which happen to be a local folder as well => upload it at root + self.local_path = repo_name + self.path_in_repo = "." + elif args.local_path is None: + # Implicit case 3: user provided only a repo_id that does not match a local file or folder + # => the user must explicitly provide a local_path => raise exception + raise ValueError(f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly.") + elif args.path_in_repo is None and os.path.isfile(args.local_path): + # Explicit local path to file, no path in repo => upload it at root with same name + self.local_path = args.local_path + self.path_in_repo = os.path.basename(args.local_path) + elif args.path_in_repo is None: + # Explicit local path to folder, no path in repo => upload at root + self.local_path = args.local_path + self.path_in_repo = "." + else: + # Finally, if both paths are explicit + self.local_path = args.local_path + self.path_in_repo = args.path_in_repo + + def run(self) -> None: + show_deprecation_warning("huggingface-cli upload", "hf upload") + + if self.quiet: + disable_progress_bars() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + print(self._upload()) + enable_progress_bars() + else: + logging.set_verbosity_info() + print(self._upload()) + logging.set_verbosity_warning() + + def _upload(self) -> str: + if os.path.isfile(self.local_path): + if self.include is not None and len(self.include) > 0: + warnings.warn("Ignoring `--include` since a single file is uploaded.") + if self.exclude is not None and len(self.exclude) > 0: + warnings.warn("Ignoring `--exclude` since a single file is uploaded.") + if self.delete is not None and len(self.delete) > 0: + warnings.warn("Ignoring `--delete` since a single file is uploaded.") + + if not is_xet_available() and not HF_HUB_ENABLE_HF_TRANSFER: + logger.info( + "Consider using `hf_transfer` for faster uploads. This solution comes with some limitations. See" + " https://huggingface.co/docs/huggingface_hub/hf_transfer for more details." + ) + + # Schedule commits if `every` is set + if self.every is not None: + if os.path.isfile(self.local_path): + # If file => watch entire folder + use allow_patterns + folder_path = os.path.dirname(self.local_path) + path_in_repo = ( + self.path_in_repo[: -len(self.local_path)] # remove filename from path_in_repo + if self.path_in_repo.endswith(self.local_path) + else self.path_in_repo + ) + allow_patterns = [self.local_path] + ignore_patterns = [] + else: + folder_path = self.local_path + path_in_repo = self.path_in_repo + allow_patterns = self.include or [] + ignore_patterns = self.exclude or [] + if self.delete is not None and len(self.delete) > 0: + warnings.warn("Ignoring `--delete` when uploading with scheduled commits.") + + scheduler = CommitScheduler( + folder_path=folder_path, + repo_id=self.repo_id, + repo_type=self.repo_type, + revision=self.revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + path_in_repo=path_in_repo, + private=self.private, + every=self.every, + hf_api=self.api, + ) + print(f"Scheduling commits every {self.every} minutes to {scheduler.repo_id}.") + try: # Block main thread until KeyboardInterrupt + while True: + time.sleep(100) + except KeyboardInterrupt: + scheduler.stop() + return "Stopped scheduled commits." + + # Otherwise, create repo and proceed with the upload + if not os.path.isfile(self.local_path) and not os.path.isdir(self.local_path): + raise FileNotFoundError(f"No such file or directory: '{self.local_path}'.") + repo_id = self.api.create_repo( + repo_id=self.repo_id, + repo_type=self.repo_type, + exist_ok=True, + private=self.private, + space_sdk="gradio" if self.repo_type == "space" else None, + # ^ We don't want it to fail when uploading to a Space => let's set Gradio by default. + # ^ I'd rather not add CLI args to set it explicitly as we already have `huggingface-cli repo create` for that. + ).repo_id + + # Check if branch already exists and if not, create it + if self.revision is not None and not self.create_pr: + try: + self.api.repo_info(repo_id=repo_id, repo_type=self.repo_type, revision=self.revision) + except RevisionNotFoundError: + logger.info(f"Branch '{self.revision}' not found. Creating it...") + self.api.create_branch(repo_id=repo_id, repo_type=self.repo_type, branch=self.revision, exist_ok=True) + # ^ `exist_ok=True` to avoid race concurrency issues + + # File-based upload + if os.path.isfile(self.local_path): + return self.api.upload_file( + path_or_fileobj=self.local_path, + path_in_repo=self.path_in_repo, + repo_id=repo_id, + repo_type=self.repo_type, + revision=self.revision, + commit_message=self.commit_message, + commit_description=self.commit_description, + create_pr=self.create_pr, + ) + + # Folder-based upload + else: + return self.api.upload_folder( + folder_path=self.local_path, + path_in_repo=self.path_in_repo, + repo_id=repo_id, + repo_type=self.repo_type, + revision=self.revision, + commit_message=self.commit_message, + commit_description=self.commit_description, + create_pr=self.create_pr, + allow_patterns=self.include, + ignore_patterns=self.exclude, + delete_patterns=self.delete, + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/upload_large_folder.py b/phivenv/Lib/site-packages/huggingface_hub/commands/upload_large_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..3105ba3f57f5644aa18e627aa5d1d18e61515ae7 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/upload_large_folder.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to upload a large folder with the CLI.""" + +import os +from argparse import Namespace, _SubParsersAction +from typing import List, Optional + +from huggingface_hub import logging +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import disable_progress_bars + +from ._cli_utils import ANSI, show_deprecation_warning + + +logger = logging.get_logger(__name__) + + +class UploadLargeFolderCommand(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + subparser = parser.add_parser("upload-large-folder", help="Upload a large folder to a repo on the Hub") + subparser.add_argument( + "repo_id", type=str, help="The ID of the repo to upload to (e.g. `username/repo-name`)." + ) + subparser.add_argument("local_path", type=str, help="Local path to the file or folder to upload.") + subparser.add_argument( + "--repo-type", + choices=["model", "dataset", "space"], + help="Type of the repo to upload to (e.g. `dataset`).", + ) + subparser.add_argument( + "--revision", + type=str, + help=("An optional Git revision to push to. It can be a branch name or a PR reference."), + ) + subparser.add_argument( + "--private", + action="store_true", + help=( + "Whether to create a private repo if repo doesn't exist on the Hub. Ignored if the repo already exists." + ), + ) + subparser.add_argument("--include", nargs="*", type=str, help="Glob patterns to match files to upload.") + subparser.add_argument("--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to upload.") + subparser.add_argument( + "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens" + ) + subparser.add_argument( + "--num-workers", type=int, help="Number of workers to use to hash, upload and commit files." + ) + subparser.add_argument("--no-report", action="store_true", help="Whether to disable regular status report.") + subparser.add_argument("--no-bars", action="store_true", help="Whether to disable progress bars.") + subparser.set_defaults(func=UploadLargeFolderCommand) + + def __init__(self, args: Namespace) -> None: + self.repo_id: str = args.repo_id + self.local_path: str = args.local_path + self.repo_type: str = args.repo_type + self.revision: Optional[str] = args.revision + self.private: bool = args.private + + self.include: Optional[List[str]] = args.include + self.exclude: Optional[List[str]] = args.exclude + + self.api: HfApi = HfApi(token=args.token, library_name="huggingface-cli") + + self.num_workers: Optional[int] = args.num_workers + self.no_report: bool = args.no_report + self.no_bars: bool = args.no_bars + + if not os.path.isdir(self.local_path): + raise ValueError("Large upload is only supported for folders.") + + def run(self) -> None: + show_deprecation_warning("huggingface-cli upload-large-folder", "hf upload-large-folder") + + logging.set_verbosity_info() + + print( + ANSI.yellow( + "You are about to upload a large folder to the Hub using `huggingface-cli upload-large-folder`. " + "This is a new feature so feedback is very welcome!\n" + "\n" + "A few things to keep in mind:\n" + " - Repository limits still apply: https://huggingface.co/docs/hub/repositories-recommendations\n" + " - Do not start several processes in parallel.\n" + " - You can interrupt and resume the process at any time. " + "The script will pick up where it left off except for partially uploaded files that would have to be entirely reuploaded.\n" + " - Do not upload the same folder to several repositories. If you need to do so, you must delete the `./.cache/huggingface/` folder first.\n" + "\n" + f"Some temporary metadata will be stored under `{self.local_path}/.cache/huggingface`.\n" + " - You must not modify those files manually.\n" + " - You must not delete the `./.cache/huggingface/` folder while a process is running.\n" + " - You can delete the `./.cache/huggingface/` folder to reinitialize the upload state when process is not running. Files will have to be hashed and preuploaded again, except for already committed files.\n" + "\n" + "If the process output is too verbose, you can disable the progress bars with `--no-bars`. " + "You can also entirely disable the status report with `--no-report`.\n" + "\n" + "For more details, run `huggingface-cli upload-large-folder --help` or check the documentation at " + "https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-large-folder." + ) + ) + + if self.no_bars: + disable_progress_bars() + + self.api.upload_large_folder( + repo_id=self.repo_id, + folder_path=self.local_path, + repo_type=self.repo_type, + revision=self.revision, + private=self.private, + allow_patterns=self.include, + ignore_patterns=self.exclude, + num_workers=self.num_workers, + print_report=not self.no_report, + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/user.py b/phivenv/Lib/site-packages/huggingface_hub/commands/user.py new file mode 100644 index 0000000000000000000000000000000000000000..e46fd3e51277fd05258b71569cb3f52a78c791bb --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/user.py @@ -0,0 +1,208 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains commands to authenticate to the Hugging Face Hub and interact with your repositories. + +Usage: + # login and save token locally. + huggingface-cli login --token=hf_*** --add-to-git-credential + + # switch between tokens + huggingface-cli auth switch + + # list all tokens + huggingface-cli auth list + + # logout from a specific token, if no token-name is provided, all tokens will be deleted from your machine. + huggingface-cli logout --token-name=your_token_name + + # find out which huggingface.co account you are logged in as + huggingface-cli whoami +""" + +from argparse import _SubParsersAction +from typing import List, Optional + +from requests.exceptions import HTTPError + +from huggingface_hub.commands import BaseHuggingfaceCLICommand +from huggingface_hub.constants import ENDPOINT +from huggingface_hub.hf_api import HfApi + +from .._login import auth_list, auth_switch, login, logout +from ..utils import get_stored_tokens, get_token, logging +from ._cli_utils import ANSI, show_deprecation_warning + + +logger = logging.get_logger(__name__) + +try: + from InquirerPy import inquirer + from InquirerPy.base.control import Choice + + _inquirer_py_available = True +except ImportError: + _inquirer_py_available = False + + +class UserCommands(BaseHuggingfaceCLICommand): + @staticmethod + def register_subcommand(parser: _SubParsersAction): + login_parser = parser.add_parser("login", help="Log in using a token from huggingface.co/settings/tokens") + login_parser.add_argument( + "--token", + type=str, + help="Token generated from https://huggingface.co/settings/tokens", + ) + login_parser.add_argument( + "--add-to-git-credential", + action="store_true", + help="Optional: Save token to git credential helper.", + ) + login_parser.set_defaults(func=lambda args: LoginCommand(args)) + whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.") + whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) + + logout_parser = parser.add_parser("logout", help="Log out") + logout_parser.add_argument( + "--token-name", + type=str, + help="Optional: Name of the access token to log out from.", + ) + logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) + + auth_parser = parser.add_parser("auth", help="Other authentication related commands") + auth_subparsers = auth_parser.add_subparsers(help="Authentication subcommands") + auth_switch_parser = auth_subparsers.add_parser("switch", help="Switch between access tokens") + auth_switch_parser.add_argument( + "--token-name", + type=str, + help="Optional: Name of the access token to switch to.", + ) + auth_switch_parser.add_argument( + "--add-to-git-credential", + action="store_true", + help="Optional: Save token to git credential helper.", + ) + auth_switch_parser.set_defaults(func=lambda args: AuthSwitchCommand(args)) + auth_list_parser = auth_subparsers.add_parser("list", help="List all stored access tokens") + auth_list_parser.set_defaults(func=lambda args: AuthListCommand(args)) + + +class BaseUserCommand: + def __init__(self, args): + self.args = args + self._api = HfApi() + + +class LoginCommand(BaseUserCommand): + def run(self): + show_deprecation_warning("huggingface-cli login", "hf auth login") + + logging.set_verbosity_info() + login( + token=self.args.token, + add_to_git_credential=self.args.add_to_git_credential, + ) + + +class LogoutCommand(BaseUserCommand): + def run(self): + show_deprecation_warning("huggingface-cli logout", "hf auth logout") + + logging.set_verbosity_info() + logout(token_name=self.args.token_name) + + +class AuthSwitchCommand(BaseUserCommand): + def run(self): + show_deprecation_warning("huggingface-cli auth switch", "hf auth switch") + + logging.set_verbosity_info() + token_name = self.args.token_name + if token_name is None: + token_name = self._select_token_name() + + if token_name is None: + print("No token name provided. Aborting.") + exit() + auth_switch(token_name, add_to_git_credential=self.args.add_to_git_credential) + + def _select_token_name(self) -> Optional[str]: + token_names = list(get_stored_tokens().keys()) + + if not token_names: + logger.error("No stored tokens found. Please login first.") + return None + + if _inquirer_py_available: + return self._select_token_name_tui(token_names) + # if inquirer is not available, use a simpler terminal UI + print("Available stored tokens:") + for i, token_name in enumerate(token_names, 1): + print(f"{i}. {token_name}") + while True: + try: + choice = input("Enter the number of the token to switch to (or 'q' to quit): ") + if choice.lower() == "q": + return None + index = int(choice) - 1 + if 0 <= index < len(token_names): + return token_names[index] + else: + print("Invalid selection. Please try again.") + except ValueError: + print("Invalid input. Please enter a number or 'q' to quit.") + + def _select_token_name_tui(self, token_names: List[str]) -> Optional[str]: + choices = [Choice(token_name, name=token_name) for token_name in token_names] + try: + return inquirer.select( + message="Select a token to switch to:", + choices=choices, + default=None, + ).execute() + except KeyboardInterrupt: + logger.info("Token selection cancelled.") + return None + + +class AuthListCommand(BaseUserCommand): + def run(self): + show_deprecation_warning("huggingface-cli auth list", "hf auth list") + + logging.set_verbosity_info() + auth_list() + + +class WhoamiCommand(BaseUserCommand): + def run(self): + show_deprecation_warning("huggingface-cli whoami", "hf auth whoami") + + token = get_token() + if token is None: + print("Not logged in") + exit() + try: + info = self._api.whoami(token) + print(ANSI.bold("user: "), info["name"]) + orgs = [org["name"] for org in info["orgs"]] + if orgs: + print(ANSI.bold("orgs: "), ",".join(orgs)) + + if ENDPOINT != "https://huggingface.co": + print(f"Authenticated through private endpoint: {ENDPOINT}") + except HTTPError as e: + print(e) + print(ANSI.red(e.response.text)) + exit(1) diff --git a/phivenv/Lib/site-packages/huggingface_hub/commands/version.py b/phivenv/Lib/site-packages/huggingface_hub/commands/version.py new file mode 100644 index 0000000000000000000000000000000000000000..10d341bcdb93e0616fcf80370ac8dde63b15ce9c --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/commands/version.py @@ -0,0 +1,40 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains command to print information about the version. + +Usage: + huggingface-cli version +""" + +from argparse import _SubParsersAction + +from huggingface_hub import __version__ + +from . import BaseHuggingfaceCLICommand +from ._cli_utils import show_deprecation_warning + + +class VersionCommand(BaseHuggingfaceCLICommand): + def __init__(self, args): + self.args = args + + @staticmethod + def register_subcommand(parser: _SubParsersAction): + version_parser = parser.add_parser("version", help="Print information about the huggingface-cli version.") + version_parser.set_defaults(func=VersionCommand) + + def run(self) -> None: + show_deprecation_warning("huggingface-cli version", "hf version") + + print(f"huggingface_hub version: {__version__}") diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e5e7a53c742a7e8f9e44cbcabc25fd42874a8b1 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b7499bf964f3fc69f1d45fcd08947ba084043b7 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/__pycache__/_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_client.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd23dc35466e7282a75d7a6eed7da38a28753f8 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_client.py @@ -0,0 +1,3552 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Related resources: +# https://huggingface.co/tasks +# https://huggingface.co/docs/huggingface.js/inference/README +# https://github.com/huggingface/huggingface.js/tree/main/packages/inference/src +# https://github.com/huggingface/text-generation-inference/tree/main/clients/python +# https://github.com/huggingface/text-generation-inference/blob/main/clients/python/text_generation/client.py +# https://huggingface.slack.com/archives/C03E4DQ9LAJ/p1680169099087869 +# https://github.com/huggingface/unity-api#tasks +# +# Some TODO: +# - add all tasks +# +# NOTE: the philosophy of this client is "let's make it as easy as possible to use it, even if less optimized". Some +# examples of how it translates: +# - Timeout / Server unavailable is handled by the client in a single "timeout" parameter. +# - Files can be provided as bytes, file paths, or URLs and the client will try to "guess" the type. +# - Images are parsed as PIL.Image for easier manipulation. +# - Provides a "recommended model" for each task => suboptimal but user-wise quicker to get a first script running. +# - Only the main parameters are publicly exposed. Power users can always read the docs for more options. +import base64 +import logging +import re +import warnings +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union, overload + +from requests import HTTPError + +from huggingface_hub import constants +from huggingface_hub.errors import BadRequestError, InferenceTimeoutError +from huggingface_hub.inference._common import ( + TASKS_EXPECTING_IMAGES, + ContentT, + ModelStatus, + RequestParameters, + _b64_encode, + _b64_to_image, + _bytes_to_dict, + _bytes_to_image, + _bytes_to_list, + _get_unsupported_text_generation_kwargs, + _import_numpy, + _open_as_binary, + _set_unsupported_text_generation_kwargs, + _stream_chat_completion_response, + _stream_text_generation_response, + raise_text_generation_error, +) +from huggingface_hub.inference._generated.types import ( + AudioClassificationOutputElement, + AudioClassificationOutputTransform, + AudioToAudioOutputElement, + AutomaticSpeechRecognitionOutput, + ChatCompletionInputGrammarType, + ChatCompletionInputMessage, + ChatCompletionInputStreamOptions, + ChatCompletionInputTool, + ChatCompletionInputToolChoiceClass, + ChatCompletionInputToolChoiceEnum, + ChatCompletionOutput, + ChatCompletionStreamOutput, + DocumentQuestionAnsweringOutputElement, + FillMaskOutputElement, + ImageClassificationOutputElement, + ImageClassificationOutputTransform, + ImageSegmentationOutputElement, + ImageSegmentationSubtask, + ImageToImageTargetSize, + ImageToTextOutput, + ImageToVideoTargetSize, + ObjectDetectionOutputElement, + Padding, + QuestionAnsweringOutputElement, + SummarizationOutput, + SummarizationTruncationStrategy, + TableQuestionAnsweringOutputElement, + TextClassificationOutputElement, + TextClassificationOutputTransform, + TextGenerationInputGrammarType, + TextGenerationOutput, + TextGenerationStreamOutput, + TextToSpeechEarlyStoppingEnum, + TokenClassificationAggregationStrategy, + TokenClassificationOutputElement, + TranslationOutput, + TranslationTruncationStrategy, + VisualQuestionAnsweringOutputElement, + ZeroShotClassificationOutputElement, + ZeroShotImageClassificationOutputElement, +) +from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper +from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status +from huggingface_hub.utils._auth import get_token +from huggingface_hub.utils._deprecation import _deprecate_method + + +if TYPE_CHECKING: + import numpy as np + from PIL.Image import Image + +logger = logging.getLogger(__name__) + + +MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]") + + +class InferenceClient: + """ + Initialize a new Inference Client. + + [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used + seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers. + + Args: + model (`str`, `optional`): + The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` + or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is + automatically selected for the task. + Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 + arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL. + provider (`str`, *optional*): + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. + If model is a URL or `base_url` is passed, then `provider` is not used. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2 + arguments are mutually exclusive and have the exact same behavior. + timeout (`float`, `optional`): + The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available. + headers (`Dict[str, str]`, `optional`): + Additional headers to send to the server. By default only the authorization and user-agent headers are sent. + Values in this dictionary will override the default values. + bill_to (`str`, `optional`): + The billing account to use for the requests. By default the requests are billed on the user's account. + Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub. + cookies (`Dict[str, str]`, `optional`): + Additional cookies to send to the server. + proxies (`Any`, `optional`): + Proxies to use for the request. + base_url (`str`, `optional`): + Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] + follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None. + api_key (`str`, `optional`): + Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`] + follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None. + """ + + def __init__( + self, + model: Optional[str] = None, + *, + provider: Optional[PROVIDER_OR_POLICY_T] = None, + token: Optional[str] = None, + timeout: Optional[float] = None, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, + proxies: Optional[Any] = None, + bill_to: Optional[str] = None, + # OpenAI compatibility + base_url: Optional[str] = None, + api_key: Optional[str] = None, + ) -> None: + if model is not None and base_url is not None: + raise ValueError( + "Received both `model` and `base_url` arguments. Please provide only one of them." + " `base_url` is an alias for `model` to make the API compatible with OpenAI's client." + " If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url." + " When passing a URL as `model`, the client will not append any suffix path to it." + ) + if token is not None and api_key is not None: + raise ValueError( + "Received both `token` and `api_key` arguments. Please provide only one of them." + " `api_key` is an alias for `token` to make the API compatible with OpenAI's client." + " It has the exact same behavior as `token`." + ) + token = token if token is not None else api_key + if isinstance(token, bool): + # Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not + # supported anymore as authentication is required. Better to explicitly raise here rather than risking + # sending the locally saved token without the user knowing about it. + if token is False: + raise ValueError( + "Cannot use `token=False` to disable authentication as authentication is required to run Inference." + ) + warnings.warn( + "Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. " + "Please use `token=None` instead (default).", + DeprecationWarning, + ) + token = get_token() + + self.model: Optional[str] = base_url or model + self.token: Optional[str] = token + + self.headers = {**headers} if headers is not None else {} + if bill_to is not None: + if ( + constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers + and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to + ): + warnings.warn( + f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.", + UserWarning, + ) + self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to + + if token is not None and not token.startswith("hf_"): + warnings.warn( + "You've provided an external provider's API key, so requests will be billed directly by the provider. " + "The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.", + UserWarning, + ) + + # Configure provider + self.provider = provider + + self.cookies = cookies + self.timeout = timeout + self.proxies = proxies + + def __repr__(self): + return f"" + + @overload + def _inner_post( # type: ignore[misc] + self, request_parameters: RequestParameters, *, stream: Literal[False] = ... + ) -> bytes: ... + + @overload + def _inner_post( # type: ignore[misc] + self, request_parameters: RequestParameters, *, stream: Literal[True] = ... + ) -> Iterable[bytes]: ... + + @overload + def _inner_post( + self, request_parameters: RequestParameters, *, stream: bool = False + ) -> Union[bytes, Iterable[bytes]]: ... + + def _inner_post( + self, request_parameters: RequestParameters, *, stream: bool = False + ) -> Union[bytes, Iterable[bytes]]: + """Make a request to the inference server.""" + # TODO: this should be handled in provider helpers directly + if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers: + request_parameters.headers["Accept"] = "image/png" + + with _open_as_binary(request_parameters.data) as data_as_binary: + try: + response = get_session().post( + request_parameters.url, + json=request_parameters.json, + data=data_as_binary, + headers=request_parameters.headers, + cookies=self.cookies, + timeout=self.timeout, + stream=stream, + proxies=self.proxies, + ) + except TimeoutError as error: + # Convert any `TimeoutError` to a `InferenceTimeoutError` + raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore + + try: + hf_raise_for_status(response) + return response.iter_lines() if stream else response.content + except HTTPError as error: + if error.response.status_code == 422 and request_parameters.task != "unknown": + msg = str(error.args[0]) + if len(error.response.text) > 0: + msg += f"\n{error.response.text}\n" + error.args = (msg,) + error.args[1:] + raise + + def audio_classification( + self, + audio: ContentT, + *, + model: Optional[str] = None, + top_k: Optional[int] = None, + function_to_apply: Optional["AudioClassificationOutputTransform"] = None, + ) -> List[AudioClassificationOutputElement]: + """ + Perform audio classification on the provided audio content. + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an + audio file. + model (`str`, *optional*): + The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub + or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for + audio classification will be used. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + function_to_apply (`"AudioClassificationOutputTransform"`, *optional*): + The function to apply to the model outputs in order to retrieve the scores. + + Returns: + `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.audio_classification("audio.flac") + [ + AudioClassificationOutputElement(score=0.4976358711719513, label='hap'), + AudioClassificationOutputElement(score=0.3677836060523987, label='neu'), + ... + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="audio-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=audio, + parameters={"function_to_apply": function_to_apply, "top_k": top_k}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return AudioClassificationOutputElement.parse_obj_as_list(response) + + def audio_to_audio( + self, + audio: ContentT, + *, + model: Optional[str] = None, + ) -> List[AudioToAudioOutputElement]: + """ + Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation). + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The audio content for the model. It can be raw audio bytes, a local audio file, or a URL pointing to an + audio file. + model (`str`, *optional*): + The model can be any model which takes an audio file and returns another audio file. Can be a model ID hosted on the Hugging Face Hub + or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for + audio_to_audio will be used. + + Returns: + `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. + + Raises: + `InferenceTimeoutError`: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> audio_output = client.audio_to_audio("audio.flac") + >>> for i, item in enumerate(audio_output): + >>> with open(f"output_{i}.flac", "wb") as f: + f.write(item.blob) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="audio-to-audio", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=audio, + parameters={}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + audio_output = AudioToAudioOutputElement.parse_obj_as_list(response) + for item in audio_output: + item.blob = base64.b64decode(item.blob) + return audio_output + + def automatic_speech_recognition( + self, + audio: ContentT, + *, + model: Optional[str] = None, + extra_body: Optional[Dict] = None, + ) -> AutomaticSpeechRecognitionOutput: + """ + Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file. + model (`str`, *optional*): + The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for ASR will be used. + extra_body (`Dict`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + Returns: + [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.automatic_speech_recognition("hello_world.flac").text + "hello world" + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=audio, + parameters={**(extra_body or {})}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response) + + @overload + def chat_completion( # type: ignore + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + *, + model: Optional[str] = None, + stream: Literal[False] = False, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream_options: Optional[ChatCompletionInputStreamOptions] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + extra_body: Optional[Dict] = None, + ) -> ChatCompletionOutput: ... + + @overload + def chat_completion( # type: ignore + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + *, + model: Optional[str] = None, + stream: Literal[True] = True, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream_options: Optional[ChatCompletionInputStreamOptions] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + extra_body: Optional[Dict] = None, + ) -> Iterable[ChatCompletionStreamOutput]: ... + + @overload + def chat_completion( + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + *, + model: Optional[str] = None, + stream: bool = False, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream_options: Optional[ChatCompletionInputStreamOptions] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + extra_body: Optional[Dict] = None, + ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ... + + def chat_completion( + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + *, + model: Optional[str] = None, + stream: bool = False, + # Parameters from ChatCompletionInput (handled manually) + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream_options: Optional[ChatCompletionInputStreamOptions] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + extra_body: Optional[Dict] = None, + ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: + """ + A method for completing conversations using a specified language model. + + + + The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client. + Inputs and outputs are strictly the same and using either syntax will yield the same results. + Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility) + for more details about OpenAI's compatibility. + + + + + You can pass provider-specific parameters to the model by using the `extra_body` argument. + + + Args: + messages (List of [`ChatCompletionInputMessage`]): + Conversation history consisting of roles and content pairs. + model (`str`, *optional*): + The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used. + See https://huggingface.co/tasks/text-generation for more details. + If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a + custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`]. + frequency_penalty (`float`, *optional*): + Penalizes new tokens based on their existing frequency + in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0. + logit_bias (`List[float]`, *optional*): + Adjusts the likelihood of specific tokens appearing in the generated output. + logprobs (`bool`, *optional*): + Whether to return log probabilities of the output tokens or not. If true, returns the log + probabilities of each output token returned in the content of message. + max_tokens (`int`, *optional*): + Maximum number of tokens allowed in the response. Defaults to 100. + n (`int`, *optional*): + The number of completions to generate for each prompt. + presence_penalty (`float`, *optional*): + Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + text so far, increasing the model's likelihood to talk about new topics. + response_format ([`ChatCompletionInputGrammarType`], *optional*): + Grammar constraints. Can be either a JSONSchema or a regex. + seed (Optional[`int`], *optional*): + Seed for reproducible control flow. Defaults to None. + stop (`List[str]`, *optional*): + Up to four strings which trigger the end of the response. + Defaults to None. + stream (`bool`, *optional*): + Enable realtime streaming of responses. Defaults to False. + stream_options ([`ChatCompletionInputStreamOptions`], *optional*): + Options for streaming completions. + temperature (`float`, *optional*): + Controls randomness of the generations. Lower values ensure + less random completions. Range: [0, 2]. Defaults to 1.0. + top_logprobs (`int`, *optional*): + An integer between 0 and 5 specifying the number of most likely tokens to return at each token + position, each with an associated log probability. logprobs must be set to true if this parameter is + used. + top_p (`float`, *optional*): + Fraction of the most likely next words to sample from. + Must be between 0 and 1. Defaults to 1.0. + tool_choice ([`ChatCompletionInputToolChoiceClass`] or [`ChatCompletionInputToolChoiceEnum`], *optional*): + The tool to use for the completion. Defaults to "auto". + tool_prompt (`str`, *optional*): + A prompt to be appended before the tools. + tools (List of [`ChatCompletionInputTool`], *optional*): + A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + provide a list of functions the model may generate JSON inputs for. + extra_body (`Dict`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + Returns: + [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]: + Generated text returned from the server: + - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default). + - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`]. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + + ```py + >>> from huggingface_hub import InferenceClient + >>> messages = [{"role": "user", "content": "What is the capital of France?"}] + >>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") + >>> client.chat_completion(messages, max_tokens=100) + ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason='eos_token', + index=0, + message=ChatCompletionOutputMessage( + role='assistant', + content='The capital of France is Paris.', + name=None, + tool_calls=None + ), + logprobs=None + ) + ], + created=1719907176, + id='', + model='meta-llama/Meta-Llama-3-8B-Instruct', + object='text_completion', + system_fingerprint='2.0.4-sha-f426a33', + usage=ChatCompletionOutputUsage( + completion_tokens=8, + prompt_tokens=17, + total_tokens=25 + ) + ) + ``` + + Example using streaming: + ```py + >>> from huggingface_hub import InferenceClient + >>> messages = [{"role": "user", "content": "What is the capital of France?"}] + >>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") + >>> for token in client.chat_completion(messages, max_tokens=10, stream=True): + ... print(token) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504) + (...) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504) + ``` + + Example using OpenAI's syntax: + ```py + # instead of `from openai import OpenAI` + from huggingface_hub import InferenceClient + + # instead of `client = OpenAI(...)` + client = InferenceClient( + base_url=..., + api_key=..., + ) + + output = client.chat.completions.create( + model="meta-llama/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + print(chunk.choices[0].delta.content) + ``` + + Example using a third-party provider directly with extra (provider-specific) parameters. Usage will be billed on your Together AI account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="together", # Use Together AI provider + ... api_key="", # Pass your Together API key directly + ... ) + >>> client.chat_completion( + ... model="meta-llama/Meta-Llama-3-8B-Instruct", + ... messages=[{"role": "user", "content": "What is the capital of France?"}], + ... extra_body={"safety_model": "Meta-Llama/Llama-Guard-7b"}, + ... ) + ``` + + Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="sambanova", # Use Sambanova provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> client.chat_completion( + ... model="meta-llama/Meta-Llama-3-8B-Instruct", + ... messages=[{"role": "user", "content": "What is the capital of France?"}], + ... ) + ``` + + Example using Image + Text as input: + ```py + >>> from huggingface_hub import InferenceClient + + # provide a remote URL + >>> image_url ="https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + # or a base64-encoded image + >>> image_path = "/path/to/image.jpeg" + >>> with open(image_path, "rb") as f: + ... base64_image = base64.b64encode(f.read()).decode("utf-8") + >>> image_url = f"data:image/jpeg;base64,{base64_image}" + + >>> client = InferenceClient("meta-llama/Llama-3.2-11B-Vision-Instruct") + >>> output = client.chat.completions.create( + ... messages=[ + ... { + ... "role": "user", + ... "content": [ + ... { + ... "type": "image_url", + ... "image_url": {"url": image_url}, + ... }, + ... { + ... "type": "text", + ... "text": "Describe this image in one sentence.", + ... }, + ... ], + ... }, + ... ], + ... ) + >>> output + The image depicts the iconic Statue of Liberty situated in New York Harbor, New York, on a clear day. + ``` + + Example using tools: + ```py + >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "system", + ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + ... }, + ... { + ... "role": "user", + ... "content": "What's the weather like the next 3 days in San Francisco, CA?", + ... }, + ... ] + >>> tools = [ + ... { + ... "type": "function", + ... "function": { + ... "name": "get_current_weather", + ... "description": "Get the current weather", + ... "parameters": { + ... "type": "object", + ... "properties": { + ... "location": { + ... "type": "string", + ... "description": "The city and state, e.g. San Francisco, CA", + ... }, + ... "format": { + ... "type": "string", + ... "enum": ["celsius", "fahrenheit"], + ... "description": "The temperature unit to use. Infer this from the users location.", + ... }, + ... }, + ... "required": ["location", "format"], + ... }, + ... }, + ... }, + ... { + ... "type": "function", + ... "function": { + ... "name": "get_n_day_weather_forecast", + ... "description": "Get an N-day weather forecast", + ... "parameters": { + ... "type": "object", + ... "properties": { + ... "location": { + ... "type": "string", + ... "description": "The city and state, e.g. San Francisco, CA", + ... }, + ... "format": { + ... "type": "string", + ... "enum": ["celsius", "fahrenheit"], + ... "description": "The temperature unit to use. Infer this from the users location.", + ... }, + ... "num_days": { + ... "type": "integer", + ... "description": "The number of days to forecast", + ... }, + ... }, + ... "required": ["location", "format", "num_days"], + ... }, + ... }, + ... }, + ... ] + + >>> response = client.chat_completion( + ... model="meta-llama/Meta-Llama-3-70B-Instruct", + ... messages=messages, + ... tools=tools, + ... tool_choice="auto", + ... max_tokens=500, + ... ) + >>> response.choices[0].message.tool_calls[0].function + ChatCompletionOutputFunctionDefinition( + arguments={ + 'location': 'San Francisco, CA', + 'format': 'fahrenheit', + 'num_days': 3 + }, + name='get_n_day_weather_forecast', + description=None + ) + ``` + + Example using response_format: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "user", + ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", + ... }, + ... ] + >>> response_format = { + ... "type": "json", + ... "value": { + ... "properties": { + ... "location": {"type": "string"}, + ... "activity": {"type": "string"}, + ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, + ... "animals": {"type": "array", "items": {"type": "string"}}, + ... }, + ... "required": ["location", "activity", "animals_seen", "animals"], + ... }, + ... } + >>> response = client.chat_completion( + ... messages=messages, + ... response_format=response_format, + ... max_tokens=500, + ... ) + >>> response.choices[0].message.content + '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' + ``` + """ + # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. + # `self.model` takes precedence over 'model' argument for building URL. + # `model` takes precedence for payload value. + model_id_or_url = self.model or model + payload_model = model or self.model + + # Get the provider helper + provider_helper = get_provider_helper( + self.provider, + task="conversational", + model=model_id_or_url + if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://")) + else payload_model, + ) + + # Prepare the payload + parameters = { + "model": payload_model, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_tokens": max_tokens, + "n": n, + "presence_penalty": presence_penalty, + "response_format": response_format, + "seed": seed, + "stop": stop, + "temperature": temperature, + "tool_choice": tool_choice, + "tool_prompt": tool_prompt, + "tools": tools, + "top_logprobs": top_logprobs, + "top_p": top_p, + "stream": stream, + "stream_options": stream_options, + **(extra_body or {}), + } + request_parameters = provider_helper.prepare_request( + inputs=messages, + parameters=parameters, + headers=self.headers, + model=model_id_or_url, + api_key=self.token, + ) + data = self._inner_post(request_parameters, stream=stream) + + if stream: + return _stream_chat_completion_response(data) # type: ignore[arg-type] + + return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + + def document_question_answering( + self, + image: ContentT, + question: str, + *, + model: Optional[str] = None, + doc_stride: Optional[int] = None, + handle_impossible_answer: Optional[bool] = None, + lang: Optional[str] = None, + max_answer_len: Optional[int] = None, + max_question_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + top_k: Optional[int] = None, + word_boxes: Optional[List[Union[List[float], str]]] = None, + ) -> List[DocumentQuestionAnsweringOutputElement]: + """ + Answer questions on document images. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image for the context. It can be raw bytes, an image file, or a URL to an online image. + question (`str`): + Question to be answered. + model (`str`, *optional*): + The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. + Defaults to None. + doc_stride (`int`, *optional*): + If the words in the document are too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. + handle_impossible_answer (`bool`, *optional*): + Whether to accept impossible as an answer + lang (`str`, *optional*): + Language to use while running OCR. Defaults to english. + max_answer_len (`int`, *optional*): + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). + max_question_len (`int`, *optional*): + The maximum length of the question after tokenization. It will be truncated if needed. + max_seq_len (`int`, *optional*): + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using doc_stride as overlap) if needed. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Can return less than top_k + answers if there are not enough options available within the context. + word_boxes (`List[Union[List[float], str`, *optional*): + A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR + step and use the provided bounding boxes instead. + Returns: + `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") + [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16)] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id) + inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + request_parameters = provider_helper.prepare_request( + inputs=inputs, + parameters={ + "doc_stride": doc_stride, + "handle_impossible_answer": handle_impossible_answer, + "lang": lang, + "max_answer_len": max_answer_len, + "max_question_len": max_question_len, + "max_seq_len": max_seq_len, + "top_k": top_k, + "word_boxes": word_boxes, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) + + def feature_extraction( + self, + text: str, + *, + normalize: Optional[bool] = None, + prompt_name: Optional[str] = None, + truncate: Optional[bool] = None, + truncation_direction: Optional[Literal["Left", "Right"]] = None, + model: Optional[str] = None, + ) -> "np.ndarray": + """ + Generate embeddings for a given text. + + Args: + text (`str`): + The text to embed. + model (`str`, *optional*): + The model to use for the feature extraction task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended feature extraction model will be used. + Defaults to None. + normalize (`bool`, *optional*): + Whether to normalize the embeddings or not. + Only available on server powered by Text-Embedding-Inference. + prompt_name (`str`, *optional*): + The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. + Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...}, + then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" + because the prompt text will be prepended before any text to encode. + truncate (`bool`, *optional*): + Whether to truncate the embeddings or not. + Only available on server powered by Text-Embedding-Inference. + truncation_direction (`Literal["Left", "Right"]`, *optional*): + Which side of the input should be truncated when `truncate=True` is passed. + + Returns: + `np.ndarray`: The embedding representing the input text as a float32 numpy array. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.feature_extraction("Hi, who are you?") + array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ], + [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ], + ..., + [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="feature-extraction", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "normalize": normalize, + "prompt_name": prompt_name, + "truncate": truncate, + "truncation_direction": truncation_direction, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + np = _import_numpy() + return np.array(provider_helper.get_response(response), dtype="float32") + + def fill_mask( + self, + text: str, + *, + model: Optional[str] = None, + targets: Optional[List[str]] = None, + top_k: Optional[int] = None, + ) -> List[FillMaskOutputElement]: + """ + Fill in a hole with a missing word (token to be precise). + + Args: + text (`str`): + a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask). + model (`str`, *optional*): + The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. + targets (`List[str`, *optional*): + When passed, the model will limit the scores to the passed targets instead of looking up in the whole + vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first + resulting token will be used (with a warning, and that might be slower). + top_k (`int`, *optional*): + When passed, overrides the number of predictions to return. + Returns: + `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated + probability, token reference, and completed text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.fill_mask("The goal of life is .") + [ + FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'), + FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.') + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="fill-mask", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={"targets": targets, "top_k": top_k}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return FillMaskOutputElement.parse_obj_as_list(response) + + def image_classification( + self, + image: ContentT, + *, + model: Optional[str] = None, + function_to_apply: Optional["ImageClassificationOutputTransform"] = None, + top_k: Optional[int] = None, + ) -> List[ImageClassificationOutputElement]: + """ + Perform image classification on the given image using the specified model. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The image to classify. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. + function_to_apply (`"ImageClassificationOutputTransform"`, *optional*): + The function to apply to the model outputs in order to retrieve the scores. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + Returns: + `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") + [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={"function_to_apply": function_to_apply, "top_k": top_k}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return ImageClassificationOutputElement.parse_obj_as_list(response) + + def image_segmentation( + self, + image: ContentT, + *, + model: Optional[str] = None, + mask_threshold: Optional[float] = None, + overlap_mask_area_threshold: Optional[float] = None, + subtask: Optional["ImageSegmentationSubtask"] = None, + threshold: Optional[float] = None, + ) -> List[ImageSegmentationOutputElement]: + """ + Perform image segmentation on the given image using the specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The image to segment. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. + mask_threshold (`float`, *optional*): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*): + Mask overlap threshold to eliminate small, disconnected segments. + subtask (`"ImageSegmentationSubtask"`, *optional*): + Segmentation task to be performed, depending on model capabilities. + threshold (`float`, *optional*): + Probability threshold to filter out predicted masks. + Returns: + `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.image_segmentation("cat.jpg") + [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=), ...] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-segmentation", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={ + "mask_threshold": mask_threshold, + "overlap_mask_area_threshold": overlap_mask_area_threshold, + "subtask": subtask, + "threshold": threshold, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + output = ImageSegmentationOutputElement.parse_obj_as_list(response) + for item in output: + item.mask = _b64_to_image(item.mask) # type: ignore [assignment] + return output + + def image_to_image( + self, + image: ContentT, + prompt: Optional[str] = None, + *, + negative_prompt: Optional[str] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + model: Optional[str] = None, + target_size: Optional[ImageToImageTargetSize] = None, + **kwargs, + ) -> "Image": + """ + Perform image-to-image translation using a specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image for translation. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + prompt (`str`, *optional*): + The text prompt to guide the image generation. + negative_prompt (`str`, *optional*): + One prompt to guide what NOT to include in image generation. + num_inference_steps (`int`, *optional*): + For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + guidance_scale (`float`, *optional*): + For diffusion models. A higher guidance scale value encourages the model to generate images closely + linked to the text prompt at the expense of lower image quality. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + target_size (`ImageToImageTargetSize`, *optional*): + The size in pixel of the output image. + + Returns: + `Image`: The translated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger") + >>> image.save("tiger.jpg") + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={ + "prompt": prompt, + "negative_prompt": negative_prompt, + "target_size": target_size, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + **kwargs, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) + return _bytes_to_image(response) + + def image_to_video( + self, + image: ContentT, + *, + model: Optional[str] = None, + prompt: Optional[str] = None, + negative_prompt: Optional[str] = None, + num_frames: Optional[float] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + seed: Optional[int] = None, + target_size: Optional[ImageToVideoTargetSize] = None, + **kwargs, + ) -> bytes: + """ + Generate a video from an input image. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + prompt (`str`, *optional*): + The text prompt to guide the video generation. + negative_prompt (`str`, *optional*): + One prompt to guide what NOT to include in video generation. + num_frames (`float`, *optional*): + The num_frames parameter determines how many video frames are generated. + num_inference_steps (`int`, *optional*): + For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + guidance_scale (`float`, *optional*): + For diffusion models. A higher guidance scale value encourages the model to generate videos closely + linked to the text prompt at the expense of lower image quality. + seed (`int`, *optional*): + The seed to use for the video generation. + target_size (`ImageToVideoTargetSize`, *optional*): + The size in pixel of the output video frames. + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + seed (`int`, *optional*): + Seed for the random number generator. + + Returns: + `bytes`: The generated video. + + Examples: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> video = client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger") + >>> with open("tiger.mp4", "wb") as f: + ... f.write(video) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={ + "prompt": prompt, + "negative_prompt": negative_prompt, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "seed": seed, + "target_size": target_size, + **kwargs, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) + return response + + def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: + """ + Takes an input image and return text. + + Models can have very different outputs depending on your use case (image captioning, optical character recognition + (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + [`ImageToTextOutput`]: The generated text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.image_to_text("cat.jpg") + 'a cat standing in a grassy field ' + >>> client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") + 'a dog laying on the grass next to a flower pot ' + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-to-text", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + output = ImageToTextOutput.parse_obj(response) + return output[0] if isinstance(output, list) else output + + def object_detection( + self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None + ) -> List[ObjectDetectionOutputElement]: + """ + Perform object detection on the given image using the specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The image to detect objects on. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. + threshold (`float`, *optional*): + The probability necessary to make a prediction. + Returns: + `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If the request output is not a List. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.object_detection("people.jpg") + [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="object-detection", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={"threshold": threshold}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return ObjectDetectionOutputElement.parse_obj_as_list(response) + + def question_answering( + self, + question: str, + context: str, + *, + model: Optional[str] = None, + align_to_words: Optional[bool] = None, + doc_stride: Optional[int] = None, + handle_impossible_answer: Optional[bool] = None, + max_answer_len: Optional[int] = None, + max_question_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + top_k: Optional[int] = None, + ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]: + """ + Retrieve the answer to a question from a given text. + + Args: + question (`str`): + Question to be answered. + context (`str`): + The context of the question. + model (`str`): + The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. + align_to_words (`bool`, *optional*): + Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt + on non-space-separated languages (like Japanese or Chinese) + doc_stride (`int`, *optional*): + If the context is too long to fit with the question for the model, it will be split in several chunks + with some overlap. This argument controls the size of that overlap. + handle_impossible_answer (`bool`, *optional*): + Whether to accept impossible as an answer. + max_answer_len (`int`, *optional*): + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). + max_question_len (`int`, *optional*): + The maximum length of the question after tokenization. It will be truncated if needed. + max_seq_len (`int`, *optional*): + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using docStride as overlap) if needed. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + topk answers if there are not enough options available within the context. + + Returns: + Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: + When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. + When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`. + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") + QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs={"question": question, "context": context}, + parameters={ + "align_to_words": align_to_words, + "doc_stride": doc_stride, + "handle_impossible_answer": handle_impossible_answer, + "max_answer_len": max_answer_len, + "max_question_len": max_question_len, + "max_seq_len": max_seq_len, + "top_k": top_k, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + # Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility. + output = QuestionAnsweringOutputElement.parse_obj(response) + return output + + def sentence_similarity( + self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None + ) -> List[float]: + """ + Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. + + Args: + sentence (`str`): + The main sentence to compare to others. + other_sentences (`List[str]`): + The list of sentences to compare to. + model (`str`, *optional*): + The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended sentence similarity model will be used. + Defaults to None. + + Returns: + `List[float]`: The embedding representing the input text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.sentence_similarity( + ... "Machine learning is so easy.", + ... other_sentences=[ + ... "Deep learning is so straightforward.", + ... "This is so difficult, like rocket science.", + ... "I can't believe how much I struggled with this.", + ... ], + ... ) + [0.7785726189613342, 0.45876261591911316, 0.2906220555305481] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="sentence-similarity", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs={"source_sentence": sentence, "sentences": other_sentences}, + parameters={}, + extra_payload={}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return _bytes_to_list(response) + + def summarization( + self, + text: str, + *, + model: Optional[str] = None, + clean_up_tokenization_spaces: Optional[bool] = None, + generate_parameters: Optional[Dict[str, Any]] = None, + truncation: Optional["SummarizationTruncationStrategy"] = None, + ) -> SummarizationOutput: + """ + Generate a summary of a given text using a specified model. + + Args: + text (`str`): + The input text to summarize. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for summarization will be used. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether to clean up the potential extra spaces in the text output. + generate_parameters (`Dict[str, Any]`, *optional*): + Additional parametrization of the text generation algorithm. + truncation (`"SummarizationTruncationStrategy"`, *optional*): + The truncation strategy to use. + Returns: + [`SummarizationOutput`]: The generated summary text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.summarization("The Eiffel tower...") + SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....") + ``` + """ + parameters = { + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + "generate_parameters": generate_parameters, + "truncation": truncation, + } + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="summarization", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters=parameters, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return SummarizationOutput.parse_obj_as_list(response)[0] + + def table_question_answering( + self, + table: Dict[str, Any], + query: str, + *, + model: Optional[str] = None, + padding: Optional["Padding"] = None, + sequential: Optional[bool] = None, + truncation: Optional[bool] = None, + ) -> TableQuestionAnsweringOutputElement: + """ + Retrieve the answer to a question from information given in a table. + + Args: + table (`str`): + A table of data represented as a dict of lists where entries are headers and the lists are all the + values, all lists must have the same size. + query (`str`): + The query in plain text that you want to ask the table. + model (`str`): + The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face + Hub or a URL to a deployed Inference Endpoint. + padding (`"Padding"`, *optional*): + Activates and controls padding. + sequential (`bool`, *optional*): + Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the + inference to be done sequentially to extract relations within sequences, given their conversational + nature. + truncation (`bool`, *optional*): + Activates and controls truncation. + + Returns: + [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> query = "How many stars does the transformers repository have?" + >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]} + >>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq") + TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs={"query": query, "table": table}, + parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response) + + def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]: + """ + Classifying a target category (a group) based on a set of attributes. + + Args: + table (`Dict[str, Any]`): + Set of attributes to classify. + model (`str`, *optional*): + The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended tabular classification model will be used. + Defaults to None. + + Returns: + `List`: a list of labels, one per row in the initial table. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> table = { + ... "fixed_acidity": ["7.4", "7.8", "10.3"], + ... "volatile_acidity": ["0.7", "0.88", "0.32"], + ... "citric_acid": ["0", "0", "0.45"], + ... "residual_sugar": ["1.9", "2.6", "6.4"], + ... "chlorides": ["0.076", "0.098", "0.073"], + ... "free_sulfur_dioxide": ["11", "25", "5"], + ... "total_sulfur_dioxide": ["34", "67", "13"], + ... "density": ["0.9978", "0.9968", "0.9976"], + ... "pH": ["3.51", "3.2", "3.23"], + ... "sulphates": ["0.56", "0.68", "0.82"], + ... "alcohol": ["9.4", "9.8", "12.6"], + ... } + >>> client.tabular_classification(table=table, model="julien-c/wine-quality") + ["5", "5", "5"] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="tabular-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=None, + extra_payload={"table": table}, + parameters={}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return _bytes_to_list(response) + + def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: + """ + Predicting a numerical target value given a set of attributes/features in a table. + + Args: + table (`Dict[str, Any]`): + Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. + model (`str`, *optional*): + The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended tabular regression model will be used. + Defaults to None. + + Returns: + `List`: a list of predicted numerical target values. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> table = { + ... "Height": ["11.52", "12.48", "12.3778"], + ... "Length1": ["23.2", "24", "23.9"], + ... "Length2": ["25.4", "26.3", "26.5"], + ... "Length3": ["30", "31.2", "31.1"], + ... "Species": ["Bream", "Bream", "Bream"], + ... "Width": ["4.02", "4.3056", "4.6961"], + ... } + >>> client.tabular_regression(table, model="scikit-learn/Fish-Weight") + [110, 120, 130] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="tabular-regression", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=None, + parameters={}, + extra_payload={"table": table}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return _bytes_to_list(response) + + def text_classification( + self, + text: str, + *, + model: Optional[str] = None, + top_k: Optional[int] = None, + function_to_apply: Optional["TextClassificationOutputTransform"] = None, + ) -> List[TextClassificationOutputElement]: + """ + Perform text classification (e.g. sentiment-analysis) on the given text. + + Args: + text (`str`): + A string to be classified. + model (`str`, *optional*): + The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. + Defaults to None. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + function_to_apply (`"TextClassificationOutputTransform"`, *optional*): + The function to apply to the model outputs in order to retrieve the scores. + + Returns: + `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.text_classification("I like you") + [ + TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314), + TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069), + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "function_to_apply": function_to_apply, + "top_k": top_k, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] + + @overload + def text_generation( + self, + prompt: str, + *, + details: Literal[True], + stream: Literal[True], + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Iterable[TextGenerationStreamOutput]: ... + + @overload + def text_generation( + self, + prompt: str, + *, + details: Literal[True], + stream: Optional[Literal[False]] = None, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> TextGenerationOutput: ... + + @overload + def text_generation( + self, + prompt: str, + *, + details: Optional[Literal[False]] = None, + stream: Literal[True], + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Iterable[str]: ... + + @overload + def text_generation( + self, + prompt: str, + *, + details: Optional[Literal[False]] = None, + stream: Optional[Literal[False]] = None, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> str: ... + + @overload + def text_generation( + self, + prompt: str, + *, + details: Optional[bool] = None, + stream: Optional[bool] = None, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]: ... + + def text_generation( + self, + prompt: str, + *, + details: Optional[bool] = None, + stream: Optional[bool] = None, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]: + """ + Given a prompt, generate the following text. + + + + If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method. + It accepts a list of messages instead of a single text prompt and handles the chat templating for you. + + + + Args: + prompt (`str`): + Input text. + details (`bool`, *optional*): + By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens, + probabilities, seed, finish reason, etc.). Only available for models running on with the + `text-generation-inference` backend. + stream (`bool`, *optional*): + By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of + tokens to be returned. Only available for models running on with the `text-generation-inference` + backend. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + adapter_id (`str`, *optional*): + Lora adapter id. + best_of (`int`, *optional*): + Generate best_of sequences and return the one if the highest token logprobs. + decoder_input_details (`bool`, *optional*): + Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken + into account. Defaults to `False`. + do_sample (`bool`, *optional*): + Activate logits sampling + frequency_penalty (`float`, *optional*): + Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in + the text so far, decreasing the model's likelihood to repeat the same line verbatim. + grammar ([`TextGenerationInputGrammarType`], *optional*): + Grammar constraints. Can be either a JSONSchema or a regex. + max_new_tokens (`int`, *optional*): + Maximum number of generated tokens. Defaults to 100. + repetition_penalty (`float`, *optional*): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + return_full_text (`bool`, *optional*): + Whether to prepend the prompt to the generated text + seed (`int`, *optional*): + Random sampling seed + stop (`List[str]`, *optional*): + Stop generating tokens if a member of `stop` is generated. + stop_sequences (`List[str]`, *optional*): + Deprecated argument. Use `stop` instead. + temperature (`float`, *optional*): + The value used to module the logits distribution. + top_n_tokens (`int`, *optional*): + Return information about the `top_n_tokens` most likely tokens at each generation step, instead of + just the sampled token. + top_k (`int`, *optional`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`, *optional`): + Truncate inputs tokens to the given size. + typical_p (`float`, *optional`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`, *optional*): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + + Returns: + `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`: + Generated text returned from the server: + - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) + - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` + - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`] + - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`] + + Raises: + `ValidationError`: + If input values are not valid. No HTTP call is made to the server. + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + # Case 1: generate text + >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12) + '100% open source and built to be easy to use.' + + # Case 2: iterate over the generated tokens. Useful for large generation. + >>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True): + ... print(token) + 100 + % + open + source + and + built + to + be + easy + to + use + . + + # Case 3: get more details about the generation process. + >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True) + TextGenerationOutput( + generated_text='100% open source and built to be easy to use.', + details=TextGenerationDetails( + finish_reason='length', + generated_tokens=12, + seed=None, + prefill=[ + TextGenerationPrefillOutputToken(id=487, text='The', logprob=None), + TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875), + (...) + TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625) + ], + tokens=[ + TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), + TokenElement(id=16, text='%', logprob=-0.0463562, special=False), + (...) + TokenElement(id=25, text='.', logprob=-0.5703125, special=False) + ], + best_of_sequences=None + ) + ) + + # Case 4: iterate over the generated tokens with more details. + # Last object is more complete, containing the full generated text and the finish reason. + >>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True): + ... print(details) + ... + TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement( + id=25, + text='.', + logprob=-0.5703125, + special=False), + generated_text='100% open source and built to be easy to use.', + details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None) + ) + + # Case 5: generate constrained output using grammar + >>> response = client.text_generation( + ... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park", + ... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + ... max_new_tokens=100, + ... repetition_penalty=1.3, + ... grammar={ + ... "type": "json", + ... "value": { + ... "properties": { + ... "location": {"type": "string"}, + ... "activity": {"type": "string"}, + ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, + ... "animals": {"type": "array", "items": {"type": "string"}}, + ... }, + ... "required": ["location", "activity", "animals_seen", "animals"], + ... }, + ... }, + ... ) + >>> json.loads(response) + { + "activity": "bike riding", + "animals": ["puppy", "cat", "raccoon"], + "animals_seen": 3, + "location": "park" + } + ``` + """ + if decoder_input_details and not details: + warnings.warn( + "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that" + " the output from the server will be truncated." + ) + decoder_input_details = False + + if stop_sequences is not None: + warnings.warn( + "`stop_sequences` is a deprecated argument for `text_generation` task" + " and will be removed in version '0.28.0'. Use `stop` instead.", + FutureWarning, + ) + if stop is None: + stop = stop_sequences # use deprecated arg if provided + + # Build payload + parameters = { + "adapter_id": adapter_id, + "best_of": best_of, + "decoder_input_details": decoder_input_details, + "details": details, + "do_sample": do_sample, + "frequency_penalty": frequency_penalty, + "grammar": grammar, + "max_new_tokens": max_new_tokens, + "repetition_penalty": repetition_penalty, + "return_full_text": return_full_text, + "seed": seed, + "stop": stop, + "temperature": temperature, + "top_k": top_k, + "top_n_tokens": top_n_tokens, + "top_p": top_p, + "truncate": truncate, + "typical_p": typical_p, + "watermark": watermark, + } + + # Remove some parameters if not a TGI server + unsupported_kwargs = _get_unsupported_text_generation_kwargs(model) + if len(unsupported_kwargs) > 0: + # The server does not support some parameters + # => means it is not a TGI server + # => remove unsupported parameters and warn the user + + ignored_parameters = [] + for key in unsupported_kwargs: + if parameters.get(key): + ignored_parameters.append(key) + parameters.pop(key, None) + if len(ignored_parameters) > 0: + warnings.warn( + "API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:" + f" {', '.join(ignored_parameters)}.", + UserWarning, + ) + if details: + warnings.warn( + "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will" + " be ignored meaning only the generated text will be returned.", + UserWarning, + ) + details = False + if stream: + raise ValueError( + "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream." + " Please pass `stream=False` as input." + ) + + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-generation", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=prompt, + parameters=parameters, + extra_payload={"stream": stream}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + + # Handle errors separately for more precise error messages + try: + bytes_output = self._inner_post(request_parameters, stream=stream or False) + except HTTPError as e: + match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e)) + if isinstance(e, BadRequestError) and match: + unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")] + _set_unsupported_text_generation_kwargs(model, unused_params) + return self.text_generation( # type: ignore + prompt=prompt, + details=details, + stream=stream, + model=model_id, + adapter_id=adapter_id, + best_of=best_of, + decoder_input_details=decoder_input_details, + do_sample=do_sample, + frequency_penalty=frequency_penalty, + grammar=grammar, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop, + temperature=temperature, + top_k=top_k, + top_n_tokens=top_n_tokens, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + ) + raise_text_generation_error(e) + + # Parse output + if stream: + return _stream_text_generation_response(bytes_output, details) # type: ignore + + data = _bytes_to_dict(bytes_output) # type: ignore[arg-type] + + # Data can be a single element (dict) or an iterable of dicts where we select the first element of. + if isinstance(data, list): + data = data[0] + response = provider_helper.get_response(data, request_parameters) + return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"] + + def text_to_image( + self, + prompt: str, + *, + negative_prompt: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + model: Optional[str] = None, + scheduler: Optional[str] = None, + seed: Optional[int] = None, + extra_body: Optional[Dict[str, Any]] = None, + ) -> "Image": + """ + Generate an image based on a given text using a specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + + You can pass provider-specific parameters to the model by using the `extra_body` argument. + + + Args: + prompt (`str`): + The prompt to generate an image from. + negative_prompt (`str`, *optional*): + One prompt to guide what NOT to include in image generation. + height (`int`, *optional*): + The height in pixels of the output image + width (`int`, *optional*): + The width in pixels of the output image + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*): + A higher guidance scale value encourages the model to generate images closely linked to the text + prompt, but values too high may cause saturation and other artifacts. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended text-to-image model will be used. + Defaults to None. + scheduler (`str`, *optional*): + Override the scheduler with a compatible one. + seed (`int`, *optional*): + Seed for the random number generator. + extra_body (`Dict[str, Any]`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + + Returns: + `Image`: The generated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + >>> image = client.text_to_image("An astronaut riding a horse on the moon.") + >>> image.save("astronaut.png") + + >>> image = client.text_to_image( + ... "An astronaut riding a horse on the moon.", + ... negative_prompt="low resolution, blurry", + ... model="stabilityai/stable-diffusion-2-1", + ... ) + >>> image.save("better_astronaut.png") + ``` + Example using a third-party provider directly. Usage will be billed on your fal.ai account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="fal-ai", # Use fal.ai provider + ... api_key="fal-ai-api-key", # Pass your fal.ai API key + ... ) + >>> image = client.text_to_image( + ... "A majestic lion in a fantasy forest", + ... model="black-forest-labs/FLUX.1-schnell", + ... ) + >>> image.save("lion.png") + ``` + + Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", # Use replicate provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> image = client.text_to_image( + ... "An astronaut riding a horse on the moon.", + ... model="black-forest-labs/FLUX.1-dev", + ... ) + >>> image.save("astronaut.png") + ``` + + Example using Replicate provider with extra parameters + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", # Use replicate provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> image = client.text_to_image( + ... "An astronaut riding a horse on the moon.", + ... model="black-forest-labs/FLUX.1-schnell", + ... extra_body={"output_quality": 100}, + ... ) + >>> image.save("astronaut.png") + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=prompt, + parameters={ + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "scheduler": scheduler, + "seed": seed, + **(extra_body or {}), + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + response = provider_helper.get_response(response) + return _bytes_to_image(response) + + def text_to_video( + self, + prompt: str, + *, + model: Optional[str] = None, + guidance_scale: Optional[float] = None, + negative_prompt: Optional[List[str]] = None, + num_frames: Optional[float] = None, + num_inference_steps: Optional[int] = None, + seed: Optional[int] = None, + extra_body: Optional[Dict[str, Any]] = None, + ) -> bytes: + """ + Generate a video based on a given text. + + + You can pass provider-specific parameters to the model by using the `extra_body` argument. + + + Args: + prompt (`str`): + The prompt to generate a video from. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended text-to-video model will be used. + Defaults to None. + guidance_scale (`float`, *optional*): + A higher guidance scale value encourages the model to generate videos closely linked to the text + prompt, but values too high may cause saturation and other artifacts. + negative_prompt (`List[str]`, *optional*): + One or several prompt to guide what NOT to include in video generation. + num_frames (`float`, *optional*): + The num_frames parameter determines how many video frames are generated. + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + seed (`int`, *optional*): + Seed for the random number generator. + extra_body (`Dict[str, Any]`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + + Returns: + `bytes`: The generated video. + + Example: + + Example using a third-party provider directly. Usage will be billed on your fal.ai account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="fal-ai", # Using fal.ai provider + ... api_key="fal-ai-api-key", # Pass your fal.ai API key + ... ) + >>> video = client.text_to_video( + ... "A majestic lion running in a fantasy forest", + ... model="tencent/HunyuanVideo", + ... ) + >>> with open("lion.mp4", "wb") as file: + ... file.write(video) + ``` + + Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", # Using replicate provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> video = client.text_to_video( + ... "A cat running in a park", + ... model="genmo/mochi-1-preview", + ... ) + >>> with open("cat.mp4", "wb") as file: + ... file.write(video) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=prompt, + parameters={ + "guidance_scale": guidance_scale, + "negative_prompt": negative_prompt, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "seed": seed, + **(extra_body or {}), + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) + return response + + def text_to_speech( + self, + text: str, + *, + model: Optional[str] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None, + epsilon_cutoff: Optional[float] = None, + eta_cutoff: Optional[float] = None, + max_length: Optional[int] = None, + max_new_tokens: Optional[int] = None, + min_length: Optional[int] = None, + min_new_tokens: Optional[int] = None, + num_beam_groups: Optional[int] = None, + num_beams: Optional[int] = None, + penalty_alpha: Optional[float] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + use_cache: Optional[bool] = None, + extra_body: Optional[Dict[str, Any]] = None, + ) -> bytes: + """ + Synthesize an audio of a voice pronouncing a given text. + + + You can pass provider-specific parameters to the model by using the `extra_body` argument. + + + Args: + text (`str`): + The text to synthesize. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended text-to-speech model will be used. + Defaults to None. + do_sample (`bool`, *optional*): + Whether to use sampling instead of greedy decoding when generating new tokens. + early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"]`, *optional*): + Controls the stopping condition for beam-based methods. + epsilon_cutoff (`float`, *optional*): + If set to float strictly between 0 and 1, only tokens with a conditional probability greater than + epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on + the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. + eta_cutoff (`float`, *optional*): + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly + between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) + * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token + probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, + depending on the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. + max_length (`int`, *optional*): + The maximum length (in tokens) of the generated text, including the input. + max_new_tokens (`int`, *optional*): + The maximum number of tokens to generate. Takes precedence over max_length. + min_length (`int`, *optional*): + The minimum length (in tokens) of the generated text, including the input. + min_new_tokens (`int`, *optional*): + The minimum number of tokens to generate. Takes precedence over min_length. + num_beam_groups (`int`, *optional*): + Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. + See [this paper](https://hf.co/papers/1610.02424) for more details. + num_beams (`int`, *optional*): + Number of beams to use for beam search. + penalty_alpha (`float`, *optional*): + The value balances the model confidence and the degeneration penalty in contrastive search decoding. + temperature (`float`, *optional*): + The value used to modulate the next token probabilities. + top_k (`int`, *optional*): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to + top_p or higher are kept for generation. + typical_p (`float`, *optional*): + Local typicality measures how similar the conditional probability of predicting a target token next is + to the expected conditional probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most locally typical tokens with + probabilities that add up to typical_p or higher are kept for generation. See [this + paper](https://hf.co/papers/2202.00666) for more details. + use_cache (`bool`, *optional*): + Whether the model should use the past last key/values attentions to speed up decoding + extra_body (`Dict[str, Any]`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + Returns: + `bytes`: The generated audio. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + >>> audio = client.text_to_speech("Hello world") + >>> Path("hello_world.flac").write_bytes(audio) + ``` + + Example using a third-party provider directly. Usage will be billed on your Replicate account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", + ... api_key="your-replicate-api-key", # Pass your Replicate API key directly + ... ) + >>> audio = client.text_to_speech( + ... text="Hello world", + ... model="OuteAI/OuteTTS-0.3-500M", + ... ) + >>> Path("hello_world.flac").write_bytes(audio) + ``` + + Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", + ... api_key="hf_...", # Pass your HF token + ... ) + >>> audio =client.text_to_speech( + ... text="Hello world", + ... model="OuteAI/OuteTTS-0.3-500M", + ... ) + >>> Path("hello_world.flac").write_bytes(audio) + ``` + Example using Replicate provider with extra parameters + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", # Use replicate provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> audio = client.text_to_speech( + ... "Hello, my name is Kororo, an awesome text-to-speech model.", + ... model="hexgrad/Kokoro-82M", + ... extra_body={"voice": "af_nicole"}, + ... ) + >>> Path("hello.flac").write_bytes(audio) + ``` + + Example music-gen using "YuE-s1-7B-anneal-en-cot" on fal.ai + ```py + >>> from huggingface_hub import InferenceClient + >>> lyrics = ''' + ... [verse] + ... In the town where I was born + ... Lived a man who sailed to sea + ... And he told us of his life + ... In the land of submarines + ... So we sailed on to the sun + ... 'Til we found a sea of green + ... And we lived beneath the waves + ... In our yellow submarine + + ... [chorus] + ... We all live in a yellow submarine + ... Yellow submarine, yellow submarine + ... We all live in a yellow submarine + ... Yellow submarine, yellow submarine + ... ''' + >>> genres = "pavarotti-style tenor voice" + >>> client = InferenceClient( + ... provider="fal-ai", + ... model="m-a-p/YuE-s1-7B-anneal-en-cot", + ... api_key=..., + ... ) + >>> audio = client.text_to_speech(lyrics, extra_body={"genres": genres}) + >>> with open("output.mp3", "wb") as f: + ... f.write(audio) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-to-speech", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "do_sample": do_sample, + "early_stopping": early_stopping, + "epsilon_cutoff": epsilon_cutoff, + "eta_cutoff": eta_cutoff, + "max_length": max_length, + "max_new_tokens": max_new_tokens, + "min_length": min_length, + "min_new_tokens": min_new_tokens, + "num_beam_groups": num_beam_groups, + "num_beams": num_beams, + "penalty_alpha": penalty_alpha, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "typical_p": typical_p, + "use_cache": use_cache, + **(extra_body or {}), + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + response = provider_helper.get_response(response) + return response + + def token_classification( + self, + text: str, + *, + model: Optional[str] = None, + aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None, + ignore_labels: Optional[List[str]] = None, + stride: Optional[int] = None, + ) -> List[TokenClassificationOutputElement]: + """ + Perform token classification on the given text. + Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. + + Args: + text (`str`): + A string to be classified. + model (`str`, *optional*): + The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. + Defaults to None. + aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*): + The strategy used to fuse tokens based on model predictions + ignore_labels (`List[str`, *optional*): + A list of labels to ignore + stride (`int`, *optional*): + The number of overlapping tokens between chunks when splitting the input text. + + Returns: + `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica") + [ + TokenClassificationOutputElement( + entity_group='PER', + score=0.9971321225166321, + word='Sarah Jessica Parker', + start=11, + end=31, + ), + TokenClassificationOutputElement( + entity_group='PER', + score=0.9773476123809814, + word='Jessica', + start=52, + end=59, + ) + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="token-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "aggregation_strategy": aggregation_strategy, + "ignore_labels": ignore_labels, + "stride": stride, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return TokenClassificationOutputElement.parse_obj_as_list(response) + + def translation( + self, + text: str, + *, + model: Optional[str] = None, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + clean_up_tokenization_spaces: Optional[bool] = None, + truncation: Optional["TranslationTruncationStrategy"] = None, + generate_parameters: Optional[Dict[str, Any]] = None, + ) -> TranslationOutput: + """ + Convert text from one language to another. + + Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for + your specific use case. Source and target languages usually depend on the model. + However, it is possible to specify source and target languages for certain models. If you are working with one of these models, + you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. + + Args: + text (`str`): + A string to be translated. + model (`str`, *optional*): + The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. + Defaults to None. + src_lang (`str`, *optional*): + The source language of the text. Required for models that can translate from multiple languages. + tgt_lang (`str`, *optional*): + Target language to translate to. Required for models that can translate to multiple languages. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether to clean up the potential extra spaces in the text output. + truncation (`"TranslationTruncationStrategy"`, *optional*): + The truncation strategy to use. + generate_parameters (`Dict[str, Any]`, *optional*): + Additional parametrization of the text generation algorithm. + + Returns: + [`TranslationOutput`]: The generated translated text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If only one of the `src_lang` and `tgt_lang` arguments are provided. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.translation("My name is Wolfgang and I live in Berlin") + 'Mein Name ist Wolfgang und ich lebe in Berlin.' + >>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") + TranslationOutput(translation_text='Je m'appelle Wolfgang et je vis à Berlin.') + ``` + + Specifying languages: + ```py + >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") + "Mon nom est Sarah Jessica Parker mais vous pouvez m'appeler Jessica" + ``` + """ + # Throw error if only one of `src_lang` and `tgt_lang` was given + if src_lang is not None and tgt_lang is None: + raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") + + if src_lang is None and tgt_lang is not None: + raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") + + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="translation", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "src_lang": src_lang, + "tgt_lang": tgt_lang, + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + "truncation": truncation, + "generate_parameters": generate_parameters, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return TranslationOutput.parse_obj_as_list(response)[0] + + def visual_question_answering( + self, + image: ContentT, + question: str, + *, + model: Optional[str] = None, + top_k: Optional[int] = None, + ) -> List[VisualQuestionAnsweringOutputElement]: + """ + Answering open-ended questions based on an image. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image for the context. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + question (`str`): + Question to be answered. + model (`str`, *optional*): + The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. + Defaults to None. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + topk answers if there are not enough options available within the context. + Returns: + `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. + + Raises: + `InferenceTimeoutError`: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.visual_question_answering( + ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg", + ... question="What is the animal doing?" + ... ) + [ + VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'), + VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'), + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="visual-question-answering", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={"top_k": top_k}, + headers=self.headers, + model=model_id, + api_key=self.token, + extra_payload={"question": question, "image": _b64_encode(image)}, + ) + response = self._inner_post(request_parameters) + return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response) + + def zero_shot_classification( + self, + text: str, + candidate_labels: List[str], + *, + multi_label: Optional[bool] = False, + hypothesis_template: Optional[str] = None, + model: Optional[str] = None, + ) -> List[ZeroShotClassificationOutputElement]: + """ + Provide as input a text and a set of candidate labels to classify the input text. + + Args: + text (`str`): + The input text to classify. + candidate_labels (`List[str]`): + The set of possible class labels to classify the text into. + labels (`List[str]`, *optional*): + (deprecated) List of strings. Each string is the verbalization of a possible label for the input text. + multi_label (`bool`, *optional*): + Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of + the label likelihoods for each sequence is 1. If true, the labels are considered independent and + probabilities are normalized for each candidate. + hypothesis_template (`str`, *optional*): + The sentence used in conjunction with `candidate_labels` to attempt the text classification by + replacing the placeholder with the candidate labels. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used. + + + Returns: + `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example with `multi_label=False`: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> text = ( + ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's" + ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" + ... " mysteries when he went for a run up a hill in Nice, France." + ... ) + >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"] + >>> client.zero_shot_classification(text, labels) + [ + ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684), + ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566), + ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627), + ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581), + ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447), + ] + >>> client.zero_shot_classification(text, labels, multi_label=True) + [ + ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311), + ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844), + ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714), + ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327), + ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354), + ] + ``` + + Example with `multi_label=True` and a custom `hypothesis_template`: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.zero_shot_classification( + ... text="I really like our dinner and I'm very happy. I don't like the weather though.", + ... labels=["positive", "negative", "pessimistic", "optimistic"], + ... multi_label=True, + ... hypothesis_template="This text is {} towards the weather" + ... ) + [ + ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467), + ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134), + ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062), + ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363) + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="zero-shot-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "candidate_labels": candidate_labels, + "multi_label": multi_label, + "hypothesis_template": hypothesis_template, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + output = _bytes_to_dict(response) + return [ + ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score}) + for label, score in zip(output["labels"], output["scores"]) + ] + + def zero_shot_image_classification( + self, + image: ContentT, + candidate_labels: List[str], + *, + model: Optional[str] = None, + hypothesis_template: Optional[str] = None, + # deprecated argument + labels: List[str] = None, # type: ignore + ) -> List[ZeroShotImageClassificationOutputElement]: + """ + Provide input image and text labels to predict text labels for the image. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + candidate_labels (`List[str]`): + The candidate labels for this image + labels (`List[str]`, *optional*): + (deprecated) List of string possible labels. There must be at least 2 labels. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used. + hypothesis_template (`str`, *optional*): + The sentence used in conjunction with `candidate_labels` to attempt the image classification by + replacing the placeholder with the candidate labels. + + Returns: + `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + >>> client.zero_shot_image_classification( + ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg", + ... labels=["dog", "cat", "horse"], + ... ) + [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...] + ``` + """ + # Raise ValueError if input is less than 2 labels + if len(candidate_labels) < 2: + raise ValueError("You must specify at least 2 classes to compare.") + + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={ + "candidate_labels": candidate_labels, + "hypothesis_template": hypothesis_template, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = self._inner_post(request_parameters) + return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) + + @_deprecate_method( + version="0.35.0", + message=( + "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)." + " Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider." + ), + ) + def list_deployed_models( + self, frameworks: Union[None, str, Literal["all"], List[str]] = None + ) -> Dict[str, List[str]]: + """ + List models deployed on the HF Serverless Inference API service. + + This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that + are supported and account for 95% of the hosted models. However, if you want a complete list of models you can + specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested + in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more + frameworks are checked, the more time it will take. + + + + This endpoint method does not return a live list of all models available for the HF Inference API service. + It searches over a cached list of models that were recently available and the list may not be up to date. + If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`]. + + + + + + This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to + check its availability, you can directly use [`~InferenceClient.get_model_status`]. + + + + Args: + frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*): + The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to + "all", all available frameworks will be tested. It is also possible to provide a single framework or a + custom set of frameworks to check. + + Returns: + `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs. + + Example: + ```python + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + # Discover zero-shot-classification models currently deployed + >>> models = client.list_deployed_models() + >>> models["zero-shot-classification"] + ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...] + + # List from only 1 framework + >>> client.list_deployed_models("text-generation-inference") + {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...} + ``` + """ + if self.provider != "hf-inference": + raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.") + + # Resolve which frameworks to check + if frameworks is None: + frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS + elif frameworks == "all": + frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS + elif isinstance(frameworks, str): + frameworks = [frameworks] + frameworks = list(set(frameworks)) + + # Fetch them iteratively + models_by_task: Dict[str, List[str]] = {} + + def _unpack_response(framework: str, items: List[Dict]) -> None: + for model in items: + if framework == "sentence-transformers": + # Model running with the `sentence-transformers` framework can work with both tasks even if not + # branded as such in the API response + models_by_task.setdefault("feature-extraction", []).append(model["model_id"]) + models_by_task.setdefault("sentence-similarity", []).append(model["model_id"]) + else: + models_by_task.setdefault(model["task"], []).append(model["model_id"]) + + for framework in frameworks: + response = get_session().get( + f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token) + ) + hf_raise_for_status(response) + _unpack_response(framework, response.json()) + + # Sort alphabetically for discoverability and return + for task, models in models_by_task.items(): + models_by_task[task] = sorted(set(models), key=lambda x: x.lower()) + return models_by_task + + def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: + """ + Get information about the deployed endpoint. + + This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). + Endpoints powered by `transformers` return an empty payload. + + Args: + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `Dict[str, Any]`: Information about the endpoint. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> client.get_endpoint_info() + { + 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct', + 'model_sha': None, + 'model_dtype': 'torch.float16', + 'model_device_type': 'cuda', + 'model_pipeline_tag': None, + 'max_concurrent_requests': 128, + 'max_best_of': 2, + 'max_stop_sequences': 4, + 'max_input_length': 8191, + 'max_total_tokens': 8192, + 'waiting_served_ratio': 0.3, + 'max_batch_total_tokens': 1259392, + 'max_waiting_tokens': 20, + 'max_batch_size': None, + 'validation_workers': 32, + 'max_client_batch_size': 4, + 'version': '2.0.2', + 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214', + 'docker_label': 'sha-dccab72' + } + ``` + """ + if self.provider != "hf-inference": + raise ValueError(f"Getting endpoint info is not supported on '{self.provider}'.") + + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if model.startswith(("http://", "https://")): + url = model.rstrip("/") + "/info" + else: + url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info" + + response = get_session().get(url, headers=build_hf_headers(token=self.token)) + hf_raise_for_status(response) + return response.json() + + def health_check(self, model: Optional[str] = None) -> bool: + """ + Check the health of the deployed endpoint. + + Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). + For Inference API, please use [`InferenceClient.get_model_status`] instead. + + Args: + model (`str`, *optional*): + URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `bool`: True if everything is working fine. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud") + >>> client.health_check() + True + ``` + """ + if self.provider != "hf-inference": + raise ValueError(f"Health check is not supported on '{self.provider}'.") + + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if not model.startswith(("http://", "https://")): + raise ValueError( + "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`." + ) + url = model.rstrip("/") + "/health" + + response = get_session().get(url, headers=build_hf_headers(token=self.token)) + return response.status_code == 200 + + @_deprecate_method( + version="0.35.0", + message=( + "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)." + " Use `HfApi.model_info` to get the model status both with HF Inference API and external providers." + ), + ) + def get_model_status(self, model: Optional[str] = None) -> ModelStatus: + """ + Get the status of a model hosted on the HF Inference API. + + + + This endpoint is mostly useful when you already know which model you want to use and want to check its + availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`]. + + + + Args: + model (`str`, *optional*): + Identifier of the model for witch the status gonna be checked. If model is not provided, + the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the + identifier cannot be a URL. + + + Returns: + [`ModelStatus`]: An instance of ModelStatus dataclass, containing information, + about the state of the model: load, state, compute type and framework. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct") + ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference') + ``` + """ + if self.provider != "hf-inference": + raise ValueError(f"Getting model status is not supported on '{self.provider}'.") + + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if model.startswith("https://"): + raise NotImplementedError("Model status is only available for Inference API endpoints.") + url = f"{constants.INFERENCE_ENDPOINT}/status/{model}" + + response = get_session().get(url, headers=build_hf_headers(token=self.token)) + hf_raise_for_status(response) + response_data = response.json() + + if "error" in response_data: + raise ValueError(response_data["error"]) + + return ModelStatus( + loaded=response_data["loaded"], + state=response_data["state"], + compute_type=response_data["compute_type"], + framework=response_data["framework"], + ) + + @property + def chat(self) -> "ProxyClientChat": + return ProxyClientChat(self) + + +class _ProxyClient: + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + def __init__(self, client: InferenceClient): + self._client = client + + +class ProxyClientChat(_ProxyClient): + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + @property + def completions(self) -> "ProxyClientChatCompletions": + return ProxyClientChatCompletions(self._client) + + +class ProxyClientChatCompletions(_ProxyClient): + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + @property + def create(self): + return self._client.chat_completion diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_common.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..84636b212ccf7c87bd882c4f1a4056b5c7ec9e05 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_common.py @@ -0,0 +1,457 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities used by both the sync and async inference clients.""" + +import base64 +import io +import json +import logging +import mimetypes +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + BinaryIO, + ContextManager, + Dict, + Generator, + Iterable, + List, + Literal, + NoReturn, + Optional, + Union, + overload, +) + +from requests import HTTPError + +from huggingface_hub.errors import ( + GenerationError, + IncompleteGenerationError, + OverloadedError, + TextGenerationError, + UnknownError, + ValidationError, +) + +from ..utils import get_session, is_aiohttp_available, is_numpy_available, is_pillow_available +from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput + + +if TYPE_CHECKING: + from aiohttp import ClientResponse, ClientSession + from PIL.Image import Image + +# TYPES +UrlT = str +PathT = Union[str, Path] +BinaryT = Union[bytes, BinaryIO] +ContentT = Union[BinaryT, PathT, UrlT, "Image"] + +# Use to set a Accept: image/png header +TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"} + +logger = logging.getLogger(__name__) + + +@dataclass +class RequestParameters: + url: str + task: str + model: Optional[str] + json: Optional[Union[str, Dict, List]] + data: Optional[ContentT] + headers: Dict[str, Any] + + +# Add dataclass for ModelStatus. We use this dataclass in get_model_status function. +@dataclass +class ModelStatus: + """ + This Dataclass represents the model status in the HF Inference API. + + Args: + loaded (`bool`): + If the model is currently loaded into HF's Inference API. Models + are loaded on-demand, leading to the user's first request taking longer. + If a model is loaded, you can be assured that it is in a healthy state. + state (`str`): + The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'. + If a model's state is 'Loadable', it's not too big and has a supported + backend. Loadable models are automatically loaded when the user first + requests inference on the endpoint. This means it is transparent for the + user to load a model, except that the first call takes longer to complete. + compute_type (`Dict`): + Information about the compute resource the model is using or will use, such as 'gpu' type and number of + replicas. + framework (`str`): + The name of the framework that the model was built with, such as 'transformers' + or 'text-generation-inference'. + """ + + loaded: bool + state: str + compute_type: Dict + framework: str + + +## IMPORT UTILS + + +def _import_aiohttp(): + # Make sure `aiohttp` is installed on the machine. + if not is_aiohttp_available(): + raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).") + import aiohttp + + return aiohttp + + +def _import_numpy(): + """Make sure `numpy` is installed on the machine.""" + if not is_numpy_available(): + raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).") + import numpy + + return numpy + + +def _import_pil_image(): + """Make sure `PIL` is installed on the machine.""" + if not is_pillow_available(): + raise ImportError( + "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be" + " post-processed, use `client.post(...)` and get the raw response from the server." + ) + from PIL import Image + + return Image + + +## ENCODING / DECODING UTILS + + +@overload +def _open_as_binary( + content: ContentT, +) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None" + + +@overload +def _open_as_binary( + content: Literal[None], +) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None" + + +@contextmanager # type: ignore +def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]: + """Open `content` as a binary file, either from a URL, a local path, raw bytes, or a PIL Image. + + Do nothing if `content` is None. + + TODO: handle base64 as input + """ + # If content is a string => must be either a URL or a path + if isinstance(content, str): + if content.startswith("https://") or content.startswith("http://"): + logger.debug(f"Downloading content from {content}") + yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ? + return + content = Path(content) + if not content.exists(): + raise FileNotFoundError( + f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local" + " file. To pass raw content, please encode it as bytes first." + ) + + # If content is a Path => open it + if isinstance(content, Path): + logger.debug(f"Opening content from {content}") + with content.open("rb") as f: + yield f + return + + # If content is a PIL Image => convert to bytes + if is_pillow_available(): + from PIL import Image + + if isinstance(content, Image.Image): + logger.debug("Converting PIL Image to bytes") + buffer = io.BytesIO() + content.save(buffer, format=content.format or "PNG") + yield buffer.getvalue() + return + + # Otherwise: already a file-like object or None + yield content # type: ignore + + +def _b64_encode(content: ContentT) -> str: + """Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL.""" + with _open_as_binary(content) as data: + data_as_bytes = data if isinstance(data, bytes) else data.read() + return base64.b64encode(data_as_bytes).decode() + + +def _as_url(content: ContentT, default_mime_type: str) -> str: + if isinstance(content, str) and (content.startswith("https://") or content.startswith("http://")): + return content + + # Handle MIME type detection for different content types + mime_type = None + if isinstance(content, (str, Path)): + mime_type = mimetypes.guess_type(content, strict=False)[0] + elif is_pillow_available(): + from PIL import Image + + if isinstance(content, Image.Image): + # Determine MIME type from PIL Image format, in sync with `_open_as_binary` + mime_type = f"image/{(content.format or 'PNG').lower()}" + + mime_type = mime_type or default_mime_type + encoded_data = _b64_encode(content) + return f"data:{mime_type};base64,{encoded_data}" + + +def _b64_to_image(encoded_image: str) -> "Image": + """Parse a base64-encoded string into a PIL Image.""" + Image = _import_pil_image() + return Image.open(io.BytesIO(base64.b64decode(encoded_image))) + + +def _bytes_to_list(content: bytes) -> List: + """Parse bytes from a Response object into a Python list. + + Expects the response body to be JSON-encoded data. + + NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a + dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. + """ + return json.loads(content.decode()) + + +def _bytes_to_dict(content: bytes) -> Dict: + """Parse bytes from a Response object into a Python dictionary. + + Expects the response body to be JSON-encoded data. + + NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a + list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. + """ + return json.loads(content.decode()) + + +def _bytes_to_image(content: bytes) -> "Image": + """Parse bytes from a Response object into a PIL Image. + + Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead. + """ + Image = _import_pil_image() + return Image.open(io.BytesIO(content)) + + +def _as_dict(response: Union[bytes, Dict]) -> Dict: + return json.loads(response) if isinstance(response, bytes) else response + + +## PAYLOAD UTILS + + +## STREAMING UTILS + + +def _stream_text_generation_response( + bytes_output_as_lines: Iterable[bytes], details: bool +) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]: + """Used in `InferenceClient.text_generation`.""" + # Parse ServerSentEvents + for byte_payload in bytes_output_as_lines: + try: + output = _format_text_generation_stream_output(byte_payload, details) + except StopIteration: + break + if output is not None: + yield output + + +async def _async_stream_text_generation_response( + bytes_output_as_lines: AsyncIterable[bytes], details: bool +) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: + """Used in `AsyncInferenceClient.text_generation`.""" + # Parse ServerSentEvents + async for byte_payload in bytes_output_as_lines: + try: + output = _format_text_generation_stream_output(byte_payload, details) + except StopIteration: + break + if output is not None: + yield output + + +def _format_text_generation_stream_output( + byte_payload: bytes, details: bool +) -> Optional[Union[str, TextGenerationStreamOutput]]: + if not byte_payload.startswith(b"data:"): + return None # empty line + + if byte_payload.strip() == b"data: [DONE]": + raise StopIteration("[DONE] signal received.") + + # Decode payload + payload = byte_payload.decode("utf-8") + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + + # Either an error as being returned + if json_payload.get("error") is not None: + raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) + + # Or parse token payload + output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload) + return output.token.text if not details else output + + +def _stream_chat_completion_response( + bytes_lines: Iterable[bytes], +) -> Iterable[ChatCompletionStreamOutput]: + """Used in `InferenceClient.chat_completion` if model is served with TGI.""" + for item in bytes_lines: + try: + output = _format_chat_completion_stream_output(item) + except StopIteration: + break + if output is not None: + yield output + + +async def _async_stream_chat_completion_response( + bytes_lines: AsyncIterable[bytes], +) -> AsyncIterable[ChatCompletionStreamOutput]: + """Used in `AsyncInferenceClient.chat_completion`.""" + async for item in bytes_lines: + try: + output = _format_chat_completion_stream_output(item) + except StopIteration: + break + if output is not None: + yield output + + +def _format_chat_completion_stream_output( + byte_payload: bytes, +) -> Optional[ChatCompletionStreamOutput]: + if not byte_payload.startswith(b"data:"): + return None # empty line + + if byte_payload.strip() == b"data: [DONE]": + raise StopIteration("[DONE] signal received.") + + # Decode payload + payload = byte_payload.decode("utf-8") + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + + # Either an error as being returned + if json_payload.get("error") is not None: + raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) + + # Or parse token payload + return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) + + +async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]: + try: + async for byte_payload in response.content: + yield byte_payload.strip() + finally: + # Always close the underlying HTTP session to avoid resource leaks + await client.close() + + +# "TGI servers" are servers running with the `text-generation-inference` backend. +# This backend is the go-to solution to run large language models at scale. However, +# for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference` +# solution is still in use. +# +# Both approaches have very similar APIs, but not exactly the same. What we do first in +# the `text_generation` method is to assume the model is served via TGI. If we realize +# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the +# default API with a warning message. When that's the case, We remember the unsupported +# attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable. +# +# In addition, TGI servers have a built-in API route for chat-completion, which is not +# available on the default API. We use this route to provide a more consistent behavior +# when available. +# +# For more details, see https://github.com/huggingface/text-generation-inference and +# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task. + +_UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {} + + +def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None: + _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs) + + +def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]: + return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, []) + + +# TEXT GENERATION ERRORS +# ---------------------- +# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation +# inference project (https://github.com/huggingface/text-generation-inference). +# ---------------------- + + +def raise_text_generation_error(http_error: HTTPError) -> NoReturn: + """ + Try to parse text-generation-inference error message and raise HTTPError in any case. + + Args: + error (`HTTPError`): + The HTTPError that have been raised. + """ + # Try to parse a Text Generation Inference error + + try: + # Hacky way to retrieve payload in case of aiohttp error + payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() + error = payload.get("error") + error_type = payload.get("error_type") + except Exception: # no payload + raise http_error + + # If error_type => more information than `hf_raise_for_status` + if error_type is not None: + exception = _parse_text_generation_error(error, error_type) + raise exception from http_error + + # Otherwise, fallback to default error + raise http_error + + +def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: + if error_type == "generation": + return GenerationError(error) # type: ignore + if error_type == "incomplete_generation": + return IncompleteGenerationError(error) # type: ignore + if error_type == "overloaded": + return OverloadedError(error) # type: ignore + if error_type == "validation": + return ValidationError(error) # type: ignore + return UnknownError(error) # type: ignore diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d2d0acb5bf08ddb86f942ab78122bd55a0ece04 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/_async_client.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/_async_client.py new file mode 100644 index 0000000000000000000000000000000000000000..55fd9e69ca583708a98630bdabd9b592cf07360a --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/_async_client.py @@ -0,0 +1,3665 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WARNING +# This entire file has been adapted from the sync-client code in `src/huggingface_hub/inference/_client.py`. +# Any change in InferenceClient will be automatically reflected in AsyncInferenceClient. +# To re-generate the code, run `make style` or `python ./utils/generate_async_inference_client.py --update`. +# WARNING +import asyncio +import base64 +import logging +import re +import warnings +from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Set, Union, overload + +from huggingface_hub import constants +from huggingface_hub.errors import InferenceTimeoutError +from huggingface_hub.inference._common import ( + TASKS_EXPECTING_IMAGES, + ContentT, + ModelStatus, + RequestParameters, + _async_stream_chat_completion_response, + _async_stream_text_generation_response, + _b64_encode, + _b64_to_image, + _bytes_to_dict, + _bytes_to_image, + _bytes_to_list, + _get_unsupported_text_generation_kwargs, + _import_numpy, + _open_as_binary, + _set_unsupported_text_generation_kwargs, + raise_text_generation_error, +) +from huggingface_hub.inference._generated.types import ( + AudioClassificationOutputElement, + AudioClassificationOutputTransform, + AudioToAudioOutputElement, + AutomaticSpeechRecognitionOutput, + ChatCompletionInputGrammarType, + ChatCompletionInputMessage, + ChatCompletionInputStreamOptions, + ChatCompletionInputTool, + ChatCompletionInputToolChoiceClass, + ChatCompletionInputToolChoiceEnum, + ChatCompletionOutput, + ChatCompletionStreamOutput, + DocumentQuestionAnsweringOutputElement, + FillMaskOutputElement, + ImageClassificationOutputElement, + ImageClassificationOutputTransform, + ImageSegmentationOutputElement, + ImageSegmentationSubtask, + ImageToImageTargetSize, + ImageToTextOutput, + ImageToVideoTargetSize, + ObjectDetectionOutputElement, + Padding, + QuestionAnsweringOutputElement, + SummarizationOutput, + SummarizationTruncationStrategy, + TableQuestionAnsweringOutputElement, + TextClassificationOutputElement, + TextClassificationOutputTransform, + TextGenerationInputGrammarType, + TextGenerationOutput, + TextGenerationStreamOutput, + TextToSpeechEarlyStoppingEnum, + TokenClassificationAggregationStrategy, + TokenClassificationOutputElement, + TranslationOutput, + TranslationTruncationStrategy, + VisualQuestionAnsweringOutputElement, + ZeroShotClassificationOutputElement, + ZeroShotImageClassificationOutputElement, +) +from huggingface_hub.inference._providers import PROVIDER_OR_POLICY_T, get_provider_helper +from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status +from huggingface_hub.utils._auth import get_token +from huggingface_hub.utils._deprecation import _deprecate_method + +from .._common import _async_yield_from, _import_aiohttp + + +if TYPE_CHECKING: + import numpy as np + from aiohttp import ClientResponse, ClientSession + from PIL.Image import Image + +logger = logging.getLogger(__name__) + + +MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]") + + +class AsyncInferenceClient: + """ + Initialize a new Inference Client. + + [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used + seamlessly with either the (free) Inference API, self-hosted Inference Endpoints, or third-party Inference Providers. + + Args: + model (`str`, `optional`): + The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` + or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is + automatically selected for the task. + Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 + arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL. + provider (`str`, *optional*): + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. + If model is a URL or `base_url` is passed, then `provider` is not used. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2 + arguments are mutually exclusive and have the exact same behavior. + timeout (`float`, `optional`): + The maximum number of seconds to wait for a response from the server. Defaults to None, meaning it will loop until the server is available. + headers (`Dict[str, str]`, `optional`): + Additional headers to send to the server. By default only the authorization and user-agent headers are sent. + Values in this dictionary will override the default values. + bill_to (`str`, `optional`): + The billing account to use for the requests. By default the requests are billed on the user's account. + Requests can only be billed to an organization the user is a member of, and which has subscribed to Enterprise Hub. + cookies (`Dict[str, str]`, `optional`): + Additional cookies to send to the server. + trust_env ('bool', 'optional'): + Trust environment settings for proxy configuration if the parameter is `True` (`False` by default). + proxies (`Any`, `optional`): + Proxies to use for the request. + base_url (`str`, `optional`): + Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] + follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None. + api_key (`str`, `optional`): + Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`] + follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None. + """ + + def __init__( + self, + model: Optional[str] = None, + *, + provider: Optional[PROVIDER_OR_POLICY_T] = None, + token: Optional[str] = None, + timeout: Optional[float] = None, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, + trust_env: bool = False, + proxies: Optional[Any] = None, + bill_to: Optional[str] = None, + # OpenAI compatibility + base_url: Optional[str] = None, + api_key: Optional[str] = None, + ) -> None: + if model is not None and base_url is not None: + raise ValueError( + "Received both `model` and `base_url` arguments. Please provide only one of them." + " `base_url` is an alias for `model` to make the API compatible with OpenAI's client." + " If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url." + " When passing a URL as `model`, the client will not append any suffix path to it." + ) + if token is not None and api_key is not None: + raise ValueError( + "Received both `token` and `api_key` arguments. Please provide only one of them." + " `api_key` is an alias for `token` to make the API compatible with OpenAI's client." + " It has the exact same behavior as `token`." + ) + token = token if token is not None else api_key + if isinstance(token, bool): + # Legacy behavior: previously is was possible to pass `token=False` to disable authentication. This is not + # supported anymore as authentication is required. Better to explicitly raise here rather than risking + # sending the locally saved token without the user knowing about it. + if token is False: + raise ValueError( + "Cannot use `token=False` to disable authentication as authentication is required to run Inference." + ) + warnings.warn( + "Using `token=True` to automatically use the locally saved token is deprecated and will be removed in a future release. " + "Please use `token=None` instead (default).", + DeprecationWarning, + ) + token = get_token() + + self.model: Optional[str] = base_url or model + self.token: Optional[str] = token + + self.headers = {**headers} if headers is not None else {} + if bill_to is not None: + if ( + constants.HUGGINGFACE_HEADER_X_BILL_TO in self.headers + and self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] != bill_to + ): + warnings.warn( + f"Overriding existing '{self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO]}' value in headers with '{bill_to}'.", + UserWarning, + ) + self.headers[constants.HUGGINGFACE_HEADER_X_BILL_TO] = bill_to + + if token is not None and not token.startswith("hf_"): + warnings.warn( + "You've provided an external provider's API key, so requests will be billed directly by the provider. " + "The `bill_to` parameter is only applicable for Hugging Face billing and will be ignored.", + UserWarning, + ) + + # Configure provider + self.provider = provider + + self.cookies = cookies + self.timeout = timeout + self.trust_env = trust_env + self.proxies = proxies + + # Keep track of the sessions to close them properly + self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict() + + def __repr__(self): + return f"" + + @overload + async def _inner_post( # type: ignore[misc] + self, request_parameters: RequestParameters, *, stream: Literal[False] = ... + ) -> bytes: ... + + @overload + async def _inner_post( # type: ignore[misc] + self, request_parameters: RequestParameters, *, stream: Literal[True] = ... + ) -> AsyncIterable[bytes]: ... + + @overload + async def _inner_post( + self, request_parameters: RequestParameters, *, stream: bool = False + ) -> Union[bytes, AsyncIterable[bytes]]: ... + + async def _inner_post( + self, request_parameters: RequestParameters, *, stream: bool = False + ) -> Union[bytes, AsyncIterable[bytes]]: + """Make a request to the inference server.""" + + aiohttp = _import_aiohttp() + + # TODO: this should be handled in provider helpers directly + if request_parameters.task in TASKS_EXPECTING_IMAGES and "Accept" not in request_parameters.headers: + request_parameters.headers["Accept"] = "image/png" + + with _open_as_binary(request_parameters.data) as data_as_binary: + # Do not use context manager as we don't want to close the connection immediately when returning + # a stream + session = self._get_client_session(headers=request_parameters.headers) + + try: + response = await session.post( + request_parameters.url, json=request_parameters.json, data=data_as_binary, proxy=self.proxies + ) + response_error_payload = None + if response.status != 200: + try: + response_error_payload = await response.json() # get payload before connection closed + except Exception: + pass + response.raise_for_status() + if stream: + return _async_yield_from(session, response) + else: + content = await response.read() + await session.close() + return content + except asyncio.TimeoutError as error: + await session.close() + # Convert any `TimeoutError` to a `InferenceTimeoutError` + raise InferenceTimeoutError(f"Inference call timed out: {request_parameters.url}") from error # type: ignore + except aiohttp.ClientResponseError as error: + error.response_error_payload = response_error_payload + await session.close() + raise error + except Exception: + await session.close() + raise + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + def __del__(self): + if len(self._sessions) > 0: + warnings.warn( + "Deleting 'AsyncInferenceClient' client but some sessions are still open. " + "This can happen if you've stopped streaming data from the server before the stream was complete. " + "To close the client properly, you must call `await client.close()` " + "or use an async context (e.g. `async with AsyncInferenceClient(): ...`." + ) + + async def close(self): + """Close all open sessions. + + By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you + are streaming data from the server and you stop before the stream is complete, you must call this method to + close the session properly. + + Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). + """ + await asyncio.gather(*[session.close() for session in self._sessions.keys()]) + + async def audio_classification( + self, + audio: ContentT, + *, + model: Optional[str] = None, + top_k: Optional[int] = None, + function_to_apply: Optional["AudioClassificationOutputTransform"] = None, + ) -> List[AudioClassificationOutputElement]: + """ + Perform audio classification on the provided audio content. + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an + audio file. + model (`str`, *optional*): + The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub + or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for + audio classification will be used. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + function_to_apply (`"AudioClassificationOutputTransform"`, *optional*): + The function to apply to the model outputs in order to retrieve the scores. + + Returns: + `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.audio_classification("audio.flac") + [ + AudioClassificationOutputElement(score=0.4976358711719513, label='hap'), + AudioClassificationOutputElement(score=0.3677836060523987, label='neu'), + ... + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="audio-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=audio, + parameters={"function_to_apply": function_to_apply, "top_k": top_k}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return AudioClassificationOutputElement.parse_obj_as_list(response) + + async def audio_to_audio( + self, + audio: ContentT, + *, + model: Optional[str] = None, + ) -> List[AudioToAudioOutputElement]: + """ + Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation). + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The audio content for the model. It can be raw audio bytes, a local audio file, or a URL pointing to an + audio file. + model (`str`, *optional*): + The model can be any model which takes an audio file and returns another audio file. Can be a model ID hosted on the Hugging Face Hub + or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for + audio_to_audio will be used. + + Returns: + `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. + + Raises: + `InferenceTimeoutError`: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> audio_output = await client.audio_to_audio("audio.flac") + >>> async for i, item in enumerate(audio_output): + >>> with open(f"output_{i}.flac", "wb") as f: + f.write(item.blob) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="audio-to-audio", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=audio, + parameters={}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + audio_output = AudioToAudioOutputElement.parse_obj_as_list(response) + for item in audio_output: + item.blob = base64.b64decode(item.blob) + return audio_output + + async def automatic_speech_recognition( + self, + audio: ContentT, + *, + model: Optional[str] = None, + extra_body: Optional[Dict] = None, + ) -> AutomaticSpeechRecognitionOutput: + """ + Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file. + model (`str`, *optional*): + The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for ASR will be used. + extra_body (`Dict`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + Returns: + [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.automatic_speech_recognition("hello_world.flac").text + "hello world" + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="automatic-speech-recognition", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=audio, + parameters={**(extra_body or {})}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response) + + @overload + async def chat_completion( # type: ignore + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + *, + model: Optional[str] = None, + stream: Literal[False] = False, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream_options: Optional[ChatCompletionInputStreamOptions] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + extra_body: Optional[Dict] = None, + ) -> ChatCompletionOutput: ... + + @overload + async def chat_completion( # type: ignore + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + *, + model: Optional[str] = None, + stream: Literal[True] = True, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream_options: Optional[ChatCompletionInputStreamOptions] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + extra_body: Optional[Dict] = None, + ) -> AsyncIterable[ChatCompletionStreamOutput]: ... + + @overload + async def chat_completion( + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + *, + model: Optional[str] = None, + stream: bool = False, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream_options: Optional[ChatCompletionInputStreamOptions] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + extra_body: Optional[Dict] = None, + ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ... + + async def chat_completion( + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + *, + model: Optional[str] = None, + stream: bool = False, + # Parameters from ChatCompletionInput (handled manually) + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream_options: Optional[ChatCompletionInputStreamOptions] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + extra_body: Optional[Dict] = None, + ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: + """ + A method for completing conversations using a specified language model. + + + + The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client. + Inputs and outputs are strictly the same and using either syntax will yield the same results. + Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility) + for more details about OpenAI's compatibility. + + + + + You can pass provider-specific parameters to the model by using the `extra_body` argument. + + + Args: + messages (List of [`ChatCompletionInputMessage`]): + Conversation history consisting of roles and content pairs. + model (`str`, *optional*): + The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used. + See https://huggingface.co/tasks/text-generation for more details. + If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a + custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`]. + frequency_penalty (`float`, *optional*): + Penalizes new tokens based on their existing frequency + in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0. + logit_bias (`List[float]`, *optional*): + Adjusts the likelihood of specific tokens appearing in the generated output. + logprobs (`bool`, *optional*): + Whether to return log probabilities of the output tokens or not. If true, returns the log + probabilities of each output token returned in the content of message. + max_tokens (`int`, *optional*): + Maximum number of tokens allowed in the response. Defaults to 100. + n (`int`, *optional*): + The number of completions to generate for each prompt. + presence_penalty (`float`, *optional*): + Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + text so far, increasing the model's likelihood to talk about new topics. + response_format ([`ChatCompletionInputGrammarType`], *optional*): + Grammar constraints. Can be either a JSONSchema or a regex. + seed (Optional[`int`], *optional*): + Seed for reproducible control flow. Defaults to None. + stop (`List[str]`, *optional*): + Up to four strings which trigger the end of the response. + Defaults to None. + stream (`bool`, *optional*): + Enable realtime streaming of responses. Defaults to False. + stream_options ([`ChatCompletionInputStreamOptions`], *optional*): + Options for streaming completions. + temperature (`float`, *optional*): + Controls randomness of the generations. Lower values ensure + less random completions. Range: [0, 2]. Defaults to 1.0. + top_logprobs (`int`, *optional*): + An integer between 0 and 5 specifying the number of most likely tokens to return at each token + position, each with an associated log probability. logprobs must be set to true if this parameter is + used. + top_p (`float`, *optional*): + Fraction of the most likely next words to sample from. + Must be between 0 and 1. Defaults to 1.0. + tool_choice ([`ChatCompletionInputToolChoiceClass`] or [`ChatCompletionInputToolChoiceEnum`], *optional*): + The tool to use for the completion. Defaults to "auto". + tool_prompt (`str`, *optional*): + A prompt to be appended before the tools. + tools (List of [`ChatCompletionInputTool`], *optional*): + A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + provide a list of functions the model may generate JSON inputs for. + extra_body (`Dict`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + Returns: + [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]: + Generated text returned from the server: + - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default). + - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`]. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> messages = [{"role": "user", "content": "What is the capital of France?"}] + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") + >>> await client.chat_completion(messages, max_tokens=100) + ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason='eos_token', + index=0, + message=ChatCompletionOutputMessage( + role='assistant', + content='The capital of France is Paris.', + name=None, + tool_calls=None + ), + logprobs=None + ) + ], + created=1719907176, + id='', + model='meta-llama/Meta-Llama-3-8B-Instruct', + object='text_completion', + system_fingerprint='2.0.4-sha-f426a33', + usage=ChatCompletionOutputUsage( + completion_tokens=8, + prompt_tokens=17, + total_tokens=25 + ) + ) + ``` + + Example using streaming: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> messages = [{"role": "user", "content": "What is the capital of France?"}] + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") + >>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True): + ... print(token) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504) + (...) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504) + ``` + + Example using OpenAI's syntax: + ```py + # Must be run in an async context + # instead of `from openai import OpenAI` + from huggingface_hub import AsyncInferenceClient + + # instead of `client = OpenAI(...)` + client = AsyncInferenceClient( + base_url=..., + api_key=..., + ) + + output = await client.chat.completions.create( + model="meta-llama/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + print(chunk.choices[0].delta.content) + ``` + + Example using a third-party provider directly with extra (provider-specific) parameters. Usage will be billed on your Together AI account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="together", # Use Together AI provider + ... api_key="", # Pass your Together API key directly + ... ) + >>> client.chat_completion( + ... model="meta-llama/Meta-Llama-3-8B-Instruct", + ... messages=[{"role": "user", "content": "What is the capital of France?"}], + ... extra_body={"safety_model": "Meta-Llama/Llama-Guard-7b"}, + ... ) + ``` + + Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="sambanova", # Use Sambanova provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> client.chat_completion( + ... model="meta-llama/Meta-Llama-3-8B-Instruct", + ... messages=[{"role": "user", "content": "What is the capital of France?"}], + ... ) + ``` + + Example using Image + Text as input: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + + # provide a remote URL + >>> image_url ="https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + # or a base64-encoded image + >>> image_path = "/path/to/image.jpeg" + >>> with open(image_path, "rb") as f: + ... base64_image = base64.b64encode(f.read()).decode("utf-8") + >>> image_url = f"data:image/jpeg;base64,{base64_image}" + + >>> client = AsyncInferenceClient("meta-llama/Llama-3.2-11B-Vision-Instruct") + >>> output = await client.chat.completions.create( + ... messages=[ + ... { + ... "role": "user", + ... "content": [ + ... { + ... "type": "image_url", + ... "image_url": {"url": image_url}, + ... }, + ... { + ... "type": "text", + ... "text": "Describe this image in one sentence.", + ... }, + ... ], + ... }, + ... ], + ... ) + >>> output + The image depicts the iconic Statue of Liberty situated in New York Harbor, New York, on a clear day. + ``` + + Example using tools: + ```py + # Must be run in an async context + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "system", + ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + ... }, + ... { + ... "role": "user", + ... "content": "What's the weather like the next 3 days in San Francisco, CA?", + ... }, + ... ] + >>> tools = [ + ... { + ... "type": "function", + ... "function": { + ... "name": "get_current_weather", + ... "description": "Get the current weather", + ... "parameters": { + ... "type": "object", + ... "properties": { + ... "location": { + ... "type": "string", + ... "description": "The city and state, e.g. San Francisco, CA", + ... }, + ... "format": { + ... "type": "string", + ... "enum": ["celsius", "fahrenheit"], + ... "description": "The temperature unit to use. Infer this from the users location.", + ... }, + ... }, + ... "required": ["location", "format"], + ... }, + ... }, + ... }, + ... { + ... "type": "function", + ... "function": { + ... "name": "get_n_day_weather_forecast", + ... "description": "Get an N-day weather forecast", + ... "parameters": { + ... "type": "object", + ... "properties": { + ... "location": { + ... "type": "string", + ... "description": "The city and state, e.g. San Francisco, CA", + ... }, + ... "format": { + ... "type": "string", + ... "enum": ["celsius", "fahrenheit"], + ... "description": "The temperature unit to use. Infer this from the users location.", + ... }, + ... "num_days": { + ... "type": "integer", + ... "description": "The number of days to forecast", + ... }, + ... }, + ... "required": ["location", "format", "num_days"], + ... }, + ... }, + ... }, + ... ] + + >>> response = await client.chat_completion( + ... model="meta-llama/Meta-Llama-3-70B-Instruct", + ... messages=messages, + ... tools=tools, + ... tool_choice="auto", + ... max_tokens=500, + ... ) + >>> response.choices[0].message.tool_calls[0].function + ChatCompletionOutputFunctionDefinition( + arguments={ + 'location': 'San Francisco, CA', + 'format': 'fahrenheit', + 'num_days': 3 + }, + name='get_n_day_weather_forecast', + description=None + ) + ``` + + Example using response_format: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "user", + ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", + ... }, + ... ] + >>> response_format = { + ... "type": "json", + ... "value": { + ... "properties": { + ... "location": {"type": "string"}, + ... "activity": {"type": "string"}, + ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, + ... "animals": {"type": "array", "items": {"type": "string"}}, + ... }, + ... "required": ["location", "activity", "animals_seen", "animals"], + ... }, + ... } + >>> response = await client.chat_completion( + ... messages=messages, + ... response_format=response_format, + ... max_tokens=500, + ... ) + >>> response.choices[0].message.content + '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' + ``` + """ + # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. + # `self.model` takes precedence over 'model' argument for building URL. + # `model` takes precedence for payload value. + model_id_or_url = self.model or model + payload_model = model or self.model + + # Get the provider helper + provider_helper = get_provider_helper( + self.provider, + task="conversational", + model=model_id_or_url + if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://")) + else payload_model, + ) + + # Prepare the payload + parameters = { + "model": payload_model, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "logprobs": logprobs, + "max_tokens": max_tokens, + "n": n, + "presence_penalty": presence_penalty, + "response_format": response_format, + "seed": seed, + "stop": stop, + "temperature": temperature, + "tool_choice": tool_choice, + "tool_prompt": tool_prompt, + "tools": tools, + "top_logprobs": top_logprobs, + "top_p": top_p, + "stream": stream, + "stream_options": stream_options, + **(extra_body or {}), + } + request_parameters = provider_helper.prepare_request( + inputs=messages, + parameters=parameters, + headers=self.headers, + model=model_id_or_url, + api_key=self.token, + ) + data = await self._inner_post(request_parameters, stream=stream) + + if stream: + return _async_stream_chat_completion_response(data) # type: ignore[arg-type] + + return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + + async def document_question_answering( + self, + image: ContentT, + question: str, + *, + model: Optional[str] = None, + doc_stride: Optional[int] = None, + handle_impossible_answer: Optional[bool] = None, + lang: Optional[str] = None, + max_answer_len: Optional[int] = None, + max_question_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + top_k: Optional[int] = None, + word_boxes: Optional[List[Union[List[float], str]]] = None, + ) -> List[DocumentQuestionAnsweringOutputElement]: + """ + Answer questions on document images. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image for the context. It can be raw bytes, an image file, or a URL to an online image. + question (`str`): + Question to be answered. + model (`str`, *optional*): + The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. + Defaults to None. + doc_stride (`int`, *optional*): + If the words in the document are too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. + handle_impossible_answer (`bool`, *optional*): + Whether to accept impossible as an answer + lang (`str`, *optional*): + Language to use while running OCR. Defaults to english. + max_answer_len (`int`, *optional*): + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). + max_question_len (`int`, *optional*): + The maximum length of the question after tokenization. It will be truncated if needed. + max_seq_len (`int`, *optional*): + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using doc_stride as overlap) if needed. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Can return less than top_k + answers if there are not enough options available within the context. + word_boxes (`List[Union[List[float], str`, *optional*): + A list of words and bounding boxes (normalized 0->1000). If provided, the inference will skip the OCR + step and use the provided bounding boxes instead. + Returns: + `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") + [DocumentQuestionAnsweringOutputElement(answer='us-001', end=16, score=0.9999666213989258, start=16)] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="document-question-answering", model=model_id) + inputs: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + request_parameters = provider_helper.prepare_request( + inputs=inputs, + parameters={ + "doc_stride": doc_stride, + "handle_impossible_answer": handle_impossible_answer, + "lang": lang, + "max_answer_len": max_answer_len, + "max_question_len": max_question_len, + "max_seq_len": max_seq_len, + "top_k": top_k, + "word_boxes": word_boxes, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) + + async def feature_extraction( + self, + text: str, + *, + normalize: Optional[bool] = None, + prompt_name: Optional[str] = None, + truncate: Optional[bool] = None, + truncation_direction: Optional[Literal["Left", "Right"]] = None, + model: Optional[str] = None, + ) -> "np.ndarray": + """ + Generate embeddings for a given text. + + Args: + text (`str`): + The text to embed. + model (`str`, *optional*): + The model to use for the feature extraction task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended feature extraction model will be used. + Defaults to None. + normalize (`bool`, *optional*): + Whether to normalize the embeddings or not. + Only available on server powered by Text-Embedding-Inference. + prompt_name (`str`, *optional*): + The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. + Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...}, + then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" + because the prompt text will be prepended before any text to encode. + truncate (`bool`, *optional*): + Whether to truncate the embeddings or not. + Only available on server powered by Text-Embedding-Inference. + truncation_direction (`Literal["Left", "Right"]`, *optional*): + Which side of the input should be truncated when `truncate=True` is passed. + + Returns: + `np.ndarray`: The embedding representing the input text as a float32 numpy array. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.feature_extraction("Hi, who are you?") + array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ], + [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ], + ..., + [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="feature-extraction", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "normalize": normalize, + "prompt_name": prompt_name, + "truncate": truncate, + "truncation_direction": truncation_direction, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + np = _import_numpy() + return np.array(provider_helper.get_response(response), dtype="float32") + + async def fill_mask( + self, + text: str, + *, + model: Optional[str] = None, + targets: Optional[List[str]] = None, + top_k: Optional[int] = None, + ) -> List[FillMaskOutputElement]: + """ + Fill in a hole with a missing word (token to be precise). + + Args: + text (`str`): + a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask). + model (`str`, *optional*): + The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. + targets (`List[str`, *optional*): + When passed, the model will limit the scores to the passed targets instead of looking up in the whole + vocabulary. If the provided targets are not in the model vocab, they will be tokenized and the first + resulting token will be used (with a warning, and that might be slower). + top_k (`int`, *optional*): + When passed, overrides the number of predictions to return. + Returns: + `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated + probability, token reference, and completed text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.fill_mask("The goal of life is .") + [ + FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'), + FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.') + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="fill-mask", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={"targets": targets, "top_k": top_k}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return FillMaskOutputElement.parse_obj_as_list(response) + + async def image_classification( + self, + image: ContentT, + *, + model: Optional[str] = None, + function_to_apply: Optional["ImageClassificationOutputTransform"] = None, + top_k: Optional[int] = None, + ) -> List[ImageClassificationOutputElement]: + """ + Perform image classification on the given image using the specified model. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The image to classify. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. + function_to_apply (`"ImageClassificationOutputTransform"`, *optional*): + The function to apply to the model outputs in order to retrieve the scores. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + Returns: + `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") + [ImageClassificationOutputElement(label='Blenheim spaniel', score=0.9779096841812134), ...] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={"function_to_apply": function_to_apply, "top_k": top_k}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return ImageClassificationOutputElement.parse_obj_as_list(response) + + async def image_segmentation( + self, + image: ContentT, + *, + model: Optional[str] = None, + mask_threshold: Optional[float] = None, + overlap_mask_area_threshold: Optional[float] = None, + subtask: Optional["ImageSegmentationSubtask"] = None, + threshold: Optional[float] = None, + ) -> List[ImageSegmentationOutputElement]: + """ + Perform image segmentation on the given image using the specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The image to segment. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. + mask_threshold (`float`, *optional*): + Threshold to use when turning the predicted masks into binary values. + overlap_mask_area_threshold (`float`, *optional*): + Mask overlap threshold to eliminate small, disconnected segments. + subtask (`"ImageSegmentationSubtask"`, *optional*): + Segmentation task to be performed, depending on model capabilities. + threshold (`float`, *optional*): + Probability threshold to filter out predicted masks. + Returns: + `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.image_segmentation("cat.jpg") + [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=), ...] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-segmentation", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={ + "mask_threshold": mask_threshold, + "overlap_mask_area_threshold": overlap_mask_area_threshold, + "subtask": subtask, + "threshold": threshold, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + output = ImageSegmentationOutputElement.parse_obj_as_list(response) + for item in output: + item.mask = _b64_to_image(item.mask) # type: ignore [assignment] + return output + + async def image_to_image( + self, + image: ContentT, + prompt: Optional[str] = None, + *, + negative_prompt: Optional[str] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + model: Optional[str] = None, + target_size: Optional[ImageToImageTargetSize] = None, + **kwargs, + ) -> "Image": + """ + Perform image-to-image translation using a specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image for translation. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + prompt (`str`, *optional*): + The text prompt to guide the image generation. + negative_prompt (`str`, *optional*): + One prompt to guide what NOT to include in image generation. + num_inference_steps (`int`, *optional*): + For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + guidance_scale (`float`, *optional*): + For diffusion models. A higher guidance scale value encourages the model to generate images closely + linked to the text prompt at the expense of lower image quality. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + target_size (`ImageToImageTargetSize`, *optional*): + The size in pixel of the output image. + + Returns: + `Image`: The translated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger") + >>> image.save("tiger.jpg") + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-to-image", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={ + "prompt": prompt, + "negative_prompt": negative_prompt, + "target_size": target_size, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + **kwargs, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) + return _bytes_to_image(response) + + async def image_to_video( + self, + image: ContentT, + *, + model: Optional[str] = None, + prompt: Optional[str] = None, + negative_prompt: Optional[str] = None, + num_frames: Optional[float] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + seed: Optional[int] = None, + target_size: Optional[ImageToVideoTargetSize] = None, + **kwargs, + ) -> bytes: + """ + Generate a video from an input image. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image to generate a video from. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + prompt (`str`, *optional*): + The text prompt to guide the video generation. + negative_prompt (`str`, *optional*): + One prompt to guide what NOT to include in video generation. + num_frames (`float`, *optional*): + The num_frames parameter determines how many video frames are generated. + num_inference_steps (`int`, *optional*): + For diffusion models. The number of denoising steps. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + guidance_scale (`float`, *optional*): + For diffusion models. A higher guidance scale value encourages the model to generate videos closely + linked to the text prompt at the expense of lower image quality. + seed (`int`, *optional*): + The seed to use for the video generation. + target_size (`ImageToVideoTargetSize`, *optional*): + The size in pixel of the output video frames. + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + seed (`int`, *optional*): + Seed for the random number generator. + + Returns: + `bytes`: The generated video. + + Examples: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> video = await client.image_to_video("cat.jpg", model="Wan-AI/Wan2.2-I2V-A14B", prompt="turn the cat into a tiger") + >>> with open("tiger.mp4", "wb") as f: + ... f.write(video) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-to-video", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={ + "prompt": prompt, + "negative_prompt": negative_prompt, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "seed": seed, + "target_size": target_size, + **kwargs, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) + return response + + async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: + """ + Takes an input image and return text. + + Models can have very different outputs depending on your use case (image captioning, optical character recognition + (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + [`ImageToTextOutput`]: The generated text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.image_to_text("cat.jpg") + 'a cat standing in a grassy field ' + >>> await client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") + 'a dog laying on the grass next to a flower pot ' + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="image-to-text", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + output = ImageToTextOutput.parse_obj(response) + return output[0] if isinstance(output, list) else output + + async def object_detection( + self, image: ContentT, *, model: Optional[str] = None, threshold: Optional[float] = None + ) -> List[ObjectDetectionOutputElement]: + """ + Perform object detection on the given image using the specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The image to detect objects on. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + model (`str`, *optional*): + The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. + threshold (`float`, *optional*): + The probability necessary to make a prediction. + Returns: + `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If the request output is not a List. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.object_detection("people.jpg") + [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="object-detection", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={"threshold": threshold}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return ObjectDetectionOutputElement.parse_obj_as_list(response) + + async def question_answering( + self, + question: str, + context: str, + *, + model: Optional[str] = None, + align_to_words: Optional[bool] = None, + doc_stride: Optional[int] = None, + handle_impossible_answer: Optional[bool] = None, + max_answer_len: Optional[int] = None, + max_question_len: Optional[int] = None, + max_seq_len: Optional[int] = None, + top_k: Optional[int] = None, + ) -> Union[QuestionAnsweringOutputElement, List[QuestionAnsweringOutputElement]]: + """ + Retrieve the answer to a question from a given text. + + Args: + question (`str`): + Question to be answered. + context (`str`): + The context of the question. + model (`str`): + The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. + align_to_words (`bool`, *optional*): + Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt + on non-space-separated languages (like Japanese or Chinese) + doc_stride (`int`, *optional*): + If the context is too long to fit with the question for the model, it will be split in several chunks + with some overlap. This argument controls the size of that overlap. + handle_impossible_answer (`bool`, *optional*): + Whether to accept impossible as an answer. + max_answer_len (`int`, *optional*): + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). + max_question_len (`int`, *optional*): + The maximum length of the question after tokenization. It will be truncated if needed. + max_seq_len (`int`, *optional*): + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using docStride as overlap) if needed. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + topk answers if there are not enough options available within the context. + + Returns: + Union[`QuestionAnsweringOutputElement`, List[`QuestionAnsweringOutputElement`]]: + When top_k is 1 or not provided, it returns a single `QuestionAnsweringOutputElement`. + When top_k is greater than 1, it returns a list of `QuestionAnsweringOutputElement`. + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") + QuestionAnsweringOutputElement(answer='Clara', end=16, score=0.9326565265655518, start=11) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="question-answering", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs={"question": question, "context": context}, + parameters={ + "align_to_words": align_to_words, + "doc_stride": doc_stride, + "handle_impossible_answer": handle_impossible_answer, + "max_answer_len": max_answer_len, + "max_question_len": max_question_len, + "max_seq_len": max_seq_len, + "top_k": top_k, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + # Parse the response as a single `QuestionAnsweringOutputElement` when top_k is 1 or not provided, or a list of `QuestionAnsweringOutputElement` to ensure backward compatibility. + output = QuestionAnsweringOutputElement.parse_obj(response) + return output + + async def sentence_similarity( + self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None + ) -> List[float]: + """ + Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. + + Args: + sentence (`str`): + The main sentence to compare to others. + other_sentences (`List[str]`): + The list of sentences to compare to. + model (`str`, *optional*): + The model to use for the sentence similarity task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended sentence similarity model will be used. + Defaults to None. + + Returns: + `List[float]`: The embedding representing the input text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.sentence_similarity( + ... "Machine learning is so easy.", + ... other_sentences=[ + ... "Deep learning is so straightforward.", + ... "This is so difficult, like rocket science.", + ... "I can't believe how much I struggled with this.", + ... ], + ... ) + [0.7785726189613342, 0.45876261591911316, 0.2906220555305481] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="sentence-similarity", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs={"source_sentence": sentence, "sentences": other_sentences}, + parameters={}, + extra_payload={}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return _bytes_to_list(response) + + async def summarization( + self, + text: str, + *, + model: Optional[str] = None, + clean_up_tokenization_spaces: Optional[bool] = None, + generate_parameters: Optional[Dict[str, Any]] = None, + truncation: Optional["SummarizationTruncationStrategy"] = None, + ) -> SummarizationOutput: + """ + Generate a summary of a given text using a specified model. + + Args: + text (`str`): + The input text to summarize. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for summarization will be used. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether to clean up the potential extra spaces in the text output. + generate_parameters (`Dict[str, Any]`, *optional*): + Additional parametrization of the text generation algorithm. + truncation (`"SummarizationTruncationStrategy"`, *optional*): + The truncation strategy to use. + Returns: + [`SummarizationOutput`]: The generated summary text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.summarization("The Eiffel tower...") + SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....") + ``` + """ + parameters = { + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + "generate_parameters": generate_parameters, + "truncation": truncation, + } + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="summarization", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters=parameters, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return SummarizationOutput.parse_obj_as_list(response)[0] + + async def table_question_answering( + self, + table: Dict[str, Any], + query: str, + *, + model: Optional[str] = None, + padding: Optional["Padding"] = None, + sequential: Optional[bool] = None, + truncation: Optional[bool] = None, + ) -> TableQuestionAnsweringOutputElement: + """ + Retrieve the answer to a question from information given in a table. + + Args: + table (`str`): + A table of data represented as a dict of lists where entries are headers and the lists are all the + values, all lists must have the same size. + query (`str`): + The query in plain text that you want to ask the table. + model (`str`): + The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face + Hub or a URL to a deployed Inference Endpoint. + padding (`"Padding"`, *optional*): + Activates and controls padding. + sequential (`bool`, *optional*): + Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the + inference to be done sequentially to extract relations within sequences, given their conversational + nature. + truncation (`bool`, *optional*): + Activates and controls truncation. + + Returns: + [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> query = "How many stars does the transformers repository have?" + >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]} + >>> await client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq") + TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="table-question-answering", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs={"query": query, "table": table}, + parameters={"model": model, "padding": padding, "sequential": sequential, "truncation": truncation}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response) + + async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]: + """ + Classifying a target category (a group) based on a set of attributes. + + Args: + table (`Dict[str, Any]`): + Set of attributes to classify. + model (`str`, *optional*): + The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended tabular classification model will be used. + Defaults to None. + + Returns: + `List`: a list of labels, one per row in the initial table. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> table = { + ... "fixed_acidity": ["7.4", "7.8", "10.3"], + ... "volatile_acidity": ["0.7", "0.88", "0.32"], + ... "citric_acid": ["0", "0", "0.45"], + ... "residual_sugar": ["1.9", "2.6", "6.4"], + ... "chlorides": ["0.076", "0.098", "0.073"], + ... "free_sulfur_dioxide": ["11", "25", "5"], + ... "total_sulfur_dioxide": ["34", "67", "13"], + ... "density": ["0.9978", "0.9968", "0.9976"], + ... "pH": ["3.51", "3.2", "3.23"], + ... "sulphates": ["0.56", "0.68", "0.82"], + ... "alcohol": ["9.4", "9.8", "12.6"], + ... } + >>> await client.tabular_classification(table=table, model="julien-c/wine-quality") + ["5", "5", "5"] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="tabular-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=None, + extra_payload={"table": table}, + parameters={}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return _bytes_to_list(response) + + async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: + """ + Predicting a numerical target value given a set of attributes/features in a table. + + Args: + table (`Dict[str, Any]`): + Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. + model (`str`, *optional*): + The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended tabular regression model will be used. + Defaults to None. + + Returns: + `List`: a list of predicted numerical target values. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> table = { + ... "Height": ["11.52", "12.48", "12.3778"], + ... "Length1": ["23.2", "24", "23.9"], + ... "Length2": ["25.4", "26.3", "26.5"], + ... "Length3": ["30", "31.2", "31.1"], + ... "Species": ["Bream", "Bream", "Bream"], + ... "Width": ["4.02", "4.3056", "4.6961"], + ... } + >>> await client.tabular_regression(table, model="scikit-learn/Fish-Weight") + [110, 120, 130] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="tabular-regression", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=None, + parameters={}, + extra_payload={"table": table}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return _bytes_to_list(response) + + async def text_classification( + self, + text: str, + *, + model: Optional[str] = None, + top_k: Optional[int] = None, + function_to_apply: Optional["TextClassificationOutputTransform"] = None, + ) -> List[TextClassificationOutputElement]: + """ + Perform text classification (e.g. sentiment-analysis) on the given text. + + Args: + text (`str`): + A string to be classified. + model (`str`, *optional*): + The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. + Defaults to None. + top_k (`int`, *optional*): + When specified, limits the output to the top K most probable classes. + function_to_apply (`"TextClassificationOutputTransform"`, *optional*): + The function to apply to the model outputs in order to retrieve the scores. + + Returns: + `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.text_classification("I like you") + [ + TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314), + TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069), + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "function_to_apply": function_to_apply, + "top_k": top_k, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] + + @overload + async def text_generation( + self, + prompt: str, + *, + details: Literal[True], + stream: Literal[True], + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> AsyncIterable[TextGenerationStreamOutput]: ... + + @overload + async def text_generation( + self, + prompt: str, + *, + details: Literal[True], + stream: Optional[Literal[False]] = None, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> TextGenerationOutput: ... + + @overload + async def text_generation( + self, + prompt: str, + *, + details: Optional[Literal[False]] = None, + stream: Literal[True], + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> AsyncIterable[str]: ... + + @overload + async def text_generation( + self, + prompt: str, + *, + details: Optional[Literal[False]] = None, + stream: Optional[Literal[False]] = None, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> str: ... + + @overload + async def text_generation( + self, + prompt: str, + *, + details: Optional[bool] = None, + stream: Optional[bool] = None, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: ... + + async def text_generation( + self, + prompt: str, + *, + details: Optional[bool] = None, + stream: Optional[bool] = None, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: + """ + Given a prompt, generate the following text. + + + + If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method. + It accepts a list of messages instead of a single text prompt and handles the chat templating for you. + + + + Args: + prompt (`str`): + Input text. + details (`bool`, *optional*): + By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens, + probabilities, seed, finish reason, etc.). Only available for models running on with the + `text-generation-inference` backend. + stream (`bool`, *optional*): + By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of + tokens to be returned. Only available for models running on with the `text-generation-inference` + backend. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + adapter_id (`str`, *optional*): + Lora adapter id. + best_of (`int`, *optional*): + Generate best_of sequences and return the one if the highest token logprobs. + decoder_input_details (`bool`, *optional*): + Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken + into account. Defaults to `False`. + do_sample (`bool`, *optional*): + Activate logits sampling + frequency_penalty (`float`, *optional*): + Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in + the text so far, decreasing the model's likelihood to repeat the same line verbatim. + grammar ([`TextGenerationInputGrammarType`], *optional*): + Grammar constraints. Can be either a JSONSchema or a regex. + max_new_tokens (`int`, *optional*): + Maximum number of generated tokens. Defaults to 100. + repetition_penalty (`float`, *optional*): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + return_full_text (`bool`, *optional*): + Whether to prepend the prompt to the generated text + seed (`int`, *optional*): + Random sampling seed + stop (`List[str]`, *optional*): + Stop generating tokens if a member of `stop` is generated. + stop_sequences (`List[str]`, *optional*): + Deprecated argument. Use `stop` instead. + temperature (`float`, *optional*): + The value used to module the logits distribution. + top_n_tokens (`int`, *optional*): + Return information about the `top_n_tokens` most likely tokens at each generation step, instead of + just the sampled token. + top_k (`int`, *optional`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`, *optional`): + Truncate inputs tokens to the given size. + typical_p (`float`, *optional`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`, *optional*): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + + Returns: + `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`: + Generated text returned from the server: + - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) + - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` + - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`] + - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`] + + Raises: + `ValidationError`: + If input values are not valid. No HTTP call is made to the server. + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + # Case 1: generate text + >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12) + '100% open source and built to be easy to use.' + + # Case 2: iterate over the generated tokens. Useful for large generation. + >>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True): + ... print(token) + 100 + % + open + source + and + built + to + be + easy + to + use + . + + # Case 3: get more details about the generation process. + >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True) + TextGenerationOutput( + generated_text='100% open source and built to be easy to use.', + details=TextGenerationDetails( + finish_reason='length', + generated_tokens=12, + seed=None, + prefill=[ + TextGenerationPrefillOutputToken(id=487, text='The', logprob=None), + TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875), + (...) + TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625) + ], + tokens=[ + TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), + TokenElement(id=16, text='%', logprob=-0.0463562, special=False), + (...) + TokenElement(id=25, text='.', logprob=-0.5703125, special=False) + ], + best_of_sequences=None + ) + ) + + # Case 4: iterate over the generated tokens with more details. + # Last object is more complete, containing the full generated text and the finish reason. + >>> async for details in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True): + ... print(details) + ... + TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement( + id=25, + text='.', + logprob=-0.5703125, + special=False), + generated_text='100% open source and built to be easy to use.', + details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None) + ) + + # Case 5: generate constrained output using grammar + >>> response = await client.text_generation( + ... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park", + ... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + ... max_new_tokens=100, + ... repetition_penalty=1.3, + ... grammar={ + ... "type": "json", + ... "value": { + ... "properties": { + ... "location": {"type": "string"}, + ... "activity": {"type": "string"}, + ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, + ... "animals": {"type": "array", "items": {"type": "string"}}, + ... }, + ... "required": ["location", "activity", "animals_seen", "animals"], + ... }, + ... }, + ... ) + >>> json.loads(response) + { + "activity": "bike riding", + "animals": ["puppy", "cat", "raccoon"], + "animals_seen": 3, + "location": "park" + } + ``` + """ + if decoder_input_details and not details: + warnings.warn( + "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that" + " the output from the server will be truncated." + ) + decoder_input_details = False + + if stop_sequences is not None: + warnings.warn( + "`stop_sequences` is a deprecated argument for `text_generation` task" + " and will be removed in version '0.28.0'. Use `stop` instead.", + FutureWarning, + ) + if stop is None: + stop = stop_sequences # use deprecated arg if provided + + # Build payload + parameters = { + "adapter_id": adapter_id, + "best_of": best_of, + "decoder_input_details": decoder_input_details, + "details": details, + "do_sample": do_sample, + "frequency_penalty": frequency_penalty, + "grammar": grammar, + "max_new_tokens": max_new_tokens, + "repetition_penalty": repetition_penalty, + "return_full_text": return_full_text, + "seed": seed, + "stop": stop, + "temperature": temperature, + "top_k": top_k, + "top_n_tokens": top_n_tokens, + "top_p": top_p, + "truncate": truncate, + "typical_p": typical_p, + "watermark": watermark, + } + + # Remove some parameters if not a TGI server + unsupported_kwargs = _get_unsupported_text_generation_kwargs(model) + if len(unsupported_kwargs) > 0: + # The server does not support some parameters + # => means it is not a TGI server + # => remove unsupported parameters and warn the user + + ignored_parameters = [] + for key in unsupported_kwargs: + if parameters.get(key): + ignored_parameters.append(key) + parameters.pop(key, None) + if len(ignored_parameters) > 0: + warnings.warn( + "API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:" + f" {', '.join(ignored_parameters)}.", + UserWarning, + ) + if details: + warnings.warn( + "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will" + " be ignored meaning only the generated text will be returned.", + UserWarning, + ) + details = False + if stream: + raise ValueError( + "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream." + " Please pass `stream=False` as input." + ) + + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-generation", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=prompt, + parameters=parameters, + extra_payload={"stream": stream}, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + + # Handle errors separately for more precise error messages + try: + bytes_output = await self._inner_post(request_parameters, stream=stream or False) + except _import_aiohttp().ClientResponseError as e: + match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"]) + if e.status == 400 and match: + unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")] + _set_unsupported_text_generation_kwargs(model, unused_params) + return await self.text_generation( # type: ignore + prompt=prompt, + details=details, + stream=stream, + model=model_id, + adapter_id=adapter_id, + best_of=best_of, + decoder_input_details=decoder_input_details, + do_sample=do_sample, + frequency_penalty=frequency_penalty, + grammar=grammar, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop, + temperature=temperature, + top_k=top_k, + top_n_tokens=top_n_tokens, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + ) + raise_text_generation_error(e) + + # Parse output + if stream: + return _async_stream_text_generation_response(bytes_output, details) # type: ignore + + data = _bytes_to_dict(bytes_output) # type: ignore[arg-type] + + # Data can be a single element (dict) or an iterable of dicts where we select the first element of. + if isinstance(data, list): + data = data[0] + response = provider_helper.get_response(data, request_parameters) + return TextGenerationOutput.parse_obj_as_instance(response) if details else response["generated_text"] + + async def text_to_image( + self, + prompt: str, + *, + negative_prompt: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + model: Optional[str] = None, + scheduler: Optional[str] = None, + seed: Optional[int] = None, + extra_body: Optional[Dict[str, Any]] = None, + ) -> "Image": + """ + Generate an image based on a given text using a specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + + You can pass provider-specific parameters to the model by using the `extra_body` argument. + + + Args: + prompt (`str`): + The prompt to generate an image from. + negative_prompt (`str`, *optional*): + One prompt to guide what NOT to include in image generation. + height (`int`, *optional*): + The height in pixels of the output image + width (`int`, *optional*): + The width in pixels of the output image + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*): + A higher guidance scale value encourages the model to generate images closely linked to the text + prompt, but values too high may cause saturation and other artifacts. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended text-to-image model will be used. + Defaults to None. + scheduler (`str`, *optional*): + Override the scheduler with a compatible one. + seed (`int`, *optional*): + Seed for the random number generator. + extra_body (`Dict[str, Any]`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + + Returns: + `Image`: The generated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + >>> image = await client.text_to_image("An astronaut riding a horse on the moon.") + >>> image.save("astronaut.png") + + >>> image = await client.text_to_image( + ... "An astronaut riding a horse on the moon.", + ... negative_prompt="low resolution, blurry", + ... model="stabilityai/stable-diffusion-2-1", + ... ) + >>> image.save("better_astronaut.png") + ``` + Example using a third-party provider directly. Usage will be billed on your fal.ai account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="fal-ai", # Use fal.ai provider + ... api_key="fal-ai-api-key", # Pass your fal.ai API key + ... ) + >>> image = client.text_to_image( + ... "A majestic lion in a fantasy forest", + ... model="black-forest-labs/FLUX.1-schnell", + ... ) + >>> image.save("lion.png") + ``` + + Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", # Use replicate provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> image = client.text_to_image( + ... "An astronaut riding a horse on the moon.", + ... model="black-forest-labs/FLUX.1-dev", + ... ) + >>> image.save("astronaut.png") + ``` + + Example using Replicate provider with extra parameters + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", # Use replicate provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> image = client.text_to_image( + ... "An astronaut riding a horse on the moon.", + ... model="black-forest-labs/FLUX.1-schnell", + ... extra_body={"output_quality": 100}, + ... ) + >>> image.save("astronaut.png") + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-to-image", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=prompt, + parameters={ + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "scheduler": scheduler, + "seed": seed, + **(extra_body or {}), + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response) + return _bytes_to_image(response) + + async def text_to_video( + self, + prompt: str, + *, + model: Optional[str] = None, + guidance_scale: Optional[float] = None, + negative_prompt: Optional[List[str]] = None, + num_frames: Optional[float] = None, + num_inference_steps: Optional[int] = None, + seed: Optional[int] = None, + extra_body: Optional[Dict[str, Any]] = None, + ) -> bytes: + """ + Generate a video based on a given text. + + + You can pass provider-specific parameters to the model by using the `extra_body` argument. + + + Args: + prompt (`str`): + The prompt to generate a video from. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended text-to-video model will be used. + Defaults to None. + guidance_scale (`float`, *optional*): + A higher guidance scale value encourages the model to generate videos closely linked to the text + prompt, but values too high may cause saturation and other artifacts. + negative_prompt (`List[str]`, *optional*): + One or several prompt to guide what NOT to include in video generation. + num_frames (`float`, *optional*): + The num_frames parameter determines how many video frames are generated. + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + seed (`int`, *optional*): + Seed for the random number generator. + extra_body (`Dict[str, Any]`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + + Returns: + `bytes`: The generated video. + + Example: + + Example using a third-party provider directly. Usage will be billed on your fal.ai account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="fal-ai", # Using fal.ai provider + ... api_key="fal-ai-api-key", # Pass your fal.ai API key + ... ) + >>> video = client.text_to_video( + ... "A majestic lion running in a fantasy forest", + ... model="tencent/HunyuanVideo", + ... ) + >>> with open("lion.mp4", "wb") as file: + ... file.write(video) + ``` + + Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", # Using replicate provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> video = client.text_to_video( + ... "A cat running in a park", + ... model="genmo/mochi-1-preview", + ... ) + >>> with open("cat.mp4", "wb") as file: + ... file.write(video) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-to-video", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=prompt, + parameters={ + "guidance_scale": guidance_scale, + "negative_prompt": negative_prompt, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "seed": seed, + **(extra_body or {}), + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response, request_parameters) + return response + + async def text_to_speech( + self, + text: str, + *, + model: Optional[str] = None, + do_sample: Optional[bool] = None, + early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None, + epsilon_cutoff: Optional[float] = None, + eta_cutoff: Optional[float] = None, + max_length: Optional[int] = None, + max_new_tokens: Optional[int] = None, + min_length: Optional[int] = None, + min_new_tokens: Optional[int] = None, + num_beam_groups: Optional[int] = None, + num_beams: Optional[int] = None, + penalty_alpha: Optional[float] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + typical_p: Optional[float] = None, + use_cache: Optional[bool] = None, + extra_body: Optional[Dict[str, Any]] = None, + ) -> bytes: + """ + Synthesize an audio of a voice pronouncing a given text. + + + You can pass provider-specific parameters to the model by using the `extra_body` argument. + + + Args: + text (`str`): + The text to synthesize. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended text-to-speech model will be used. + Defaults to None. + do_sample (`bool`, *optional*): + Whether to use sampling instead of greedy decoding when generating new tokens. + early_stopping (`Union[bool, "TextToSpeechEarlyStoppingEnum"]`, *optional*): + Controls the stopping condition for beam-based methods. + epsilon_cutoff (`float`, *optional*): + If set to float strictly between 0 and 1, only tokens with a conditional probability greater than + epsilon_cutoff will be sampled. In the paper, suggested values range from 3e-4 to 9e-4, depending on + the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. + eta_cutoff (`float`, *optional*): + Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to float strictly + between 0 and 1, a token is only considered if it is greater than either eta_cutoff or sqrt(eta_cutoff) + * exp(-entropy(softmax(next_token_logits))). The latter term is intuitively the expected next token + probability, scaled by sqrt(eta_cutoff). In the paper, suggested values range from 3e-4 to 2e-3, + depending on the size of the model. See [Truncation Sampling as Language Model + Desmoothing](https://hf.co/papers/2210.15191) for more details. + max_length (`int`, *optional*): + The maximum length (in tokens) of the generated text, including the input. + max_new_tokens (`int`, *optional*): + The maximum number of tokens to generate. Takes precedence over max_length. + min_length (`int`, *optional*): + The minimum length (in tokens) of the generated text, including the input. + min_new_tokens (`int`, *optional*): + The minimum number of tokens to generate. Takes precedence over min_length. + num_beam_groups (`int`, *optional*): + Number of groups to divide num_beams into in order to ensure diversity among different groups of beams. + See [this paper](https://hf.co/papers/1610.02424) for more details. + num_beams (`int`, *optional*): + Number of beams to use for beam search. + penalty_alpha (`float`, *optional*): + The value balances the model confidence and the degeneration penalty in contrastive search decoding. + temperature (`float`, *optional*): + The value used to modulate the next token probabilities. + top_k (`int`, *optional*): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional*): + If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to + top_p or higher are kept for generation. + typical_p (`float`, *optional*): + Local typicality measures how similar the conditional probability of predicting a target token next is + to the expected conditional probability of predicting a random token next, given the partial text + already generated. If set to float < 1, the smallest set of the most locally typical tokens with + probabilities that add up to typical_p or higher are kept for generation. See [this + paper](https://hf.co/papers/2202.00666) for more details. + use_cache (`bool`, *optional*): + Whether the model should use the past last key/values attentions to speed up decoding + extra_body (`Dict[str, Any]`, *optional*): + Additional provider-specific parameters to pass to the model. Refer to the provider's documentation + for supported parameters. + Returns: + `bytes`: The generated audio. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from pathlib import Path + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + >>> audio = await client.text_to_speech("Hello world") + >>> Path("hello_world.flac").write_bytes(audio) + ``` + + Example using a third-party provider directly. Usage will be billed on your Replicate account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", + ... api_key="your-replicate-api-key", # Pass your Replicate API key directly + ... ) + >>> audio = client.text_to_speech( + ... text="Hello world", + ... model="OuteAI/OuteTTS-0.3-500M", + ... ) + >>> Path("hello_world.flac").write_bytes(audio) + ``` + + Example using a third-party provider through Hugging Face Routing. Usage will be billed on your Hugging Face account. + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", + ... api_key="hf_...", # Pass your HF token + ... ) + >>> audio =client.text_to_speech( + ... text="Hello world", + ... model="OuteAI/OuteTTS-0.3-500M", + ... ) + >>> Path("hello_world.flac").write_bytes(audio) + ``` + Example using Replicate provider with extra parameters + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient( + ... provider="replicate", # Use replicate provider + ... api_key="hf_...", # Pass your HF token + ... ) + >>> audio = client.text_to_speech( + ... "Hello, my name is Kororo, an awesome text-to-speech model.", + ... model="hexgrad/Kokoro-82M", + ... extra_body={"voice": "af_nicole"}, + ... ) + >>> Path("hello.flac").write_bytes(audio) + ``` + + Example music-gen using "YuE-s1-7B-anneal-en-cot" on fal.ai + ```py + >>> from huggingface_hub import InferenceClient + >>> lyrics = ''' + ... [verse] + ... In the town where I was born + ... Lived a man who sailed to sea + ... And he told us of his life + ... In the land of submarines + ... So we sailed on to the sun + ... 'Til we found a sea of green + ... And we lived beneath the waves + ... In our yellow submarine + + ... [chorus] + ... We all live in a yellow submarine + ... Yellow submarine, yellow submarine + ... We all live in a yellow submarine + ... Yellow submarine, yellow submarine + ... ''' + >>> genres = "pavarotti-style tenor voice" + >>> client = InferenceClient( + ... provider="fal-ai", + ... model="m-a-p/YuE-s1-7B-anneal-en-cot", + ... api_key=..., + ... ) + >>> audio = client.text_to_speech(lyrics, extra_body={"genres": genres}) + >>> with open("output.mp3", "wb") as f: + ... f.write(audio) + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="text-to-speech", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "do_sample": do_sample, + "early_stopping": early_stopping, + "epsilon_cutoff": epsilon_cutoff, + "eta_cutoff": eta_cutoff, + "max_length": max_length, + "max_new_tokens": max_new_tokens, + "min_length": min_length, + "min_new_tokens": min_new_tokens, + "num_beam_groups": num_beam_groups, + "num_beams": num_beams, + "penalty_alpha": penalty_alpha, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "typical_p": typical_p, + "use_cache": use_cache, + **(extra_body or {}), + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + response = provider_helper.get_response(response) + return response + + async def token_classification( + self, + text: str, + *, + model: Optional[str] = None, + aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None, + ignore_labels: Optional[List[str]] = None, + stride: Optional[int] = None, + ) -> List[TokenClassificationOutputElement]: + """ + Perform token classification on the given text. + Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. + + Args: + text (`str`): + A string to be classified. + model (`str`, *optional*): + The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. + Defaults to None. + aggregation_strategy (`"TokenClassificationAggregationStrategy"`, *optional*): + The strategy used to fuse tokens based on model predictions + ignore_labels (`List[str`, *optional*): + A list of labels to ignore + stride (`int`, *optional*): + The number of overlapping tokens between chunks when splitting the input text. + + Returns: + `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica") + [ + TokenClassificationOutputElement( + entity_group='PER', + score=0.9971321225166321, + word='Sarah Jessica Parker', + start=11, + end=31, + ), + TokenClassificationOutputElement( + entity_group='PER', + score=0.9773476123809814, + word='Jessica', + start=52, + end=59, + ) + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="token-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "aggregation_strategy": aggregation_strategy, + "ignore_labels": ignore_labels, + "stride": stride, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return TokenClassificationOutputElement.parse_obj_as_list(response) + + async def translation( + self, + text: str, + *, + model: Optional[str] = None, + src_lang: Optional[str] = None, + tgt_lang: Optional[str] = None, + clean_up_tokenization_spaces: Optional[bool] = None, + truncation: Optional["TranslationTruncationStrategy"] = None, + generate_parameters: Optional[Dict[str, Any]] = None, + ) -> TranslationOutput: + """ + Convert text from one language to another. + + Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for + your specific use case. Source and target languages usually depend on the model. + However, it is possible to specify source and target languages for certain models. If you are working with one of these models, + you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. + + Args: + text (`str`): + A string to be translated. + model (`str`, *optional*): + The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. + Defaults to None. + src_lang (`str`, *optional*): + The source language of the text. Required for models that can translate from multiple languages. + tgt_lang (`str`, *optional*): + Target language to translate to. Required for models that can translate to multiple languages. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether to clean up the potential extra spaces in the text output. + truncation (`"TranslationTruncationStrategy"`, *optional*): + The truncation strategy to use. + generate_parameters (`Dict[str, Any]`, *optional*): + Additional parametrization of the text generation algorithm. + + Returns: + [`TranslationOutput`]: The generated translated text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If only one of the `src_lang` and `tgt_lang` arguments are provided. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.translation("My name is Wolfgang and I live in Berlin") + 'Mein Name ist Wolfgang und ich lebe in Berlin.' + >>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") + TranslationOutput(translation_text='Je m'appelle Wolfgang et je vis à Berlin.') + ``` + + Specifying languages: + ```py + >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") + "Mon nom est Sarah Jessica Parker mais vous pouvez m'appeler Jessica" + ``` + """ + # Throw error if only one of `src_lang` and `tgt_lang` was given + if src_lang is not None and tgt_lang is None: + raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") + + if src_lang is None and tgt_lang is not None: + raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") + + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="translation", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "src_lang": src_lang, + "tgt_lang": tgt_lang, + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + "truncation": truncation, + "generate_parameters": generate_parameters, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return TranslationOutput.parse_obj_as_list(response)[0] + + async def visual_question_answering( + self, + image: ContentT, + question: str, + *, + model: Optional[str] = None, + top_k: Optional[int] = None, + ) -> List[VisualQuestionAnsweringOutputElement]: + """ + Answering open-ended questions based on an image. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image for the context. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + question (`str`): + Question to be answered. + model (`str`, *optional*): + The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. + Defaults to None. + top_k (`int`, *optional*): + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + topk answers if there are not enough options available within the context. + Returns: + `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. + + Raises: + `InferenceTimeoutError`: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.visual_question_answering( + ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg", + ... question="What is the animal doing?" + ... ) + [ + VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'), + VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'), + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="visual-question-answering", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={"top_k": top_k}, + headers=self.headers, + model=model_id, + api_key=self.token, + extra_payload={"question": question, "image": _b64_encode(image)}, + ) + response = await self._inner_post(request_parameters) + return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response) + + async def zero_shot_classification( + self, + text: str, + candidate_labels: List[str], + *, + multi_label: Optional[bool] = False, + hypothesis_template: Optional[str] = None, + model: Optional[str] = None, + ) -> List[ZeroShotClassificationOutputElement]: + """ + Provide as input a text and a set of candidate labels to classify the input text. + + Args: + text (`str`): + The input text to classify. + candidate_labels (`List[str]`): + The set of possible class labels to classify the text into. + labels (`List[str]`, *optional*): + (deprecated) List of strings. Each string is the verbalization of a possible label for the input text. + multi_label (`bool`, *optional*): + Whether multiple candidate labels can be true. If false, the scores are normalized such that the sum of + the label likelihoods for each sequence is 1. If true, the labels are considered independent and + probabilities are normalized for each candidate. + hypothesis_template (`str`, *optional*): + The sentence used in conjunction with `candidate_labels` to attempt the text classification by + replacing the placeholder with the candidate labels. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot classification model will be used. + + + Returns: + `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example with `multi_label=False`: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> text = ( + ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's" + ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" + ... " mysteries when he went for a run up a hill in Nice, France." + ... ) + >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"] + >>> await client.zero_shot_classification(text, labels) + [ + ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684), + ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566), + ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627), + ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581), + ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447), + ] + >>> await client.zero_shot_classification(text, labels, multi_label=True) + [ + ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311), + ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844), + ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714), + ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327), + ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354), + ] + ``` + + Example with `multi_label=True` and a custom `hypothesis_template`: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.zero_shot_classification( + ... text="I really like our dinner and I'm very happy. I don't like the weather though.", + ... labels=["positive", "negative", "pessimistic", "optimistic"], + ... multi_label=True, + ... hypothesis_template="This text is {} towards the weather" + ... ) + [ + ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467), + ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134), + ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062), + ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363) + ] + ``` + """ + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="zero-shot-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=text, + parameters={ + "candidate_labels": candidate_labels, + "multi_label": multi_label, + "hypothesis_template": hypothesis_template, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + output = _bytes_to_dict(response) + return [ + ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score}) + for label, score in zip(output["labels"], output["scores"]) + ] + + async def zero_shot_image_classification( + self, + image: ContentT, + candidate_labels: List[str], + *, + model: Optional[str] = None, + hypothesis_template: Optional[str] = None, + # deprecated argument + labels: List[str] = None, # type: ignore + ) -> List[ZeroShotImageClassificationOutputElement]: + """ + Provide input image and text labels to predict text labels for the image. + + Args: + image (`Union[str, Path, bytes, BinaryIO, PIL.Image.Image]`): + The input image to caption. It can be raw bytes, an image file, a URL to an online image, or a PIL Image. + candidate_labels (`List[str]`): + The candidate labels for this image + labels (`List[str]`, *optional*): + (deprecated) List of string possible labels. There must be at least 2 labels. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. If not provided, the default recommended zero-shot image classification model will be used. + hypothesis_template (`str`, *optional*): + The sentence used in conjunction with `candidate_labels` to attempt the image classification by + replacing the placeholder with the candidate labels. + + Returns: + `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + >>> await client.zero_shot_image_classification( + ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg", + ... labels=["dog", "cat", "horse"], + ... ) + [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...] + ``` + """ + # Raise ValueError if input is less than 2 labels + if len(candidate_labels) < 2: + raise ValueError("You must specify at least 2 classes to compare.") + + model_id = model or self.model + provider_helper = get_provider_helper(self.provider, task="zero-shot-image-classification", model=model_id) + request_parameters = provider_helper.prepare_request( + inputs=image, + parameters={ + "candidate_labels": candidate_labels, + "hypothesis_template": hypothesis_template, + }, + headers=self.headers, + model=model_id, + api_key=self.token, + ) + response = await self._inner_post(request_parameters) + return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) + + @_deprecate_method( + version="0.35.0", + message=( + "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)." + " Use `HfApi.list_models(..., inference_provider='...')` to list warm models per provider." + ), + ) + async def list_deployed_models( + self, frameworks: Union[None, str, Literal["all"], List[str]] = None + ) -> Dict[str, List[str]]: + """ + List models deployed on the HF Serverless Inference API service. + + This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that + are supported and account for 95% of the hosted models. However, if you want a complete list of models you can + specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested + in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more + frameworks are checked, the more time it will take. + + + + This endpoint method does not return a live list of all models available for the HF Inference API service. + It searches over a cached list of models that were recently available and the list may not be up to date. + If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`]. + + + + + + This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to + check its availability, you can directly use [`~InferenceClient.get_model_status`]. + + + + Args: + frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*): + The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to + "all", all available frameworks will be tested. It is also possible to provide a single framework or a + custom set of frameworks to check. + + Returns: + `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs. + + Example: + ```py + # Must be run in an async contextthon + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + # Discover zero-shot-classification models currently deployed + >>> models = await client.list_deployed_models() + >>> models["zero-shot-classification"] + ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...] + + # List from only 1 framework + >>> await client.list_deployed_models("text-generation-inference") + {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...} + ``` + """ + if self.provider != "hf-inference": + raise ValueError(f"Listing deployed models is not supported on '{self.provider}'.") + + # Resolve which frameworks to check + if frameworks is None: + frameworks = constants.MAIN_INFERENCE_API_FRAMEWORKS + elif frameworks == "all": + frameworks = constants.ALL_INFERENCE_API_FRAMEWORKS + elif isinstance(frameworks, str): + frameworks = [frameworks] + frameworks = list(set(frameworks)) + + # Fetch them iteratively + models_by_task: Dict[str, List[str]] = {} + + def _unpack_response(framework: str, items: List[Dict]) -> None: + for model in items: + if framework == "sentence-transformers": + # Model running with the `sentence-transformers` framework can work with both tasks even if not + # branded as such in the API response + models_by_task.setdefault("feature-extraction", []).append(model["model_id"]) + models_by_task.setdefault("sentence-similarity", []).append(model["model_id"]) + else: + models_by_task.setdefault(model["task"], []).append(model["model_id"]) + + for framework in frameworks: + response = get_session().get( + f"{constants.INFERENCE_ENDPOINT}/framework/{framework}", headers=build_hf_headers(token=self.token) + ) + hf_raise_for_status(response) + _unpack_response(framework, response.json()) + + # Sort alphabetically for discoverability and return + for task, models in models_by_task.items(): + models_by_task[task] = sorted(set(models), key=lambda x: x.lower()) + return models_by_task + + def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession": + aiohttp = _import_aiohttp() + client_headers = self.headers.copy() + if headers is not None: + client_headers.update(headers) + + # Return a new aiohttp ClientSession with correct settings. + session = aiohttp.ClientSession( + headers=client_headers, + cookies=self.cookies, + timeout=aiohttp.ClientTimeout(self.timeout), + trust_env=self.trust_env, + ) + + # Keep track of sessions to close them later + self._sessions[session] = set() + + # Override the `._request` method to register responses to be closed + session._wrapped_request = session._request + + async def _request(method, url, **kwargs): + response = await session._wrapped_request(method, url, **kwargs) + self._sessions[session].add(response) + return response + + session._request = _request + + # Override the 'close' method to + # 1. close ongoing responses + # 2. deregister the session when closed + session._close = session.close + + async def close_session(): + for response in self._sessions[session]: + response.close() + await session._close() + self._sessions.pop(session, None) + + session.close = close_session + return session + + async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: + """ + Get information about the deployed endpoint. + + This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). + Endpoints powered by `transformers` return an empty payload. + + Args: + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `Dict[str, Any]`: Information about the endpoint. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> await client.get_endpoint_info() + { + 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct', + 'model_sha': None, + 'model_dtype': 'torch.float16', + 'model_device_type': 'cuda', + 'model_pipeline_tag': None, + 'max_concurrent_requests': 128, + 'max_best_of': 2, + 'max_stop_sequences': 4, + 'max_input_length': 8191, + 'max_total_tokens': 8192, + 'waiting_served_ratio': 0.3, + 'max_batch_total_tokens': 1259392, + 'max_waiting_tokens': 20, + 'max_batch_size': None, + 'validation_workers': 32, + 'max_client_batch_size': 4, + 'version': '2.0.2', + 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214', + 'docker_label': 'sha-dccab72' + } + ``` + """ + if self.provider != "hf-inference": + raise ValueError(f"Getting endpoint info is not supported on '{self.provider}'.") + + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if model.startswith(("http://", "https://")): + url = model.rstrip("/") + "/info" + else: + url = f"{constants.INFERENCE_ENDPOINT}/models/{model}/info" + + async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: + response = await client.get(url, proxy=self.proxies) + response.raise_for_status() + return await response.json() + + async def health_check(self, model: Optional[str] = None) -> bool: + """ + Check the health of the deployed endpoint. + + Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). + For Inference API, please use [`InferenceClient.get_model_status`] instead. + + Args: + model (`str`, *optional*): + URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `bool`: True if everything is working fine. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud") + >>> await client.health_check() + True + ``` + """ + if self.provider != "hf-inference": + raise ValueError(f"Health check is not supported on '{self.provider}'.") + + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if not model.startswith(("http://", "https://")): + raise ValueError( + "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`." + ) + url = model.rstrip("/") + "/health" + + async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: + response = await client.get(url, proxy=self.proxies) + return response.status == 200 + + @_deprecate_method( + version="0.35.0", + message=( + "HF Inference API is getting revamped and will only support warm models in the future (no cold start allowed)." + " Use `HfApi.model_info` to get the model status both with HF Inference API and external providers." + ), + ) + async def get_model_status(self, model: Optional[str] = None) -> ModelStatus: + """ + Get the status of a model hosted on the HF Inference API. + + + + This endpoint is mostly useful when you already know which model you want to use and want to check its + availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`]. + + + + Args: + model (`str`, *optional*): + Identifier of the model for witch the status gonna be checked. If model is not provided, + the model associated with this instance of [`InferenceClient`] will be used. Only HF Inference API service can be checked so the + identifier cannot be a URL. + + + Returns: + [`ModelStatus`]: An instance of ModelStatus dataclass, containing information, + about the state of the model: load, state, compute type and framework. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct") + ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference') + ``` + """ + if self.provider != "hf-inference": + raise ValueError(f"Getting model status is not supported on '{self.provider}'.") + + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if model.startswith("https://"): + raise NotImplementedError("Model status is only available for Inference API endpoints.") + url = f"{constants.INFERENCE_ENDPOINT}/status/{model}" + + async with self._get_client_session(headers=build_hf_headers(token=self.token)) as client: + response = await client.get(url, proxy=self.proxies) + response.raise_for_status() + response_data = await response.json() + + if "error" in response_data: + raise ValueError(response_data["error"]) + + return ModelStatus( + loaded=response_data["loaded"], + state=response_data["state"], + compute_type=response_data["compute_type"], + framework=response_data["framework"], + ) + + @property + def chat(self) -> "ProxyClientChat": + return ProxyClientChat(self) + + +class _ProxyClient: + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + def __init__(self, client: AsyncInferenceClient): + self._client = client + + +class ProxyClientChat(_ProxyClient): + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + @property + def completions(self) -> "ProxyClientChatCompletions": + return ProxyClientChatCompletions(self._client) + + +class ProxyClientChatCompletions(_ProxyClient): + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + @property + def create(self): + return self._client.chat_completion diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfffc0ae3bce71532382ee87d03c40dc376cfae7 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__init__.py @@ -0,0 +1,192 @@ +# This file is auto-generated by `utils/generate_inference_types.py`. +# Do not modify it manually. +# +# ruff: noqa: F401 + +from .audio_classification import ( + AudioClassificationInput, + AudioClassificationOutputElement, + AudioClassificationOutputTransform, + AudioClassificationParameters, +) +from .audio_to_audio import AudioToAudioInput, AudioToAudioOutputElement +from .automatic_speech_recognition import ( + AutomaticSpeechRecognitionEarlyStoppingEnum, + AutomaticSpeechRecognitionGenerationParameters, + AutomaticSpeechRecognitionInput, + AutomaticSpeechRecognitionOutput, + AutomaticSpeechRecognitionOutputChunk, + AutomaticSpeechRecognitionParameters, +) +from .base import BaseInferenceType +from .chat_completion import ( + ChatCompletionInput, + ChatCompletionInputFunctionDefinition, + ChatCompletionInputFunctionName, + ChatCompletionInputGrammarType, + ChatCompletionInputJSONSchema, + ChatCompletionInputMessage, + ChatCompletionInputMessageChunk, + ChatCompletionInputMessageChunkType, + ChatCompletionInputResponseFormatJSONObject, + ChatCompletionInputResponseFormatJSONSchema, + ChatCompletionInputResponseFormatText, + ChatCompletionInputStreamOptions, + ChatCompletionInputTool, + ChatCompletionInputToolCall, + ChatCompletionInputToolChoiceClass, + ChatCompletionInputToolChoiceEnum, + ChatCompletionInputURL, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputLogprob, + ChatCompletionOutputLogprobs, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputTopLogprob, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputLogprob, + ChatCompletionStreamOutputLogprobs, + ChatCompletionStreamOutputTopLogprob, + ChatCompletionStreamOutputUsage, +) +from .depth_estimation import DepthEstimationInput, DepthEstimationOutput +from .document_question_answering import ( + DocumentQuestionAnsweringInput, + DocumentQuestionAnsweringInputData, + DocumentQuestionAnsweringOutputElement, + DocumentQuestionAnsweringParameters, +) +from .feature_extraction import FeatureExtractionInput, FeatureExtractionInputTruncationDirection +from .fill_mask import FillMaskInput, FillMaskOutputElement, FillMaskParameters +from .image_classification import ( + ImageClassificationInput, + ImageClassificationOutputElement, + ImageClassificationOutputTransform, + ImageClassificationParameters, +) +from .image_segmentation import ( + ImageSegmentationInput, + ImageSegmentationOutputElement, + ImageSegmentationParameters, + ImageSegmentationSubtask, +) +from .image_to_image import ImageToImageInput, ImageToImageOutput, ImageToImageParameters, ImageToImageTargetSize +from .image_to_text import ( + ImageToTextEarlyStoppingEnum, + ImageToTextGenerationParameters, + ImageToTextInput, + ImageToTextOutput, + ImageToTextParameters, +) +from .image_to_video import ImageToVideoInput, ImageToVideoOutput, ImageToVideoParameters, ImageToVideoTargetSize +from .object_detection import ( + ObjectDetectionBoundingBox, + ObjectDetectionInput, + ObjectDetectionOutputElement, + ObjectDetectionParameters, +) +from .question_answering import ( + QuestionAnsweringInput, + QuestionAnsweringInputData, + QuestionAnsweringOutputElement, + QuestionAnsweringParameters, +) +from .sentence_similarity import SentenceSimilarityInput, SentenceSimilarityInputData +from .summarization import ( + SummarizationInput, + SummarizationOutput, + SummarizationParameters, + SummarizationTruncationStrategy, +) +from .table_question_answering import ( + Padding, + TableQuestionAnsweringInput, + TableQuestionAnsweringInputData, + TableQuestionAnsweringOutputElement, + TableQuestionAnsweringParameters, +) +from .text2text_generation import ( + Text2TextGenerationInput, + Text2TextGenerationOutput, + Text2TextGenerationParameters, + Text2TextGenerationTruncationStrategy, +) +from .text_classification import ( + TextClassificationInput, + TextClassificationOutputElement, + TextClassificationOutputTransform, + TextClassificationParameters, +) +from .text_generation import ( + TextGenerationInput, + TextGenerationInputGenerateParameters, + TextGenerationInputGrammarType, + TextGenerationOutput, + TextGenerationOutputBestOfSequence, + TextGenerationOutputDetails, + TextGenerationOutputFinishReason, + TextGenerationOutputPrefillToken, + TextGenerationOutputToken, + TextGenerationStreamOutput, + TextGenerationStreamOutputStreamDetails, + TextGenerationStreamOutputToken, + TypeEnum, +) +from .text_to_audio import ( + TextToAudioEarlyStoppingEnum, + TextToAudioGenerationParameters, + TextToAudioInput, + TextToAudioOutput, + TextToAudioParameters, +) +from .text_to_image import TextToImageInput, TextToImageOutput, TextToImageParameters +from .text_to_speech import ( + TextToSpeechEarlyStoppingEnum, + TextToSpeechGenerationParameters, + TextToSpeechInput, + TextToSpeechOutput, + TextToSpeechParameters, +) +from .text_to_video import TextToVideoInput, TextToVideoOutput, TextToVideoParameters +from .token_classification import ( + TokenClassificationAggregationStrategy, + TokenClassificationInput, + TokenClassificationOutputElement, + TokenClassificationParameters, +) +from .translation import TranslationInput, TranslationOutput, TranslationParameters, TranslationTruncationStrategy +from .video_classification import ( + VideoClassificationInput, + VideoClassificationOutputElement, + VideoClassificationOutputTransform, + VideoClassificationParameters, +) +from .visual_question_answering import ( + VisualQuestionAnsweringInput, + VisualQuestionAnsweringInputData, + VisualQuestionAnsweringOutputElement, + VisualQuestionAnsweringParameters, +) +from .zero_shot_classification import ( + ZeroShotClassificationInput, + ZeroShotClassificationOutputElement, + ZeroShotClassificationParameters, +) +from .zero_shot_image_classification import ( + ZeroShotImageClassificationInput, + ZeroShotImageClassificationOutputElement, + ZeroShotImageClassificationParameters, +) +from .zero_shot_object_detection import ( + ZeroShotObjectDetectionBoundingBox, + ZeroShotObjectDetectionInput, + ZeroShotObjectDetectionOutputElement, + ZeroShotObjectDetectionParameters, +) diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..553279d7c15d1a2541f1d60f2daf12a9f16f5150 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/audio_classification.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/audio_classification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2abeeb0f6c036d1a65930f686045203a35c5216 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/audio_classification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/audio_to_audio.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/audio_to_audio.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5f5482deddfdaf81288b6de068afb8ed3097453 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/audio_to_audio.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/automatic_speech_recognition.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/automatic_speech_recognition.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b86b518a53e77eef74afae74b974da8911ffdf4c Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/automatic_speech_recognition.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/base.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f602948ec119236d9e67fcee80dcef9579d1051 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/chat_completion.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/chat_completion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1b4c109680025dca8fe26e9e559b0425a90d90a Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/chat_completion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/depth_estimation.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/depth_estimation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e16ecd30baf6683e2547fd121c47fedc4f282d29 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/depth_estimation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/document_question_answering.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/document_question_answering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..867a148d084ce347a2eb524388af2aa4957cb03f Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/document_question_answering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/feature_extraction.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/feature_extraction.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8030c8e0818c7bd5236d7f70843a42a9d2cb3fb Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/feature_extraction.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/fill_mask.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/fill_mask.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d51b0914b185dba3865271fc23a16fdd15f05237 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/fill_mask.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_classification.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_classification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bee2f956299580d716f560fe7f6cf0963e2f333b Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_classification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_segmentation.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_segmentation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..751085760cce30aef66ba4ab6e293233a5e36f15 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_segmentation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_image.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_image.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2bb5ce3f9c98f4b0d9c8f075643a51f6a0dbbe6 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_image.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_text.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_text.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa1b293a82b75c49a0413d8042b0faffdfd2f7d Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_text.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_video.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_video.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d463a23a547caac311b9fa88a3eb931863aa0caa Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/image_to_video.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/object_detection.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/object_detection.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05e4b80d8db53404c16846de17c29693eebf84f5 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/object_detection.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/question_answering.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/question_answering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c52e7a164a5f8680d6cc8db9c4fbefabae79452 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/question_answering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/sentence_similarity.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/sentence_similarity.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aafa37787f6eeb05cfcde71cf7d86389a69c4abf Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/sentence_similarity.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/summarization.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/summarization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b3c93bb6a7b5f601709f90362452b6af1ec1cf2 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/summarization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/table_question_answering.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/table_question_answering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..545b5ac45a0392b070f59189ac5184e922628ad1 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/table_question_answering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text2text_generation.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text2text_generation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7209f0a64106c5ce834878a21a86318f8f0f8fb9 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text2text_generation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_classification.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_classification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74ee1ca020fd10c3413d7b1f432e303e5a3a9197 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_classification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_generation.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_generation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7cbc2833b706cd6a07555858fee716d0dcbfaf0 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_generation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_audio.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_audio.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e58352d578540ae3ceb16c91c4b4693cf0f93d5e Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_audio.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_image.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_image.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65702e221812bc2050fc31b27292ea36b5378a63 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_image.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_speech.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_speech.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38d97c8916f1bac6cb66b29af51e4ee5e0ec1210 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_speech.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_video.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_video.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42ff2f16f2468c950fb45de9fba522aeb7070c37 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/text_to_video.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/token_classification.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/token_classification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0812b446a1052095a1360b35a18626a1d7077526 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/token_classification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/translation.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/translation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9131c220564260eeddbd401ae5288ff75d51130c Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/translation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/video_classification.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/video_classification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43f17b02a905c5b52b26d9fda44f9f1e7463748f Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/video_classification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/visual_question_answering.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/visual_question_answering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6ad738d4c2ed747ad25fbed0bcc008b203a22b8 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/visual_question_answering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_classification.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_classification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..743842c45b0018648049e458118b96f02573e9c4 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_classification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_image_classification.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_image_classification.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52382f652ace8ee839bbf2696effb0ac4f585b71 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_image_classification.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_object_detection.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_object_detection.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2367650ccf505c8293f106d913681d9609c4dc Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/__pycache__/zero_shot_object_detection.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/audio_classification.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/audio_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..053055787bce933e1fbd393cfbc00d81c43c8c2d --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/audio_classification.py @@ -0,0 +1,43 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +AudioClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] + + +@dataclass_with_extra +class AudioClassificationParameters(BaseInferenceType): + """Additional inference parameters for Audio Classification""" + + function_to_apply: Optional["AudioClassificationOutputTransform"] = None + """The function to apply to the model outputs in order to retrieve the scores.""" + top_k: Optional[int] = None + """When specified, limits the output to the top K most probable classes.""" + + +@dataclass_with_extra +class AudioClassificationInput(BaseInferenceType): + """Inputs for Audio Classification inference""" + + inputs: str + """The input audio data as a base64-encoded string. If no `parameters` are provided, you can + also provide the audio data as a raw bytes payload. + """ + parameters: Optional[AudioClassificationParameters] = None + """Additional inference parameters for Audio Classification""" + + +@dataclass_with_extra +class AudioClassificationOutputElement(BaseInferenceType): + """Outputs for Audio Classification inference""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/audio_to_audio.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/audio_to_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..43f376b5345fab6b854b028d1c17416c020d7bc1 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/audio_to_audio.py @@ -0,0 +1,30 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class AudioToAudioInput(BaseInferenceType): + """Inputs for Audio to Audio inference""" + + inputs: Any + """The input audio data""" + + +@dataclass_with_extra +class AudioToAudioOutputElement(BaseInferenceType): + """Outputs of inference for the Audio To Audio task + A generated audio file with its label. + """ + + blob: Any + """The generated audio file.""" + content_type: str + """The content type of audio file.""" + label: str + """The label of the audio file.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..f6bfd28256c82309b160f337aba5a54e2dd11872 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py @@ -0,0 +1,113 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import List, Literal, Optional, Union + +from .base import BaseInferenceType, dataclass_with_extra + + +AutomaticSpeechRecognitionEarlyStoppingEnum = Literal["never"] + + +@dataclass_with_extra +class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType): + """Parametrization of the text generation process""" + + do_sample: Optional[bool] = None + """Whether to use sampling instead of greedy decoding when generating new tokens.""" + early_stopping: Optional[Union[bool, "AutomaticSpeechRecognitionEarlyStoppingEnum"]] = None + """Controls the stopping condition for beam-based methods.""" + epsilon_cutoff: Optional[float] = None + """If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + """ + eta_cutoff: Optional[float] = None + """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + """ + max_length: Optional[int] = None + """The maximum length (in tokens) of the generated text, including the input.""" + max_new_tokens: Optional[int] = None + """The maximum number of tokens to generate. Takes precedence over max_length.""" + min_length: Optional[int] = None + """The minimum length (in tokens) of the generated text, including the input.""" + min_new_tokens: Optional[int] = None + """The minimum number of tokens to generate. Takes precedence over min_length.""" + num_beam_groups: Optional[int] = None + """Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + """ + num_beams: Optional[int] = None + """Number of beams to use for beam search.""" + penalty_alpha: Optional[float] = None + """The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + """ + temperature: Optional[float] = None + """The value used to modulate the next token probabilities.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_p: Optional[float] = None + """If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + """ + typical_p: Optional[float] = None + """Local typicality measures how similar the conditional probability of predicting a target + token next is to the expected conditional probability of predicting a random token next, + given the partial text already generated. If set to float < 1, the smallest set of the + most locally typical tokens with probabilities that add up to typical_p or higher are + kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. + """ + use_cache: Optional[bool] = None + """Whether the model should use the past last key/values attentions to speed up decoding""" + + +@dataclass_with_extra +class AutomaticSpeechRecognitionParameters(BaseInferenceType): + """Additional inference parameters for Automatic Speech Recognition""" + + generation_parameters: Optional[AutomaticSpeechRecognitionGenerationParameters] = None + """Parametrization of the text generation process""" + return_timestamps: Optional[bool] = None + """Whether to output corresponding timestamps with the generated text""" + + +@dataclass_with_extra +class AutomaticSpeechRecognitionInput(BaseInferenceType): + """Inputs for Automatic Speech Recognition inference""" + + inputs: str + """The input audio data as a base64-encoded string. If no `parameters` are provided, you can + also provide the audio data as a raw bytes payload. + """ + parameters: Optional[AutomaticSpeechRecognitionParameters] = None + """Additional inference parameters for Automatic Speech Recognition""" + + +@dataclass_with_extra +class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType): + text: str + """A chunk of text identified by the model""" + timestamp: List[float] + """The start and end timestamps corresponding with the text""" + + +@dataclass_with_extra +class AutomaticSpeechRecognitionOutput(BaseInferenceType): + """Outputs of inference for the Automatic Speech Recognition task""" + + text: str + """The recognized text.""" + chunks: Optional[List[AutomaticSpeechRecognitionOutputChunk]] = None + """When returnTimestamps is enabled, chunks contains a list of audio chunks identified by + the model. + """ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/base.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/base.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0c4687ceccbfb738da3f38c583c2516d065a01 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/base.py @@ -0,0 +1,161 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a base class for all inference types.""" + +import inspect +import json +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Type, TypeVar, Union, get_args + + +T = TypeVar("T", bound="BaseInferenceType") + + +def _repr_with_extra(self): + fields = list(self.__dataclass_fields__.keys()) + other_fields = list(k for k in self.__dict__ if k not in fields) + return f"{self.__class__.__name__}({', '.join(f'{k}={self.__dict__[k]!r}' for k in fields + other_fields)})" + + +def dataclass_with_extra(cls: Type[T]) -> Type[T]: + """Decorator to add a custom __repr__ method to a dataclass, showing all fields, including extra ones. + + This decorator only works with dataclasses that inherit from `BaseInferenceType`. + """ + cls = dataclass(cls) + cls.__repr__ = _repr_with_extra # type: ignore[method-assign] + return cls + + +@dataclass +class BaseInferenceType(dict): + """Base class for all inference types. + + Object is a dataclass and a dict for backward compatibility but plan is to remove the dict part in the future. + + Handle parsing from dict, list and json strings in a permissive way to ensure future-compatibility (e.g. all fields + are made optional, and non-expected fields are added as dict attributes). + """ + + @classmethod + def parse_obj_as_list(cls: Type[T], data: Union[bytes, str, List, Dict]) -> List[T]: + """Alias to parse server response and return a single instance. + + See `parse_obj` for more details. + """ + output = cls.parse_obj(data) + if not isinstance(output, list): + raise ValueError(f"Invalid input data for {cls}. Expected a list, but got {type(output)}.") + return output + + @classmethod + def parse_obj_as_instance(cls: Type[T], data: Union[bytes, str, List, Dict]) -> T: + """Alias to parse server response and return a single instance. + + See `parse_obj` for more details. + """ + output = cls.parse_obj(data) + if isinstance(output, list): + raise ValueError(f"Invalid input data for {cls}. Expected a single instance, but got a list.") + return output + + @classmethod + def parse_obj(cls: Type[T], data: Union[bytes, str, List, Dict]) -> Union[List[T], T]: + """Parse server response as a dataclass or list of dataclasses. + + To enable future-compatibility, we want to handle cases where the server return more fields than expected. + In such cases, we don't want to raise an error but still create the dataclass object. Remaining fields are + added as dict attributes. + """ + # Parse server response (from bytes) + if isinstance(data, bytes): + data = data.decode() + if isinstance(data, str): + data = json.loads(data) + + # If a list, parse each item individually + if isinstance(data, List): + return [cls.parse_obj(d) for d in data] # type: ignore [misc] + + # At this point, we expect a dict + if not isinstance(data, dict): + raise ValueError(f"Invalid data type: {type(data)}") + + init_values = {} + other_values = {} + for key, value in data.items(): + key = normalize_key(key) + if key in cls.__dataclass_fields__ and cls.__dataclass_fields__[key].init: + if isinstance(value, dict) or isinstance(value, list): + field_type = cls.__dataclass_fields__[key].type + + # if `field_type` is a `BaseInferenceType`, parse it + if inspect.isclass(field_type) and issubclass(field_type, BaseInferenceType): + value = field_type.parse_obj(value) + + # otherwise, recursively parse nested dataclasses (if possible) + # `get_args` returns handle Union and Optional for us + else: + expected_types = get_args(field_type) + for expected_type in expected_types: + if getattr(expected_type, "_name", None) == "List": + expected_type = get_args(expected_type)[ + 0 + ] # assume same type for all items in the list + if inspect.isclass(expected_type) and issubclass(expected_type, BaseInferenceType): + value = expected_type.parse_obj(value) + break + init_values[key] = value + else: + other_values[key] = value + + # Make all missing fields default to None + # => ensure that dataclass initialization will never fail even if the server does not return all fields. + for key in cls.__dataclass_fields__: + if key not in init_values: + init_values[key] = None + + # Initialize dataclass with expected values + item = cls(**init_values) + + # Add remaining fields as dict attributes + item.update(other_values) + + # Add remaining fields as extra dataclass fields. + # They won't be part of the dataclass fields but will be accessible as attributes. + # Use @dataclass_with_extra to show them in __repr__. + item.__dict__.update(other_values) + return item + + def __post_init__(self): + self.update(asdict(self)) + + def __setitem__(self, __key: Any, __value: Any) -> None: + # Hacky way to keep dataclass values in sync when dict is updated + super().__setitem__(__key, __value) + if __key in self.__dataclass_fields__ and getattr(self, __key, None) != __value: + self.__setattr__(__key, __value) + return + + def __setattr__(self, __name: str, __value: Any) -> None: + # Hacky way to keep dict values is sync when dataclass is updated + super().__setattr__(__name, __value) + if self.get(__name) != __value: + self[__name] = __value + return + + +def normalize_key(key: str) -> str: + # e.g "content-type" -> "content_type", "Accept" -> "accept" + return key.replace("-", "_").replace(" ", "_").lower() diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/chat_completion.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/chat_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..fe455ee71084920a3b8b246d875b8ab1ef555ad9 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/chat_completion.py @@ -0,0 +1,345 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Dict, List, Literal, Optional, Union + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class ChatCompletionInputURL(BaseInferenceType): + url: str + + +ChatCompletionInputMessageChunkType = Literal["text", "image_url"] + + +@dataclass_with_extra +class ChatCompletionInputMessageChunk(BaseInferenceType): + type: "ChatCompletionInputMessageChunkType" + image_url: Optional[ChatCompletionInputURL] = None + text: Optional[str] = None + + +@dataclass_with_extra +class ChatCompletionInputFunctionDefinition(BaseInferenceType): + name: str + parameters: Any + description: Optional[str] = None + + +@dataclass_with_extra +class ChatCompletionInputToolCall(BaseInferenceType): + function: ChatCompletionInputFunctionDefinition + id: str + type: str + + +@dataclass_with_extra +class ChatCompletionInputMessage(BaseInferenceType): + role: str + content: Optional[Union[List[ChatCompletionInputMessageChunk], str]] = None + name: Optional[str] = None + tool_calls: Optional[List[ChatCompletionInputToolCall]] = None + + +@dataclass_with_extra +class ChatCompletionInputJSONSchema(BaseInferenceType): + name: str + """ + The name of the response format. + """ + description: Optional[str] = None + """ + A description of what the response format is for, used by the model to determine + how to respond in the format. + """ + schema: Optional[Dict[str, object]] = None + """ + The schema for the response format, described as a JSON Schema object. Learn how + to build JSON schemas [here](https://json-schema.org/). + """ + strict: Optional[bool] = None + """ + Whether to enable strict schema adherence when generating the output. If set to + true, the model will always follow the exact schema defined in the `schema` + field. + """ + + +@dataclass_with_extra +class ChatCompletionInputResponseFormatText(BaseInferenceType): + type: Literal["text"] + + +@dataclass_with_extra +class ChatCompletionInputResponseFormatJSONSchema(BaseInferenceType): + type: Literal["json_schema"] + json_schema: ChatCompletionInputJSONSchema + + +@dataclass_with_extra +class ChatCompletionInputResponseFormatJSONObject(BaseInferenceType): + type: Literal["json_object"] + + +ChatCompletionInputGrammarType = Union[ + ChatCompletionInputResponseFormatText, + ChatCompletionInputResponseFormatJSONSchema, + ChatCompletionInputResponseFormatJSONObject, +] + + +@dataclass_with_extra +class ChatCompletionInputStreamOptions(BaseInferenceType): + include_usage: Optional[bool] = None + """If set, an additional chunk will be streamed before the data: [DONE] message. The usage + field on this chunk shows the token usage statistics for the entire request, and the + choices field will always be an empty array. All other chunks will also include a usage + field, but with a null value. + """ + + +@dataclass_with_extra +class ChatCompletionInputFunctionName(BaseInferenceType): + name: str + + +@dataclass_with_extra +class ChatCompletionInputToolChoiceClass(BaseInferenceType): + function: ChatCompletionInputFunctionName + + +ChatCompletionInputToolChoiceEnum = Literal["auto", "none", "required"] + + +@dataclass_with_extra +class ChatCompletionInputTool(BaseInferenceType): + function: ChatCompletionInputFunctionDefinition + type: str + + +@dataclass_with_extra +class ChatCompletionInput(BaseInferenceType): + """Chat Completion Input. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + messages: List[ChatCompletionInputMessage] + """A list of messages comprising the conversation so far.""" + frequency_penalty: Optional[float] = None + """Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + """ + logit_bias: Optional[List[float]] = None + """UNUSED + Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON + object that maps tokens + (specified by their token ID in the tokenizer) to an associated bias value from -100 to + 100. Mathematically, + the bias is added to the logits generated by the model prior to sampling. The exact + effect will vary per model, + but values between -1 and 1 should decrease or increase likelihood of selection; values + like -100 or 100 should + result in a ban or exclusive selection of the relevant token. + """ + logprobs: Optional[bool] = None + """Whether to return log probabilities of the output tokens or not. If true, returns the log + probabilities of each + output token returned in the content of message. + """ + max_tokens: Optional[int] = None + """The maximum number of tokens that can be generated in the chat completion.""" + model: Optional[str] = None + """[UNUSED] ID of the model to use. See the model endpoint compatibility table for details + on which models work with the Chat API. + """ + n: Optional[int] = None + """UNUSED + How many chat completion choices to generate for each input message. Note that you will + be charged based on the + number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + """ + presence_penalty: Optional[float] = None + """Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + appear in the text so far, + increasing the model's likelihood to talk about new topics + """ + response_format: Optional[ChatCompletionInputGrammarType] = None + seed: Optional[int] = None + stop: Optional[List[str]] = None + """Up to 4 sequences where the API will stop generating further tokens.""" + stream: Optional[bool] = None + stream_options: Optional[ChatCompletionInputStreamOptions] = None + temperature: Optional[float] = None + """What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the + output more random, while + lower values like 0.2 will make it more focused and deterministic. + We generally recommend altering this or `top_p` but not both. + """ + tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None + tool_prompt: Optional[str] = None + """A prompt to be appended before the tools""" + tools: Optional[List[ChatCompletionInputTool]] = None + """A list of tools the model may call. Currently, only functions are supported as a tool. + Use this to provide a list of + functions the model may generate JSON inputs for. + """ + top_logprobs: Optional[int] = None + """An integer between 0 and 5 specifying the number of most likely tokens to return at each + token position, each with + an associated log probability. logprobs must be set to true if this parameter is used. + """ + top_p: Optional[float] = None + """An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the + tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + probability mass are considered. + """ + + +@dataclass_with_extra +class ChatCompletionOutputTopLogprob(BaseInferenceType): + logprob: float + token: str + + +@dataclass_with_extra +class ChatCompletionOutputLogprob(BaseInferenceType): + logprob: float + token: str + top_logprobs: List[ChatCompletionOutputTopLogprob] + + +@dataclass_with_extra +class ChatCompletionOutputLogprobs(BaseInferenceType): + content: List[ChatCompletionOutputLogprob] + + +@dataclass_with_extra +class ChatCompletionOutputFunctionDefinition(BaseInferenceType): + arguments: str + name: str + description: Optional[str] = None + + +@dataclass_with_extra +class ChatCompletionOutputToolCall(BaseInferenceType): + function: ChatCompletionOutputFunctionDefinition + id: str + type: str + + +@dataclass_with_extra +class ChatCompletionOutputMessage(BaseInferenceType): + role: str + content: Optional[str] = None + tool_call_id: Optional[str] = None + tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None + + +@dataclass_with_extra +class ChatCompletionOutputComplete(BaseInferenceType): + finish_reason: str + index: int + message: ChatCompletionOutputMessage + logprobs: Optional[ChatCompletionOutputLogprobs] = None + + +@dataclass_with_extra +class ChatCompletionOutputUsage(BaseInferenceType): + completion_tokens: int + prompt_tokens: int + total_tokens: int + + +@dataclass_with_extra +class ChatCompletionOutput(BaseInferenceType): + """Chat Completion Output. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + choices: List[ChatCompletionOutputComplete] + created: int + id: str + model: str + system_fingerprint: str + usage: ChatCompletionOutputUsage + + +@dataclass_with_extra +class ChatCompletionStreamOutputFunction(BaseInferenceType): + arguments: str + name: Optional[str] = None + + +@dataclass_with_extra +class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType): + function: ChatCompletionStreamOutputFunction + id: str + index: int + type: str + + +@dataclass_with_extra +class ChatCompletionStreamOutputDelta(BaseInferenceType): + role: str + content: Optional[str] = None + tool_call_id: Optional[str] = None + tool_calls: Optional[List[ChatCompletionStreamOutputDeltaToolCall]] = None + + +@dataclass_with_extra +class ChatCompletionStreamOutputTopLogprob(BaseInferenceType): + logprob: float + token: str + + +@dataclass_with_extra +class ChatCompletionStreamOutputLogprob(BaseInferenceType): + logprob: float + token: str + top_logprobs: List[ChatCompletionStreamOutputTopLogprob] + + +@dataclass_with_extra +class ChatCompletionStreamOutputLogprobs(BaseInferenceType): + content: List[ChatCompletionStreamOutputLogprob] + + +@dataclass_with_extra +class ChatCompletionStreamOutputChoice(BaseInferenceType): + delta: ChatCompletionStreamOutputDelta + index: int + finish_reason: Optional[str] = None + logprobs: Optional[ChatCompletionStreamOutputLogprobs] = None + + +@dataclass_with_extra +class ChatCompletionStreamOutputUsage(BaseInferenceType): + completion_tokens: int + prompt_tokens: int + total_tokens: int + + +@dataclass_with_extra +class ChatCompletionStreamOutput(BaseInferenceType): + """Chat Completion Stream Output. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + choices: List[ChatCompletionStreamOutputChoice] + created: int + id: str + model: str + system_fingerprint: str + usage: Optional[ChatCompletionStreamOutputUsage] = None diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/depth_estimation.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/depth_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..1e09bdffa194f97444e484de6e930f67ac030207 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/depth_estimation.py @@ -0,0 +1,28 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Dict, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class DepthEstimationInput(BaseInferenceType): + """Inputs for Depth Estimation inference""" + + inputs: Any + """The input image data""" + parameters: Optional[Dict[str, Any]] = None + """Additional inference parameters for Depth Estimation""" + + +@dataclass_with_extra +class DepthEstimationOutput(BaseInferenceType): + """Outputs of inference for the Depth Estimation task""" + + depth: Any + """The predicted depth as an image""" + predicted_depth: Any + """The predicted depth as a tensor""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/document_question_answering.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/document_question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..2457d2c8c237f055f660e0e8291d846bb036949d --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/document_question_answering.py @@ -0,0 +1,80 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, List, Optional, Union + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class DocumentQuestionAnsweringInputData(BaseInferenceType): + """One (document, question) pair to answer""" + + image: Any + """The image on which the question is asked""" + question: str + """A question to ask of the document""" + + +@dataclass_with_extra +class DocumentQuestionAnsweringParameters(BaseInferenceType): + """Additional inference parameters for Document Question Answering""" + + doc_stride: Optional[int] = None + """If the words in the document are too long to fit with the question for the model, it will + be split in several chunks with some overlap. This argument controls the size of that + overlap. + """ + handle_impossible_answer: Optional[bool] = None + """Whether to accept impossible as an answer""" + lang: Optional[str] = None + """Language to use while running OCR. Defaults to english.""" + max_answer_len: Optional[int] = None + """The maximum length of predicted answers (e.g., only answers with a shorter length are + considered). + """ + max_question_len: Optional[int] = None + """The maximum length of the question after tokenization. It will be truncated if needed.""" + max_seq_len: Optional[int] = None + """The maximum length of the total sentence (context + question) in tokens of each chunk + passed to the model. The context will be split in several chunks (using doc_stride as + overlap) if needed. + """ + top_k: Optional[int] = None + """The number of answers to return (will be chosen by order of likelihood). Can return less + than top_k answers if there are not enough options available within the context. + """ + word_boxes: Optional[List[Union[List[float], str]]] = None + """A list of words and bounding boxes (normalized 0->1000). If provided, the inference will + skip the OCR step and use the provided bounding boxes instead. + """ + + +@dataclass_with_extra +class DocumentQuestionAnsweringInput(BaseInferenceType): + """Inputs for Document Question Answering inference""" + + inputs: DocumentQuestionAnsweringInputData + """One (document, question) pair to answer""" + parameters: Optional[DocumentQuestionAnsweringParameters] = None + """Additional inference parameters for Document Question Answering""" + + +@dataclass_with_extra +class DocumentQuestionAnsweringOutputElement(BaseInferenceType): + """Outputs of inference for the Document Question Answering task""" + + answer: str + """The answer to the question.""" + end: int + """The end word index of the answer (in the OCR’d version of the input or provided word + boxes). + """ + score: float + """The probability associated to the answer.""" + start: int + """The start word index of the answer (in the OCR’d version of the input or provided word + boxes). + """ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/feature_extraction.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..e965ddbac2af0a5bf73e662a7c18c847611d18a1 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/feature_extraction.py @@ -0,0 +1,36 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import List, Literal, Optional, Union + +from .base import BaseInferenceType, dataclass_with_extra + + +FeatureExtractionInputTruncationDirection = Literal["Left", "Right"] + + +@dataclass_with_extra +class FeatureExtractionInput(BaseInferenceType): + """Feature Extraction Input. + Auto-generated from TEI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts. + """ + + inputs: Union[List[str], str] + """The text or list of texts to embed.""" + normalize: Optional[bool] = None + prompt_name: Optional[str] = None + """The name of the prompt that should be used by for encoding. If not set, no prompt + will be applied. + Must be a key in the `sentence-transformers` configuration `prompts` dictionary. + For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", + ...}, + then the sentence "What is the capital of France?" will be encoded as + "query: What is the capital of France?" because the prompt text will be prepended before + any text to encode. + """ + truncate: Optional[bool] = None + truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/fill_mask.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/fill_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcdc56bc507e50280d38e0f63b024ada6a7ea94 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/fill_mask.py @@ -0,0 +1,47 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, List, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class FillMaskParameters(BaseInferenceType): + """Additional inference parameters for Fill Mask""" + + targets: Optional[List[str]] = None + """When passed, the model will limit the scores to the passed targets instead of looking up + in the whole vocabulary. If the provided targets are not in the model vocab, they will be + tokenized and the first resulting token will be used (with a warning, and that might be + slower). + """ + top_k: Optional[int] = None + """When passed, overrides the number of predictions to return.""" + + +@dataclass_with_extra +class FillMaskInput(BaseInferenceType): + """Inputs for Fill Mask inference""" + + inputs: str + """The text with masked tokens""" + parameters: Optional[FillMaskParameters] = None + """Additional inference parameters for Fill Mask""" + + +@dataclass_with_extra +class FillMaskOutputElement(BaseInferenceType): + """Outputs of inference for the Fill Mask task""" + + score: float + """The corresponding probability""" + sequence: str + """The corresponding input with the mask token prediction.""" + token: int + """The predicted token id (to replace the masked one).""" + token_str: Any + fill_mask_output_token_str: Optional[str] = None + """The predicted token (to replace the masked one).""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_classification.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..0fdda6c83ff4c7aee5dc7794f0530e89d6b43047 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_classification.py @@ -0,0 +1,43 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +ImageClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] + + +@dataclass_with_extra +class ImageClassificationParameters(BaseInferenceType): + """Additional inference parameters for Image Classification""" + + function_to_apply: Optional["ImageClassificationOutputTransform"] = None + """The function to apply to the model outputs in order to retrieve the scores.""" + top_k: Optional[int] = None + """When specified, limits the output to the top K most probable classes.""" + + +@dataclass_with_extra +class ImageClassificationInput(BaseInferenceType): + """Inputs for Image Classification inference""" + + inputs: str + """The input image data as a base64-encoded string. If no `parameters` are provided, you can + also provide the image data as a raw bytes payload. + """ + parameters: Optional[ImageClassificationParameters] = None + """Additional inference parameters for Image Classification""" + + +@dataclass_with_extra +class ImageClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Image Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_segmentation.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbf61db83ec2ae6ceafd901c4425567cd2e5b03 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_segmentation.py @@ -0,0 +1,51 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +ImageSegmentationSubtask = Literal["instance", "panoptic", "semantic"] + + +@dataclass_with_extra +class ImageSegmentationParameters(BaseInferenceType): + """Additional inference parameters for Image Segmentation""" + + mask_threshold: Optional[float] = None + """Threshold to use when turning the predicted masks into binary values.""" + overlap_mask_area_threshold: Optional[float] = None + """Mask overlap threshold to eliminate small, disconnected segments.""" + subtask: Optional["ImageSegmentationSubtask"] = None + """Segmentation task to be performed, depending on model capabilities.""" + threshold: Optional[float] = None + """Probability threshold to filter out predicted masks.""" + + +@dataclass_with_extra +class ImageSegmentationInput(BaseInferenceType): + """Inputs for Image Segmentation inference""" + + inputs: str + """The input image data as a base64-encoded string. If no `parameters` are provided, you can + also provide the image data as a raw bytes payload. + """ + parameters: Optional[ImageSegmentationParameters] = None + """Additional inference parameters for Image Segmentation""" + + +@dataclass_with_extra +class ImageSegmentationOutputElement(BaseInferenceType): + """Outputs of inference for the Image Segmentation task + A predicted mask / segment + """ + + label: str + """The label of the predicted segment.""" + mask: str + """The corresponding mask as a black-and-white image (base64-encoded).""" + score: Optional[float] = None + """The score or confidence degree the model has.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_image.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..e99e719f8b837c76392b58d33ca19f9b615e857e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_image.py @@ -0,0 +1,56 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class ImageToImageTargetSize(BaseInferenceType): + """The size in pixel of the output image.""" + + height: int + width: int + + +@dataclass_with_extra +class ImageToImageParameters(BaseInferenceType): + """Additional inference parameters for Image To Image""" + + guidance_scale: Optional[float] = None + """For diffusion models. A higher guidance scale value encourages the model to generate + images closely linked to the text prompt at the expense of lower image quality. + """ + negative_prompt: Optional[str] = None + """One prompt to guide what NOT to include in image generation.""" + num_inference_steps: Optional[int] = None + """For diffusion models. The number of denoising steps. More denoising steps usually lead to + a higher quality image at the expense of slower inference. + """ + prompt: Optional[str] = None + """The text prompt to guide the image generation.""" + target_size: Optional[ImageToImageTargetSize] = None + """The size in pixel of the output image.""" + + +@dataclass_with_extra +class ImageToImageInput(BaseInferenceType): + """Inputs for Image To Image inference""" + + inputs: str + """The input image data as a base64-encoded string. If no `parameters` are provided, you can + also provide the image data as a raw bytes payload. + """ + parameters: Optional[ImageToImageParameters] = None + """Additional inference parameters for Image To Image""" + + +@dataclass_with_extra +class ImageToImageOutput(BaseInferenceType): + """Outputs of inference for the Image To Image task""" + + image: Any + """The output image returned as raw bytes in the payload.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_text.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..b65e0e0068e80dbcab5a4706fb5d49be2538c4ca --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_text.py @@ -0,0 +1,100 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Literal, Optional, Union + +from .base import BaseInferenceType, dataclass_with_extra + + +ImageToTextEarlyStoppingEnum = Literal["never"] + + +@dataclass_with_extra +class ImageToTextGenerationParameters(BaseInferenceType): + """Parametrization of the text generation process""" + + do_sample: Optional[bool] = None + """Whether to use sampling instead of greedy decoding when generating new tokens.""" + early_stopping: Optional[Union[bool, "ImageToTextEarlyStoppingEnum"]] = None + """Controls the stopping condition for beam-based methods.""" + epsilon_cutoff: Optional[float] = None + """If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + """ + eta_cutoff: Optional[float] = None + """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + """ + max_length: Optional[int] = None + """The maximum length (in tokens) of the generated text, including the input.""" + max_new_tokens: Optional[int] = None + """The maximum number of tokens to generate. Takes precedence over max_length.""" + min_length: Optional[int] = None + """The minimum length (in tokens) of the generated text, including the input.""" + min_new_tokens: Optional[int] = None + """The minimum number of tokens to generate. Takes precedence over min_length.""" + num_beam_groups: Optional[int] = None + """Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + """ + num_beams: Optional[int] = None + """Number of beams to use for beam search.""" + penalty_alpha: Optional[float] = None + """The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + """ + temperature: Optional[float] = None + """The value used to modulate the next token probabilities.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_p: Optional[float] = None + """If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + """ + typical_p: Optional[float] = None + """Local typicality measures how similar the conditional probability of predicting a target + token next is to the expected conditional probability of predicting a random token next, + given the partial text already generated. If set to float < 1, the smallest set of the + most locally typical tokens with probabilities that add up to typical_p or higher are + kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. + """ + use_cache: Optional[bool] = None + """Whether the model should use the past last key/values attentions to speed up decoding""" + + +@dataclass_with_extra +class ImageToTextParameters(BaseInferenceType): + """Additional inference parameters for Image To Text""" + + generation_parameters: Optional[ImageToTextGenerationParameters] = None + """Parametrization of the text generation process""" + max_new_tokens: Optional[int] = None + """The amount of maximum tokens to generate.""" + + +@dataclass_with_extra +class ImageToTextInput(BaseInferenceType): + """Inputs for Image To Text inference""" + + inputs: Any + """The input image data""" + parameters: Optional[ImageToTextParameters] = None + """Additional inference parameters for Image To Text""" + + +@dataclass_with_extra +class ImageToTextOutput(BaseInferenceType): + """Outputs of inference for the Image To Text task""" + + generated_text: Any + image_to_text_output_generated_text: Optional[str] = None + """The generated text.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_video.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..92192a2a05b7a825c6dd55e96702fece0f3b3316 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/image_to_video.py @@ -0,0 +1,60 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class ImageToVideoTargetSize(BaseInferenceType): + """The size in pixel of the output video frames.""" + + height: int + width: int + + +@dataclass_with_extra +class ImageToVideoParameters(BaseInferenceType): + """Additional inference parameters for Image To Video""" + + guidance_scale: Optional[float] = None + """For diffusion models. A higher guidance scale value encourages the model to generate + videos closely linked to the text prompt at the expense of lower image quality. + """ + negative_prompt: Optional[str] = None + """One prompt to guide what NOT to include in video generation.""" + num_frames: Optional[float] = None + """The num_frames parameter determines how many video frames are generated.""" + num_inference_steps: Optional[int] = None + """The number of denoising steps. More denoising steps usually lead to a higher quality + video at the expense of slower inference. + """ + prompt: Optional[str] = None + """The text prompt to guide the video generation.""" + seed: Optional[int] = None + """Seed for the random number generator.""" + target_size: Optional[ImageToVideoTargetSize] = None + """The size in pixel of the output video frames.""" + + +@dataclass_with_extra +class ImageToVideoInput(BaseInferenceType): + """Inputs for Image To Video inference""" + + inputs: str + """The input image data as a base64-encoded string. If no `parameters` are provided, you can + also provide the image data as a raw bytes payload. + """ + parameters: Optional[ImageToVideoParameters] = None + """Additional inference parameters for Image To Video""" + + +@dataclass_with_extra +class ImageToVideoOutput(BaseInferenceType): + """Outputs of inference for the Image To Video task""" + + video: Any + """The generated video returned as raw bytes in the payload.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/object_detection.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/object_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..75f3ebcfe1199462d0df60879b5ba6e517f7001e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/object_detection.py @@ -0,0 +1,58 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class ObjectDetectionParameters(BaseInferenceType): + """Additional inference parameters for Object Detection""" + + threshold: Optional[float] = None + """The probability necessary to make a prediction.""" + + +@dataclass_with_extra +class ObjectDetectionInput(BaseInferenceType): + """Inputs for Object Detection inference""" + + inputs: str + """The input image data as a base64-encoded string. If no `parameters` are provided, you can + also provide the image data as a raw bytes payload. + """ + parameters: Optional[ObjectDetectionParameters] = None + """Additional inference parameters for Object Detection""" + + +@dataclass_with_extra +class ObjectDetectionBoundingBox(BaseInferenceType): + """The predicted bounding box. Coordinates are relative to the top left corner of the input + image. + """ + + xmax: int + """The x-coordinate of the bottom-right corner of the bounding box.""" + xmin: int + """The x-coordinate of the top-left corner of the bounding box.""" + ymax: int + """The y-coordinate of the bottom-right corner of the bounding box.""" + ymin: int + """The y-coordinate of the top-left corner of the bounding box.""" + + +@dataclass_with_extra +class ObjectDetectionOutputElement(BaseInferenceType): + """Outputs of inference for the Object Detection task""" + + box: ObjectDetectionBoundingBox + """The predicted bounding box. Coordinates are relative to the top left corner of the input + image. + """ + label: str + """The predicted label for the bounding box.""" + score: float + """The associated score / probability.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/question_answering.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..014ab41893c560a2c266bc04a1d60bc933be31c7 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/question_answering.py @@ -0,0 +1,74 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class QuestionAnsweringInputData(BaseInferenceType): + """One (context, question) pair to answer""" + + context: str + """The context to be used for answering the question""" + question: str + """The question to be answered""" + + +@dataclass_with_extra +class QuestionAnsweringParameters(BaseInferenceType): + """Additional inference parameters for Question Answering""" + + align_to_words: Optional[bool] = None + """Attempts to align the answer to real words. Improves quality on space separated + languages. Might hurt on non-space-separated languages (like Japanese or Chinese) + """ + doc_stride: Optional[int] = None + """If the context is too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. + """ + handle_impossible_answer: Optional[bool] = None + """Whether to accept impossible as an answer.""" + max_answer_len: Optional[int] = None + """The maximum length of predicted answers (e.g., only answers with a shorter length are + considered). + """ + max_question_len: Optional[int] = None + """The maximum length of the question after tokenization. It will be truncated if needed.""" + max_seq_len: Optional[int] = None + """The maximum length of the total sentence (context + question) in tokens of each chunk + passed to the model. The context will be split in several chunks (using docStride as + overlap) if needed. + """ + top_k: Optional[int] = None + """The number of answers to return (will be chosen by order of likelihood). Note that we + return less than topk answers if there are not enough options available within the + context. + """ + + +@dataclass_with_extra +class QuestionAnsweringInput(BaseInferenceType): + """Inputs for Question Answering inference""" + + inputs: QuestionAnsweringInputData + """One (context, question) pair to answer""" + parameters: Optional[QuestionAnsweringParameters] = None + """Additional inference parameters for Question Answering""" + + +@dataclass_with_extra +class QuestionAnsweringOutputElement(BaseInferenceType): + """Outputs of inference for the Question Answering task""" + + answer: str + """The answer to the question.""" + end: int + """The character position in the input where the answer ends.""" + score: float + """The probability associated to the answer.""" + start: int + """The character position in the input where the answer begins.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/sentence_similarity.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/sentence_similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..66e8bb4d9322d4847556b7a17dc17bd208a37d0c --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/sentence_similarity.py @@ -0,0 +1,27 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Dict, List, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class SentenceSimilarityInputData(BaseInferenceType): + sentences: List[str] + """A list of strings which will be compared against the source_sentence.""" + source_sentence: str + """The string that you wish to compare the other strings with. This can be a phrase, + sentence, or longer passage, depending on the model being used. + """ + + +@dataclass_with_extra +class SentenceSimilarityInput(BaseInferenceType): + """Inputs for Sentence similarity inference""" + + inputs: SentenceSimilarityInputData + parameters: Optional[Dict[str, Any]] = None + """Additional inference parameters for Sentence Similarity""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/summarization.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/summarization.py new file mode 100644 index 0000000000000000000000000000000000000000..33eae6fcba0e8724babf145f93be005868429c33 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/summarization.py @@ -0,0 +1,41 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Dict, Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +SummarizationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] + + +@dataclass_with_extra +class SummarizationParameters(BaseInferenceType): + """Additional inference parameters for summarization.""" + + clean_up_tokenization_spaces: Optional[bool] = None + """Whether to clean up the potential extra spaces in the text output.""" + generate_parameters: Optional[Dict[str, Any]] = None + """Additional parametrization of the text generation algorithm.""" + truncation: Optional["SummarizationTruncationStrategy"] = None + """The truncation strategy to use.""" + + +@dataclass_with_extra +class SummarizationInput(BaseInferenceType): + """Inputs for Summarization inference""" + + inputs: str + """The input text to summarize.""" + parameters: Optional[SummarizationParameters] = None + """Additional inference parameters for summarization.""" + + +@dataclass_with_extra +class SummarizationOutput(BaseInferenceType): + """Outputs of inference for the Summarization task""" + + summary_text: str + """The summarized text.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/table_question_answering.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/table_question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..10e208eeeb50a689d2826a160432a2b005ec006c --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/table_question_answering.py @@ -0,0 +1,62 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Dict, List, Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class TableQuestionAnsweringInputData(BaseInferenceType): + """One (table, question) pair to answer""" + + question: str + """The question to be answered about the table""" + table: Dict[str, List[str]] + """The table to serve as context for the questions""" + + +Padding = Literal["do_not_pad", "longest", "max_length"] + + +@dataclass_with_extra +class TableQuestionAnsweringParameters(BaseInferenceType): + """Additional inference parameters for Table Question Answering""" + + padding: Optional["Padding"] = None + """Activates and controls padding.""" + sequential: Optional[bool] = None + """Whether to do inference sequentially or as a batch. Batching is faster, but models like + SQA require the inference to be done sequentially to extract relations within sequences, + given their conversational nature. + """ + truncation: Optional[bool] = None + """Activates and controls truncation.""" + + +@dataclass_with_extra +class TableQuestionAnsweringInput(BaseInferenceType): + """Inputs for Table Question Answering inference""" + + inputs: TableQuestionAnsweringInputData + """One (table, question) pair to answer""" + parameters: Optional[TableQuestionAnsweringParameters] = None + """Additional inference parameters for Table Question Answering""" + + +@dataclass_with_extra +class TableQuestionAnsweringOutputElement(BaseInferenceType): + """Outputs of inference for the Table Question Answering task""" + + answer: str + """The answer of the question given the table. If there is an aggregator, the answer will be + preceded by `AGGREGATOR >`. + """ + cells: List[str] + """List of strings made up of the answer cell values.""" + coordinates: List[List[int]] + """Coordinates of the cells of the answers.""" + aggregator: Optional[str] = None + """If the model has an aggregator, this returns the aggregator.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text2text_generation.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text2text_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..34ac74e21e8a30d889f1a251f648d4c365325be6 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text2text_generation.py @@ -0,0 +1,42 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Dict, Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +Text2TextGenerationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] + + +@dataclass_with_extra +class Text2TextGenerationParameters(BaseInferenceType): + """Additional inference parameters for Text2text Generation""" + + clean_up_tokenization_spaces: Optional[bool] = None + """Whether to clean up the potential extra spaces in the text output.""" + generate_parameters: Optional[Dict[str, Any]] = None + """Additional parametrization of the text generation algorithm""" + truncation: Optional["Text2TextGenerationTruncationStrategy"] = None + """The truncation strategy to use""" + + +@dataclass_with_extra +class Text2TextGenerationInput(BaseInferenceType): + """Inputs for Text2text Generation inference""" + + inputs: str + """The input text data""" + parameters: Optional[Text2TextGenerationParameters] = None + """Additional inference parameters for Text2text Generation""" + + +@dataclass_with_extra +class Text2TextGenerationOutput(BaseInferenceType): + """Outputs of inference for the Text2text Generation task""" + + generated_text: Any + text2_text_generation_output_generated_text: Optional[str] = None + """The generated text.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_classification.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..9a172b23f844fa58f757a644d52138a18e7b6ddb --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_classification.py @@ -0,0 +1,41 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +TextClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] + + +@dataclass_with_extra +class TextClassificationParameters(BaseInferenceType): + """Additional inference parameters for Text Classification""" + + function_to_apply: Optional["TextClassificationOutputTransform"] = None + """The function to apply to the model outputs in order to retrieve the scores.""" + top_k: Optional[int] = None + """When specified, limits the output to the top K most probable classes.""" + + +@dataclass_with_extra +class TextClassificationInput(BaseInferenceType): + """Inputs for Text Classification inference""" + + inputs: str + """The text to classify""" + parameters: Optional[TextClassificationParameters] = None + """Additional inference parameters for Text Classification""" + + +@dataclass_with_extra +class TextClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Text Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_generation.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..9b79cc691dce3a6d42aef716d4a93a719f2d600c --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_generation.py @@ -0,0 +1,168 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, List, Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +TypeEnum = Literal["json", "regex", "json_schema"] + + +@dataclass_with_extra +class TextGenerationInputGrammarType(BaseInferenceType): + type: "TypeEnum" + value: Any + """A string that represents a [JSON Schema](https://json-schema.org/). + JSON Schema is a declarative language that allows to annotate JSON documents + with types and descriptions. + """ + + +@dataclass_with_extra +class TextGenerationInputGenerateParameters(BaseInferenceType): + adapter_id: Optional[str] = None + """Lora adapter id""" + best_of: Optional[int] = None + """Generate best_of sequences and return the one if the highest token logprobs.""" + decoder_input_details: Optional[bool] = None + """Whether to return decoder input token logprobs and ids.""" + details: Optional[bool] = None + """Whether to return generation details.""" + do_sample: Optional[bool] = None + """Activate logits sampling.""" + frequency_penalty: Optional[float] = None + """The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + """ + grammar: Optional[TextGenerationInputGrammarType] = None + max_new_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + repetition_penalty: Optional[float] = None + """The parameter for repetition penalty. 1.0 means no penalty. + See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + """ + return_full_text: Optional[bool] = None + """Whether to prepend the prompt to the generated text""" + seed: Optional[int] = None + """Random sampling seed.""" + stop: Optional[List[str]] = None + """Stop generating tokens if a member of `stop` is generated.""" + temperature: Optional[float] = None + """The value used to module the logits distribution.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_n_tokens: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-n-filtering.""" + top_p: Optional[float] = None + """Top-p value for nucleus sampling.""" + truncate: Optional[int] = None + """Truncate inputs tokens to the given size.""" + typical_p: Optional[float] = None + """Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) + for more information. + """ + watermark: Optional[bool] = None + """Watermarking with [A Watermark for Large Language + Models](https://arxiv.org/abs/2301.10226). + """ + + +@dataclass_with_extra +class TextGenerationInput(BaseInferenceType): + """Text Generation Input. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + inputs: str + parameters: Optional[TextGenerationInputGenerateParameters] = None + stream: Optional[bool] = None + + +TextGenerationOutputFinishReason = Literal["length", "eos_token", "stop_sequence"] + + +@dataclass_with_extra +class TextGenerationOutputPrefillToken(BaseInferenceType): + id: int + logprob: float + text: str + + +@dataclass_with_extra +class TextGenerationOutputToken(BaseInferenceType): + id: int + logprob: float + special: bool + text: str + + +@dataclass_with_extra +class TextGenerationOutputBestOfSequence(BaseInferenceType): + finish_reason: "TextGenerationOutputFinishReason" + generated_text: str + generated_tokens: int + prefill: List[TextGenerationOutputPrefillToken] + tokens: List[TextGenerationOutputToken] + seed: Optional[int] = None + top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None + + +@dataclass_with_extra +class TextGenerationOutputDetails(BaseInferenceType): + finish_reason: "TextGenerationOutputFinishReason" + generated_tokens: int + prefill: List[TextGenerationOutputPrefillToken] + tokens: List[TextGenerationOutputToken] + best_of_sequences: Optional[List[TextGenerationOutputBestOfSequence]] = None + seed: Optional[int] = None + top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None + + +@dataclass_with_extra +class TextGenerationOutput(BaseInferenceType): + """Text Generation Output. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + generated_text: str + details: Optional[TextGenerationOutputDetails] = None + + +@dataclass_with_extra +class TextGenerationStreamOutputStreamDetails(BaseInferenceType): + finish_reason: "TextGenerationOutputFinishReason" + generated_tokens: int + input_length: int + seed: Optional[int] = None + + +@dataclass_with_extra +class TextGenerationStreamOutputToken(BaseInferenceType): + id: int + logprob: float + special: bool + text: str + + +@dataclass_with_extra +class TextGenerationStreamOutput(BaseInferenceType): + """Text Generation Stream Output. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + index: int + token: TextGenerationStreamOutputToken + details: Optional[TextGenerationStreamOutputStreamDetails] = None + generated_text: Optional[str] = None + top_tokens: Optional[List[TextGenerationStreamOutputToken]] = None diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_audio.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..87af80a598af70800b8386f034c65de0b397479e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_audio.py @@ -0,0 +1,99 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Literal, Optional, Union + +from .base import BaseInferenceType, dataclass_with_extra + + +TextToAudioEarlyStoppingEnum = Literal["never"] + + +@dataclass_with_extra +class TextToAudioGenerationParameters(BaseInferenceType): + """Parametrization of the text generation process""" + + do_sample: Optional[bool] = None + """Whether to use sampling instead of greedy decoding when generating new tokens.""" + early_stopping: Optional[Union[bool, "TextToAudioEarlyStoppingEnum"]] = None + """Controls the stopping condition for beam-based methods.""" + epsilon_cutoff: Optional[float] = None + """If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + """ + eta_cutoff: Optional[float] = None + """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + """ + max_length: Optional[int] = None + """The maximum length (in tokens) of the generated text, including the input.""" + max_new_tokens: Optional[int] = None + """The maximum number of tokens to generate. Takes precedence over max_length.""" + min_length: Optional[int] = None + """The minimum length (in tokens) of the generated text, including the input.""" + min_new_tokens: Optional[int] = None + """The minimum number of tokens to generate. Takes precedence over min_length.""" + num_beam_groups: Optional[int] = None + """Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + """ + num_beams: Optional[int] = None + """Number of beams to use for beam search.""" + penalty_alpha: Optional[float] = None + """The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + """ + temperature: Optional[float] = None + """The value used to modulate the next token probabilities.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_p: Optional[float] = None + """If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + """ + typical_p: Optional[float] = None + """Local typicality measures how similar the conditional probability of predicting a target + token next is to the expected conditional probability of predicting a random token next, + given the partial text already generated. If set to float < 1, the smallest set of the + most locally typical tokens with probabilities that add up to typical_p or higher are + kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. + """ + use_cache: Optional[bool] = None + """Whether the model should use the past last key/values attentions to speed up decoding""" + + +@dataclass_with_extra +class TextToAudioParameters(BaseInferenceType): + """Additional inference parameters for Text To Audio""" + + generation_parameters: Optional[TextToAudioGenerationParameters] = None + """Parametrization of the text generation process""" + + +@dataclass_with_extra +class TextToAudioInput(BaseInferenceType): + """Inputs for Text To Audio inference""" + + inputs: str + """The input text data""" + parameters: Optional[TextToAudioParameters] = None + """Additional inference parameters for Text To Audio""" + + +@dataclass_with_extra +class TextToAudioOutput(BaseInferenceType): + """Outputs of inference for the Text To Audio task""" + + audio: Any + """The generated audio waveform.""" + sampling_rate: float + """The sampling rate of the generated audio waveform.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_image.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..20c963731371339975019ca5d40c95303d79209b --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_image.py @@ -0,0 +1,50 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class TextToImageParameters(BaseInferenceType): + """Additional inference parameters for Text To Image""" + + guidance_scale: Optional[float] = None + """A higher guidance scale value encourages the model to generate images closely linked to + the text prompt, but values too high may cause saturation and other artifacts. + """ + height: Optional[int] = None + """The height in pixels of the output image""" + negative_prompt: Optional[str] = None + """One prompt to guide what NOT to include in image generation.""" + num_inference_steps: Optional[int] = None + """The number of denoising steps. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + """ + scheduler: Optional[str] = None + """Override the scheduler with a compatible one.""" + seed: Optional[int] = None + """Seed for the random number generator.""" + width: Optional[int] = None + """The width in pixels of the output image""" + + +@dataclass_with_extra +class TextToImageInput(BaseInferenceType): + """Inputs for Text To Image inference""" + + inputs: str + """The input text data (sometimes called "prompt")""" + parameters: Optional[TextToImageParameters] = None + """Additional inference parameters for Text To Image""" + + +@dataclass_with_extra +class TextToImageOutput(BaseInferenceType): + """Outputs of inference for the Text To Image task""" + + image: Any + """The generated image returned as raw bytes in the payload.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_speech.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2db8f3f901cc99b5d2fcbb362c4b07b2a718e0 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_speech.py @@ -0,0 +1,99 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Literal, Optional, Union + +from .base import BaseInferenceType, dataclass_with_extra + + +TextToSpeechEarlyStoppingEnum = Literal["never"] + + +@dataclass_with_extra +class TextToSpeechGenerationParameters(BaseInferenceType): + """Parametrization of the text generation process""" + + do_sample: Optional[bool] = None + """Whether to use sampling instead of greedy decoding when generating new tokens.""" + early_stopping: Optional[Union[bool, "TextToSpeechEarlyStoppingEnum"]] = None + """Controls the stopping condition for beam-based methods.""" + epsilon_cutoff: Optional[float] = None + """If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + """ + eta_cutoff: Optional[float] = None + """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + """ + max_length: Optional[int] = None + """The maximum length (in tokens) of the generated text, including the input.""" + max_new_tokens: Optional[int] = None + """The maximum number of tokens to generate. Takes precedence over max_length.""" + min_length: Optional[int] = None + """The minimum length (in tokens) of the generated text, including the input.""" + min_new_tokens: Optional[int] = None + """The minimum number of tokens to generate. Takes precedence over min_length.""" + num_beam_groups: Optional[int] = None + """Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + """ + num_beams: Optional[int] = None + """Number of beams to use for beam search.""" + penalty_alpha: Optional[float] = None + """The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + """ + temperature: Optional[float] = None + """The value used to modulate the next token probabilities.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_p: Optional[float] = None + """If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + """ + typical_p: Optional[float] = None + """Local typicality measures how similar the conditional probability of predicting a target + token next is to the expected conditional probability of predicting a random token next, + given the partial text already generated. If set to float < 1, the smallest set of the + most locally typical tokens with probabilities that add up to typical_p or higher are + kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. + """ + use_cache: Optional[bool] = None + """Whether the model should use the past last key/values attentions to speed up decoding""" + + +@dataclass_with_extra +class TextToSpeechParameters(BaseInferenceType): + """Additional inference parameters for Text To Speech""" + + generation_parameters: Optional[TextToSpeechGenerationParameters] = None + """Parametrization of the text generation process""" + + +@dataclass_with_extra +class TextToSpeechInput(BaseInferenceType): + """Inputs for Text To Speech inference""" + + inputs: str + """The input text data""" + parameters: Optional[TextToSpeechParameters] = None + """Additional inference parameters for Text To Speech""" + + +@dataclass_with_extra +class TextToSpeechOutput(BaseInferenceType): + """Outputs of inference for the Text To Speech task""" + + audio: Any + """The generated audio""" + sampling_rate: Optional[float] = None + """The sampling rate of the generated audio waveform.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_video.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_video.py new file mode 100644 index 0000000000000000000000000000000000000000..e54a1bc094e4aaf7132e502aa268bc052ab34f0a --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/text_to_video.py @@ -0,0 +1,46 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, List, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class TextToVideoParameters(BaseInferenceType): + """Additional inference parameters for Text To Video""" + + guidance_scale: Optional[float] = None + """A higher guidance scale value encourages the model to generate videos closely linked to + the text prompt, but values too high may cause saturation and other artifacts. + """ + negative_prompt: Optional[List[str]] = None + """One or several prompt to guide what NOT to include in video generation.""" + num_frames: Optional[float] = None + """The num_frames parameter determines how many video frames are generated.""" + num_inference_steps: Optional[int] = None + """The number of denoising steps. More denoising steps usually lead to a higher quality + video at the expense of slower inference. + """ + seed: Optional[int] = None + """Seed for the random number generator.""" + + +@dataclass_with_extra +class TextToVideoInput(BaseInferenceType): + """Inputs for Text To Video inference""" + + inputs: str + """The input text data (sometimes called "prompt")""" + parameters: Optional[TextToVideoParameters] = None + """Additional inference parameters for Text To Video""" + + +@dataclass_with_extra +class TextToVideoOutput(BaseInferenceType): + """Outputs of inference for the Text To Video task""" + + video: Any + """The generated video returned as raw bytes in the payload.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/token_classification.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/token_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..e039b6a1db7dcd54dbc9434d3254da0770c6799e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/token_classification.py @@ -0,0 +1,51 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import List, Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +TokenClassificationAggregationStrategy = Literal["none", "simple", "first", "average", "max"] + + +@dataclass_with_extra +class TokenClassificationParameters(BaseInferenceType): + """Additional inference parameters for Token Classification""" + + aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None + """The strategy used to fuse tokens based on model predictions""" + ignore_labels: Optional[List[str]] = None + """A list of labels to ignore""" + stride: Optional[int] = None + """The number of overlapping tokens between chunks when splitting the input text.""" + + +@dataclass_with_extra +class TokenClassificationInput(BaseInferenceType): + """Inputs for Token Classification inference""" + + inputs: str + """The input text data""" + parameters: Optional[TokenClassificationParameters] = None + """Additional inference parameters for Token Classification""" + + +@dataclass_with_extra +class TokenClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Token Classification task""" + + end: int + """The character position in the input where this group ends.""" + score: float + """The associated score / probability""" + start: int + """The character position in the input where this group begins.""" + word: str + """The corresponding text""" + entity: Optional[str] = None + """The predicted label for a single token""" + entity_group: Optional[str] = None + """The predicted label for a group of one or more tokens""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/translation.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/translation.py new file mode 100644 index 0000000000000000000000000000000000000000..df95b7dbb1f4ce5b80cec034e004bb6e71387be8 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/translation.py @@ -0,0 +1,49 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Dict, Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +TranslationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] + + +@dataclass_with_extra +class TranslationParameters(BaseInferenceType): + """Additional inference parameters for Translation""" + + clean_up_tokenization_spaces: Optional[bool] = None + """Whether to clean up the potential extra spaces in the text output.""" + generate_parameters: Optional[Dict[str, Any]] = None + """Additional parametrization of the text generation algorithm.""" + src_lang: Optional[str] = None + """The source language of the text. Required for models that can translate from multiple + languages. + """ + tgt_lang: Optional[str] = None + """Target language to translate to. Required for models that can translate to multiple + languages. + """ + truncation: Optional["TranslationTruncationStrategy"] = None + """The truncation strategy to use.""" + + +@dataclass_with_extra +class TranslationInput(BaseInferenceType): + """Inputs for Translation inference""" + + inputs: str + """The text to translate.""" + parameters: Optional[TranslationParameters] = None + """Additional inference parameters for Translation""" + + +@dataclass_with_extra +class TranslationOutput(BaseInferenceType): + """Outputs of inference for the Translation task""" + + translation_text: str + """The translated text.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/video_classification.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/video_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d7a15bb4ee5fa63aa6ebc3750191bd38549212 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/video_classification.py @@ -0,0 +1,45 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Literal, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +VideoClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] + + +@dataclass_with_extra +class VideoClassificationParameters(BaseInferenceType): + """Additional inference parameters for Video Classification""" + + frame_sampling_rate: Optional[int] = None + """The sampling rate used to select frames from the video.""" + function_to_apply: Optional["VideoClassificationOutputTransform"] = None + """The function to apply to the model outputs in order to retrieve the scores.""" + num_frames: Optional[int] = None + """The number of sampled frames to consider for classification.""" + top_k: Optional[int] = None + """When specified, limits the output to the top K most probable classes.""" + + +@dataclass_with_extra +class VideoClassificationInput(BaseInferenceType): + """Inputs for Video Classification inference""" + + inputs: Any + """The input video data""" + parameters: Optional[VideoClassificationParameters] = None + """Additional inference parameters for Video Classification""" + + +@dataclass_with_extra +class VideoClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Video Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/visual_question_answering.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/visual_question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..d368f1621289bc11a17be3e590cf8a040019d455 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/visual_question_answering.py @@ -0,0 +1,49 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import Any, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class VisualQuestionAnsweringInputData(BaseInferenceType): + """One (image, question) pair to answer""" + + image: Any + """The image.""" + question: str + """The question to answer based on the image.""" + + +@dataclass_with_extra +class VisualQuestionAnsweringParameters(BaseInferenceType): + """Additional inference parameters for Visual Question Answering""" + + top_k: Optional[int] = None + """The number of answers to return (will be chosen by order of likelihood). Note that we + return less than topk answers if there are not enough options available within the + context. + """ + + +@dataclass_with_extra +class VisualQuestionAnsweringInput(BaseInferenceType): + """Inputs for Visual Question Answering inference""" + + inputs: VisualQuestionAnsweringInputData + """One (image, question) pair to answer""" + parameters: Optional[VisualQuestionAnsweringParameters] = None + """Additional inference parameters for Visual Question Answering""" + + +@dataclass_with_extra +class VisualQuestionAnsweringOutputElement(BaseInferenceType): + """Outputs of inference for the Visual Question Answering task""" + + score: float + """The associated score / probability""" + answer: Optional[str] = None + """The answer to the question""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_classification.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..47b32492e358edcc0de6aa09d53635b0a8156b25 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_classification.py @@ -0,0 +1,45 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import List, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class ZeroShotClassificationParameters(BaseInferenceType): + """Additional inference parameters for Zero Shot Classification""" + + candidate_labels: List[str] + """The set of possible class labels to classify the text into.""" + hypothesis_template: Optional[str] = None + """The sentence used in conjunction with `candidate_labels` to attempt the text + classification by replacing the placeholder with the candidate labels. + """ + multi_label: Optional[bool] = None + """Whether multiple candidate labels can be true. If false, the scores are normalized such + that the sum of the label likelihoods for each sequence is 1. If true, the labels are + considered independent and probabilities are normalized for each candidate. + """ + + +@dataclass_with_extra +class ZeroShotClassificationInput(BaseInferenceType): + """Inputs for Zero Shot Classification inference""" + + inputs: str + """The text to classify""" + parameters: ZeroShotClassificationParameters + """Additional inference parameters for Zero Shot Classification""" + + +@dataclass_with_extra +class ZeroShotClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Zero Shot Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..998d66b6b4e3356f0f09a0ad25ebdaf2e76cd03f --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py @@ -0,0 +1,40 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import List, Optional + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class ZeroShotImageClassificationParameters(BaseInferenceType): + """Additional inference parameters for Zero Shot Image Classification""" + + candidate_labels: List[str] + """The candidate labels for this image""" + hypothesis_template: Optional[str] = None + """The sentence used in conjunction with `candidate_labels` to attempt the image + classification by replacing the placeholder with the candidate labels. + """ + + +@dataclass_with_extra +class ZeroShotImageClassificationInput(BaseInferenceType): + """Inputs for Zero Shot Image Classification inference""" + + inputs: str + """The input image data to classify as a base64-encoded string.""" + parameters: ZeroShotImageClassificationParameters + """Additional inference parameters for Zero Shot Image Classification""" + + +@dataclass_with_extra +class ZeroShotImageClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Zero Shot Image Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef76b5fcb93e8126266e4b1464934d01024b1b7 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py @@ -0,0 +1,52 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from typing import List + +from .base import BaseInferenceType, dataclass_with_extra + + +@dataclass_with_extra +class ZeroShotObjectDetectionParameters(BaseInferenceType): + """Additional inference parameters for Zero Shot Object Detection""" + + candidate_labels: List[str] + """The candidate labels for this image""" + + +@dataclass_with_extra +class ZeroShotObjectDetectionInput(BaseInferenceType): + """Inputs for Zero Shot Object Detection inference""" + + inputs: str + """The input image data as a base64-encoded string.""" + parameters: ZeroShotObjectDetectionParameters + """Additional inference parameters for Zero Shot Object Detection""" + + +@dataclass_with_extra +class ZeroShotObjectDetectionBoundingBox(BaseInferenceType): + """The predicted bounding box. Coordinates are relative to the top left corner of the input + image. + """ + + xmax: int + xmin: int + ymax: int + ymin: int + + +@dataclass_with_extra +class ZeroShotObjectDetectionOutputElement(BaseInferenceType): + """Outputs of inference for the Zero Shot Object Detection task""" + + box: ZeroShotObjectDetectionBoundingBox + """The predicted bounding box. Coordinates are relative to the top left corner of the input + image. + """ + label: str + """A candidate label""" + score: float + """The associated score / probability""" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9bef9d8b9b150f9c94dfa8275dac200894b3941 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/_cli_hacks.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/_cli_hacks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe01a2039b4fc97c7c3c0c91ff63829cdbb5d1c Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/_cli_hacks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/agent.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/agent.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a04ceb43e6890c8c1bca0755cce7c5adc645128 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/agent.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/cli.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/cli.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67316ec45f7fd9741f4728d0c96cc467ed9a10a7 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/cli.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/constants.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca13b52c0ba0a56ea5af53f32b11e80560acfbf2 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/mcp_client.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/mcp_client.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01bb19aefeb87d6887b5ddf590ce824d3acb4fe8 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/mcp_client.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/types.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e914423d0863bfe743035b1e06be9cc3f2588b1c Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38f651e21208a42999b344a7caf8a9e91a0c971a Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/_cli_hacks.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/_cli_hacks.py new file mode 100644 index 0000000000000000000000000000000000000000..44113b9101a78c7fc5c7abfb58965ed0c22032a1 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/_cli_hacks.py @@ -0,0 +1,88 @@ +import asyncio +import sys +from functools import partial + +import typer + + +def _patch_anyio_open_process(): + """ + Patch anyio.open_process to allow detached processes on Windows and Unix-like systems. + + This is necessary to prevent the MCP client from being interrupted by Ctrl+C when running in the CLI. + """ + import subprocess + + import anyio + + if getattr(anyio, "_tiny_agents_patched", False): + return + anyio._tiny_agents_patched = True + + original_open_process = anyio.open_process + + if sys.platform == "win32": + # On Windows, we need to set the creation flags to create a new process group + + async def open_process_in_new_group(*args, **kwargs): + """ + Wrapper for open_process to handle Windows-specific process creation flags. + """ + # Ensure we pass the creation flags for Windows + kwargs.setdefault("creationflags", subprocess.CREATE_NEW_PROCESS_GROUP) + return await original_open_process(*args, **kwargs) + + anyio.open_process = open_process_in_new_group + else: + # For Unix-like systems, we can use setsid to create a new session + async def open_process_in_new_group(*args, **kwargs): + """ + Wrapper for open_process to handle Unix-like systems with start_new_session=True. + """ + kwargs.setdefault("start_new_session", True) + return await original_open_process(*args, **kwargs) + + anyio.open_process = open_process_in_new_group + + +async def _async_prompt(exit_event: asyncio.Event, prompt: str = "» ") -> str: + """ + Asynchronous prompt function that reads input from stdin without blocking. + + This function is designed to work in an asynchronous context, allowing the event loop to gracefully stop it (e.g. on Ctrl+C). + + Alternatively, we could use https://github.com/vxgmichel/aioconsole but that would be an additional dependency. + """ + loop = asyncio.get_event_loop() + + if sys.platform == "win32": + # Windows: Use run_in_executor to avoid blocking the event loop + # Degraded solution: this is not ideal as user will have to CTRL+C once more to stop the prompt (and it'll not be graceful) + return await loop.run_in_executor(None, partial(typer.prompt, prompt, prompt_suffix=" ")) + else: + # UNIX-like: Use loop.add_reader for non-blocking stdin read + future = loop.create_future() + + def on_input(): + line = sys.stdin.readline() + loop.remove_reader(sys.stdin) + future.set_result(line) + + print(prompt, end=" ", flush=True) + loop.add_reader(sys.stdin, on_input) # not supported on Windows + + # Wait for user input or exit event + # Wait until either the user hits enter or exit_event is set + exit_task = asyncio.create_task(exit_event.wait()) + await asyncio.wait( + [future, exit_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Check which one has been triggered + if exit_event.is_set(): + future.cancel() + return "" + + line = await future + return line.strip() diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/agent.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..4f88016ba709d15445525e2bd252febb5b2287ab --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/agent.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union + +from huggingface_hub import ChatCompletionInputMessage, ChatCompletionStreamOutput, MCPClient + +from .._providers import PROVIDER_OR_POLICY_T +from .constants import DEFAULT_SYSTEM_PROMPT, EXIT_LOOP_TOOLS, MAX_NUM_TURNS +from .types import ServerConfig + + +class Agent(MCPClient): + """ + Implementation of a Simple Agent, which is a simple while loop built right on top of an [`MCPClient`]. + + + + This class is experimental and might be subject to breaking changes in the future without prior notice. + + + + Args: + model (`str`, *optional*): + The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` + or a URL to a deployed Inference Endpoint or other local or remote endpoint. + servers (`Iterable[Dict]`): + MCP servers to connect to. Each server is a dictionary containing a `type` key and a `config` key. The `type` key can be `"stdio"` or `"sse"`, and the `config` key is a dictionary of arguments for the server. + provider (`str`, *optional*): + Name of the provider to use for inference. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. + If model is a URL or `base_url` is passed, then `provider` is not used. + base_url (`str`, *optional*): + The base URL to run inference. Defaults to None. + api_key (`str`, *optional*): + Token to use for authentication. Will default to the locally Hugging Face saved token if not provided. You can also use your own provider API key to interact directly with the provider's service. + prompt (`str`, *optional*): + The system prompt to use for the agent. Defaults to the default system prompt in `constants.py`. + """ + + def __init__( + self, + *, + model: Optional[str] = None, + servers: Iterable[ServerConfig], + provider: Optional[PROVIDER_OR_POLICY_T] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + prompt: Optional[str] = None, + ): + super().__init__(model=model, provider=provider, base_url=base_url, api_key=api_key) + self._servers_cfg = list(servers) + self.messages: List[Union[Dict, ChatCompletionInputMessage]] = [ + {"role": "system", "content": prompt or DEFAULT_SYSTEM_PROMPT} + ] + + async def load_tools(self) -> None: + for cfg in self._servers_cfg: + await self.add_mcp_server(**cfg) + + async def run( + self, + user_input: str, + *, + abort_event: Optional[asyncio.Event] = None, + ) -> AsyncGenerator[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage], None]: + """ + Run the agent with the given user input. + + Args: + user_input (`str`): + The user input to run the agent with. + abort_event (`asyncio.Event`, *optional*): + An event that can be used to abort the agent. If the event is set, the agent will stop running. + """ + self.messages.append({"role": "user", "content": user_input}) + + num_turns: int = 0 + next_turn_should_call_tools = True + + while True: + if abort_event and abort_event.is_set(): + return + + async for item in self.process_single_turn_with_tools( + self.messages, + exit_loop_tools=EXIT_LOOP_TOOLS, + exit_if_first_chunk_no_tool=(num_turns > 0 and next_turn_should_call_tools), + ): + yield item + + num_turns += 1 + last = self.messages[-1] + + if last.get("role") == "tool" and last.get("name") in {t.function.name for t in EXIT_LOOP_TOOLS}: + return + + if last.get("role") != "tool" and num_turns > MAX_NUM_TURNS: + return + + if last.get("role") != "tool" and next_turn_should_call_tools: + return + + next_turn_should_call_tools = last.get("role") != "tool" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/cli.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..ce0a2536b63ee207f378b3ef45980a915d432a14 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/cli.py @@ -0,0 +1,247 @@ +import asyncio +import os +import signal +import traceback +from typing import Optional + +import typer +from rich import print + +from ._cli_hacks import _async_prompt, _patch_anyio_open_process +from .agent import Agent +from .utils import _load_agent_config + + +app = typer.Typer( + rich_markup_mode="rich", + help="A squad of lightweight composable AI applications built on Hugging Face's Inference Client and MCP stack.", +) + +run_cli = typer.Typer( + name="run", + help="Run the Agent in the CLI", + invoke_without_command=True, +) +app.add_typer(run_cli, name="run") + + +async def run_agent( + agent_path: Optional[str], +) -> None: + """ + Tiny Agent loop. + + Args: + agent_path (`str`, *optional*): + Path to a local folder containing an `agent.json` and optionally a custom `PROMPT.md` file or a built-in agent stored in a Hugging Face dataset. + + """ + _patch_anyio_open_process() # Hacky way to prevent stdio connections to be stopped by Ctrl+C + + config, prompt = _load_agent_config(agent_path) + + inputs = config.get("inputs", []) + servers = config.get("servers", []) + + abort_event = asyncio.Event() + exit_event = asyncio.Event() + first_sigint = True + + loop = asyncio.get_running_loop() + original_sigint_handler = signal.getsignal(signal.SIGINT) + + def _sigint_handler() -> None: + nonlocal first_sigint + if first_sigint: + first_sigint = False + abort_event.set() + print("\n[red]Interrupted. Press Ctrl+C again to quit.[/red]", flush=True) + return + + print("\n[red]Exiting...[/red]", flush=True) + exit_event.set() + + try: + sigint_registered_in_loop = False + try: + loop.add_signal_handler(signal.SIGINT, _sigint_handler) + sigint_registered_in_loop = True + except (AttributeError, NotImplementedError): + # Windows (or any loop that doesn't support it) : fall back to sync + signal.signal(signal.SIGINT, lambda *_: _sigint_handler()) + + # Handle inputs (i.e. env variables injection) + resolved_inputs: dict[str, str] = {} + + if len(inputs) > 0: + print( + "[bold blue]Some initial inputs are required by the agent. " + "Please provide a value or leave empty to load from env.[/bold blue]" + ) + for input_item in inputs: + input_id = input_item["id"] + description = input_item["description"] + env_special_value = f"${{input:{input_id}}}" + + # Check if the input is used by any server or as an apiKey + input_usages = set() + for server in servers: + # Check stdio's "env" and http/sse's "headers" mappings + env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {}) + for key, value in env_or_headers.items(): + if env_special_value in value: + input_usages.add(key) + + raw_api_key = config.get("apiKey") + if isinstance(raw_api_key, str) and env_special_value in raw_api_key: + input_usages.add("apiKey") + + if not input_usages: + print( + f"[yellow]Input '{input_id}' defined in config but not used by any server or as an API key." + " Skipping.[/yellow]" + ) + continue + + # Prompt user for input + env_variable_key = input_id.replace("-", "_").upper() + print( + f"[blue] • {input_id}[/blue]: {description}. (default: load from {env_variable_key}).", + end=" ", + ) + user_input = (await _async_prompt(exit_event=exit_event)).strip() + if exit_event.is_set(): + return + + # Fallback to environment variable when user left blank + final_value = user_input + if not final_value: + final_value = os.getenv(env_variable_key, "") + if final_value: + print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]") + else: + print( + f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]" + ) + resolved_inputs[input_id] = final_value + + # Inject resolved value (can be empty) into stdio's env or http/sse's headers + for server in servers: + env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {}) + for key, value in env_or_headers.items(): + if env_special_value in value: + env_or_headers[key] = env_or_headers[key].replace(env_special_value, final_value) + + print() + + raw_api_key = config.get("apiKey") + if isinstance(raw_api_key, str): + substituted_api_key = raw_api_key + for input_id, val in resolved_inputs.items(): + substituted_api_key = substituted_api_key.replace(f"${{input:{input_id}}}", val) + config["apiKey"] = substituted_api_key + # Main agent loop + async with Agent( + provider=config.get("provider"), # type: ignore[arg-type] + model=config.get("model"), + base_url=config.get("endpointUrl"), # type: ignore[arg-type] + api_key=config.get("apiKey"), + servers=servers, # type: ignore[arg-type] + prompt=prompt, + ) as agent: + await agent.load_tools() + print(f"[bold blue]Agent loaded with {len(agent.available_tools)} tools:[/bold blue]") + for t in agent.available_tools: + print(f"[blue] • {t.function.name}[/blue]") + + while True: + abort_event.clear() + + # Check if we should exit + if exit_event.is_set(): + return + + try: + user_input = await _async_prompt(exit_event=exit_event) + first_sigint = True + except EOFError: + print("\n[red]EOF received, exiting.[/red]", flush=True) + break + except KeyboardInterrupt: + if not first_sigint and abort_event.is_set(): + continue + else: + print("\n[red]Keyboard interrupt during input processing.[/red]", flush=True) + break + + try: + async for chunk in agent.run(user_input, abort_event=abort_event): + if abort_event.is_set() and not first_sigint: + break + if exit_event.is_set(): + return + + if hasattr(chunk, "choices"): + delta = chunk.choices[0].delta + if delta.content: + print(delta.content, end="", flush=True) + if delta.tool_calls: + for call in delta.tool_calls: + if call.id: + print(f"", end="") + if call.function.name: + print(f"{call.function.name}", end=" ") + if call.function.arguments: + print(f"{call.function.arguments}", end="") + else: + print( + f"\n\n[green]Tool[{chunk.name}] {chunk.tool_call_id}\n{chunk.content}[/green]\n", + flush=True, + ) + + print() + + except Exception as e: + tb_str = traceback.format_exc() + print(f"\n[bold red]Error during agent run: {e}\n{tb_str}[/bold red]", flush=True) + first_sigint = True # Allow graceful interrupt for the next command + + except Exception as e: + tb_str = traceback.format_exc() + print(f"\n[bold red]An unexpected error occurred: {e}\n{tb_str}[/bold red]", flush=True) + raise e + + finally: + if sigint_registered_in_loop: + try: + loop.remove_signal_handler(signal.SIGINT) + except (AttributeError, NotImplementedError): + pass + else: + signal.signal(signal.SIGINT, original_sigint_handler) + + +@run_cli.callback() +def run( + path: Optional[str] = typer.Argument( + None, + help=( + "Path to a local folder containing an agent.json file or a built-in agent " + "stored in the 'tiny-agents/tiny-agents' Hugging Face dataset " + "(https://huggingface.co/datasets/tiny-agents/tiny-agents)" + ), + show_default=False, + ), +): + try: + asyncio.run(run_agent(path)) + except KeyboardInterrupt: + print("\n[red]Application terminated by KeyboardInterrupt.[/red]", flush=True) + raise typer.Exit(code=130) + except Exception as e: + print(f"\n[bold red]An unexpected error occurred: {e}[/bold red]", flush=True) + raise e + + +if __name__ == "__main__": + app() diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/constants.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..968bae48ed19c2ba8011eeb676ac2699f5eb3beb --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/constants.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from typing import List + +from huggingface_hub import ChatCompletionInputTool + + +FILENAME_CONFIG = "agent.json" +FILENAME_PROMPT = "PROMPT.md" + +DEFAULT_AGENT = { + "model": "Qwen/Qwen2.5-72B-Instruct", + "provider": "nebius", + "servers": [ + { + "type": "stdio", + "command": "npx", + "args": [ + "-y", + "@modelcontextprotocol/server-filesystem", + str(Path.home() / ("Desktop" if sys.platform == "darwin" else "")), + ], + }, + { + "type": "stdio", + "command": "npx", + "args": ["@playwright/mcp@latest"], + }, + ], +} + + +DEFAULT_SYSTEM_PROMPT = """ +You are an agent - please keep going until the user’s query is completely +resolved, before ending your turn and yielding back to the user. Only terminate +your turn when you are sure that the problem is solved, or if you need more +info from the user to solve the problem. +If you are not sure about anything pertaining to the user’s request, use your +tools to read files and gather the relevant information: do NOT guess or make +up an answer. +You MUST plan extensively before each function call, and reflect extensively +on the outcomes of the previous function calls. DO NOT do this entire process +by making function calls only, as this can impair your ability to solve the +problem and think insightfully. +""".strip() + +MAX_NUM_TURNS = 10 + +TASK_COMPLETE_TOOL: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj( # type: ignore[assignment] + { + "type": "function", + "function": { + "name": "task_complete", + "description": "Call this tool when the task given by the user is complete", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } +) + +ASK_QUESTION_TOOL: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj( # type: ignore[assignment] + { + "type": "function", + "function": { + "name": "ask_question", + "description": "Ask the user for more info required to solve or clarify their problem.", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } +) + +EXIT_LOOP_TOOLS: List[ChatCompletionInputTool] = [TASK_COMPLETE_TOOL, ASK_QUESTION_TOOL] + + +DEFAULT_REPO_ID = "tiny-agents/tiny-agents" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/mcp_client.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/mcp_client.py new file mode 100644 index 0000000000000000000000000000000000000000..2712dea12127ed69a088d9414f0715de3e103d8b --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/mcp_client.py @@ -0,0 +1,369 @@ +import json +import logging +from contextlib import AsyncExitStack +from datetime import timedelta +from pathlib import Path +from typing import TYPE_CHECKING, Any, AsyncIterable, Dict, List, Literal, Optional, Union, overload + +from typing_extensions import NotRequired, TypeAlias, TypedDict, Unpack + +from ...utils._runtime import get_hf_hub_version +from .._generated._async_client import AsyncInferenceClient +from .._generated.types import ( + ChatCompletionInputMessage, + ChatCompletionInputTool, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputDeltaToolCall, +) +from .._providers import PROVIDER_OR_POLICY_T +from .utils import format_result + + +if TYPE_CHECKING: + from mcp import ClientSession + +logger = logging.getLogger(__name__) + +# Type alias for tool names +ToolName: TypeAlias = str + +ServerType: TypeAlias = Literal["stdio", "sse", "http"] + + +class StdioServerParameters_T(TypedDict): + command: str + args: NotRequired[List[str]] + env: NotRequired[Dict[str, str]] + cwd: NotRequired[Union[str, Path, None]] + + +class SSEServerParameters_T(TypedDict): + url: str + headers: NotRequired[Dict[str, Any]] + timeout: NotRequired[float] + sse_read_timeout: NotRequired[float] + + +class StreamableHTTPParameters_T(TypedDict): + url: str + headers: NotRequired[dict[str, Any]] + timeout: NotRequired[timedelta] + sse_read_timeout: NotRequired[timedelta] + terminate_on_close: NotRequired[bool] + + +class MCPClient: + """ + Client for connecting to one or more MCP servers and processing chat completions with tools. + + + + This class is experimental and might be subject to breaking changes in the future without prior notice. + + + + Args: + model (`str`, `optional`): + The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` + or a URL to a deployed Inference Endpoint or other local or remote endpoint. + provider (`str`, *optional*): + Name of the provider to use for inference. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. + If model is a URL or `base_url` is passed, then `provider` is not used. + base_url (`str`, *optional*): + The base URL to run inference. Defaults to None. + api_key (`str`, `optional`): + Token to use for authentication. Will default to the locally Hugging Face saved token if not provided. You can also use your own provider API key to interact directly with the provider's service. + """ + + def __init__( + self, + *, + model: Optional[str] = None, + provider: Optional[PROVIDER_OR_POLICY_T] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + ): + # Initialize MCP sessions as a dictionary of ClientSession objects + self.sessions: Dict[ToolName, "ClientSession"] = {} + self.exit_stack = AsyncExitStack() + self.available_tools: List[ChatCompletionInputTool] = [] + # To be able to send the model in the payload if `base_url` is provided + if model is None and base_url is None: + raise ValueError("At least one of `model` or `base_url` should be set in `MCPClient`.") + self.payload_model = model + self.client = AsyncInferenceClient( + model=None if base_url is not None else model, + provider=provider, + api_key=api_key, + base_url=base_url, + ) + + async def __aenter__(self): + """Enter the context manager""" + await self.client.__aenter__() + await self.exit_stack.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager""" + await self.client.__aexit__(exc_type, exc_val, exc_tb) + await self.cleanup() + + async def cleanup(self): + """Clean up resources""" + await self.client.close() + await self.exit_stack.aclose() + + @overload + async def add_mcp_server(self, type: Literal["stdio"], **params: Unpack[StdioServerParameters_T]): ... + + @overload + async def add_mcp_server(self, type: Literal["sse"], **params: Unpack[SSEServerParameters_T]): ... + + @overload + async def add_mcp_server(self, type: Literal["http"], **params: Unpack[StreamableHTTPParameters_T]): ... + + async def add_mcp_server(self, type: ServerType, **params: Any): + """Connect to an MCP server + + Args: + type (`str`): + Type of the server to connect to. Can be one of: + - "stdio": Standard input/output server (local) + - "sse": Server-sent events (SSE) server + - "http": StreamableHTTP server + **params (`Dict[str, Any]`): + Server parameters that can be either: + - For stdio servers: + - command (str): The command to run the MCP server + - args (List[str], optional): Arguments for the command + - env (Dict[str, str], optional): Environment variables for the command + - cwd (Union[str, Path, None], optional): Working directory for the command + - For SSE servers: + - url (str): The URL of the SSE server + - headers (Dict[str, Any], optional): Headers for the SSE connection + - timeout (float, optional): Connection timeout + - sse_read_timeout (float, optional): SSE read timeout + - For StreamableHTTP servers: + - url (str): The URL of the StreamableHTTP server + - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection + - timeout (timedelta, optional): Connection timeout + - sse_read_timeout (timedelta, optional): SSE read timeout + - terminate_on_close (bool, optional): Whether to terminate on close + """ + from mcp import ClientSession, StdioServerParameters + from mcp import types as mcp_types + + # Determine server type and create appropriate parameters + if type == "stdio": + # Handle stdio server + from mcp.client.stdio import stdio_client + + logger.info(f"Connecting to stdio MCP server with command: {params['command']} {params.get('args', [])}") + + client_kwargs = {"command": params["command"]} + for key in ["args", "env", "cwd"]: + if params.get(key) is not None: + client_kwargs[key] = params[key] + server_params = StdioServerParameters(**client_kwargs) + read, write = await self.exit_stack.enter_async_context(stdio_client(server_params)) + elif type == "sse": + # Handle SSE server + from mcp.client.sse import sse_client + + logger.info(f"Connecting to SSE MCP server at: {params['url']}") + + client_kwargs = {"url": params["url"]} + for key in ["headers", "timeout", "sse_read_timeout"]: + if params.get(key) is not None: + client_kwargs[key] = params[key] + read, write = await self.exit_stack.enter_async_context(sse_client(**client_kwargs)) + elif type == "http": + # Handle StreamableHTTP server + from mcp.client.streamable_http import streamablehttp_client + + logger.info(f"Connecting to StreamableHTTP MCP server at: {params['url']}") + + client_kwargs = {"url": params["url"]} + for key in ["headers", "timeout", "sse_read_timeout", "terminate_on_close"]: + if params.get(key) is not None: + client_kwargs[key] = params[key] + read, write, _ = await self.exit_stack.enter_async_context(streamablehttp_client(**client_kwargs)) + # ^ TODO: should be handle `get_session_id_callback`? (function to retrieve the current session ID) + else: + raise ValueError(f"Unsupported server type: {type}") + + session = await self.exit_stack.enter_async_context( + ClientSession( + read_stream=read, + write_stream=write, + client_info=mcp_types.Implementation( + name="huggingface_hub.MCPClient", + version=get_hf_hub_version(), + ), + ) + ) + + logger.debug("Initializing session...") + await session.initialize() + + # List available tools + response = await session.list_tools() + logger.debug("Connected to server with tools:", [tool.name for tool in response.tools]) + + for tool in response.tools: + if tool.name in self.sessions: + logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.") + continue + + # Map tool names to their server for later lookup + self.sessions[tool.name] = session + + # Add tool to the list of available tools (for use in chat completions) + self.available_tools.append( + ChatCompletionInputTool.parse_obj_as_instance( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.inputSchema, + }, + } + ) + ) + + async def process_single_turn_with_tools( + self, + messages: List[Union[Dict, ChatCompletionInputMessage]], + exit_loop_tools: Optional[List[ChatCompletionInputTool]] = None, + exit_if_first_chunk_no_tool: bool = False, + ) -> AsyncIterable[Union[ChatCompletionStreamOutput, ChatCompletionInputMessage]]: + """Process a query using `self.model` and available tools, yielding chunks and tool outputs. + + Args: + messages (`List[Dict]`): + List of message objects representing the conversation history + exit_loop_tools (`List[ChatCompletionInputTool]`, *optional*): + List of tools that should exit the generator when called + exit_if_first_chunk_no_tool (`bool`, *optional*): + Exit if no tool is present in the first chunks. Default to False. + + Yields: + [`ChatCompletionStreamOutput`] chunks or [`ChatCompletionInputMessage`] objects + """ + # Prepare tools list based on options + tools = self.available_tools + if exit_loop_tools is not None: + tools = [*exit_loop_tools, *self.available_tools] + + # Create the streaming request + response = await self.client.chat.completions.create( + model=self.payload_model, + messages=messages, + tools=tools, + tool_choice="auto", + stream=True, + ) + + message: Dict[str, Any] = {"role": "unknown", "content": ""} + final_tool_calls: Dict[int, ChatCompletionStreamOutputDeltaToolCall] = {} + num_of_chunks = 0 + + # Read from stream + async for chunk in response: + num_of_chunks += 1 + delta = chunk.choices[0].delta if chunk.choices and len(chunk.choices) > 0 else None + if not delta: + continue + + # Process message + if delta.role: + message["role"] = delta.role + if delta.content: + message["content"] += delta.content + + # Process tool calls + if delta.tool_calls: + for tool_call in delta.tool_calls: + # Aggregate chunks into tool calls + if tool_call.index not in final_tool_calls: + if ( + tool_call.function.arguments is None or tool_call.function.arguments == "{}" + ): # Corner case (depends on provider) + tool_call.function.arguments = "" + final_tool_calls[tool_call.index] = tool_call + + elif tool_call.function.arguments: + final_tool_calls[tool_call.index].function.arguments += tool_call.function.arguments + + # Optionally exit early if no tools in first chunks + if exit_if_first_chunk_no_tool and num_of_chunks <= 2 and len(final_tool_calls) == 0: + return + + # Yield each chunk to caller + yield chunk + + # Add the assistant message with tool calls (if any) to messages + if message["content"] or final_tool_calls: + # if the role is unknown, set it to assistant + if message.get("role") == "unknown": + message["role"] = "assistant" + # Convert final_tool_calls to the format expected by OpenAI + if final_tool_calls: + tool_calls_list: List[Dict[str, Any]] = [] + for tc in final_tool_calls.values(): + tool_calls_list.append( + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments or "{}", + }, + } + ) + message["tool_calls"] = tool_calls_list + messages.append(message) + + # Process tool calls one by one + for tool_call in final_tool_calls.values(): + function_name = tool_call.function.name + try: + function_args = json.loads(tool_call.function.arguments or "{}") + except json.JSONDecodeError as err: + tool_message = { + "role": "tool", + "tool_call_id": tool_call.id, + "name": function_name, + "content": f"Invalid JSON generated by the model: {err}", + } + tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message) + messages.append(tool_message_as_obj) + yield tool_message_as_obj + continue # move to next tool call + + tool_message = {"role": "tool", "tool_call_id": tool_call.id, "content": "", "name": function_name} + + # Check if this is an exit loop tool + if exit_loop_tools and function_name in [t.function.name for t in exit_loop_tools]: + tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message) + messages.append(tool_message_as_obj) + yield tool_message_as_obj + return + + # Execute tool call with the appropriate session + session = self.sessions.get(function_name) + if session is not None: + try: + result = await session.call_tool(function_name, function_args) + tool_message["content"] = format_result(result) + except Exception as err: + tool_message["content"] = f"Error: MCP tool call failed with error message: {err}" + else: + tool_message["content"] = f"Error: No session found for tool: {function_name}" + + # Yield tool message + tool_message_as_obj = ChatCompletionInputMessage.parse_obj_as_instance(tool_message) + messages.append(tool_message_as_obj) + yield tool_message_as_obj diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/types.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/types.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb5e0eac91562fb42015e0a59d78bbb172d5bbe --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/types.py @@ -0,0 +1,42 @@ +from typing import Dict, List, Literal, TypedDict, Union + +from typing_extensions import NotRequired + + +class InputConfig(TypedDict, total=False): + id: str + description: str + type: str + password: bool + + +class StdioServerConfig(TypedDict): + type: Literal["stdio"] + command: str + args: List[str] + env: Dict[str, str] + cwd: str + + +class HTTPServerConfig(TypedDict): + type: Literal["http"] + url: str + headers: Dict[str, str] + + +class SSEServerConfig(TypedDict): + type: Literal["sse"] + url: str + headers: Dict[str, str] + + +ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig] + + +# AgentConfig root object +class AgentConfig(TypedDict): + model: str + provider: str + apiKey: NotRequired[str] + inputs: List[InputConfig] + servers: List[ServerConfig] diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/utils.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e75eef45fe2e1f3deb169acd9c1ca76d9a7e283 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_mcp/utils.py @@ -0,0 +1,124 @@ +""" +Utility functions for MCPClient and Tiny Agents. + +Formatting utilities taken from the JS SDK: https://github.com/huggingface/huggingface.js/blob/main/packages/mcp-client/src/ResultFormatter.ts. +""" + +import json +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional, Tuple + +from huggingface_hub import snapshot_download +from huggingface_hub.errors import EntryNotFoundError + +from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, FILENAME_PROMPT +from .types import AgentConfig + + +if TYPE_CHECKING: + from mcp import types as mcp_types + + +def format_result(result: "mcp_types.CallToolResult") -> str: + """ + Formats a mcp.types.CallToolResult content into a human-readable string. + + Args: + result (CallToolResult) + Object returned by mcp.ClientSession.call_tool. + + Returns: + str + A formatted string representing the content of the result. + """ + content = result.content + + if len(content) == 0: + return "[No content]" + + formatted_parts: List[str] = [] + + for item in content: + if item.type == "text": + formatted_parts.append(item.text) + + elif item.type == "image": + formatted_parts.append( + f"[Binary Content: Image {item.mimeType}, {_get_base64_size(item.data)} bytes]\n" + f"The task is complete and the content accessible to the User" + ) + + elif item.type == "audio": + formatted_parts.append( + f"[Binary Content: Audio {item.mimeType}, {_get_base64_size(item.data)} bytes]\n" + f"The task is complete and the content accessible to the User" + ) + + elif item.type == "resource": + resource = item.resource + + if hasattr(resource, "text"): + formatted_parts.append(resource.text) + + elif hasattr(resource, "blob"): + formatted_parts.append( + f"[Binary Content ({resource.uri}): {resource.mimeType}, {_get_base64_size(resource.blob)} bytes]\n" + f"The task is complete and the content accessible to the User" + ) + + return "\n".join(formatted_parts) + + +def _get_base64_size(base64_str: str) -> int: + """Estimate the byte size of a base64-encoded string.""" + # Remove any prefix like "data:image/png;base64," + if "," in base64_str: + base64_str = base64_str.split(",")[1] + + padding = 0 + if base64_str.endswith("=="): + padding = 2 + elif base64_str.endswith("="): + padding = 1 + + return (len(base64_str) * 3) // 4 - padding + + +def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional[str]]: + """Load server config and prompt.""" + + def _read_dir(directory: Path) -> Tuple[AgentConfig, Optional[str]]: + cfg_file = directory / FILENAME_CONFIG + if not cfg_file.exists(): + raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally") + + config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8")) + prompt_file = directory / FILENAME_PROMPT + prompt: Optional[str] = prompt_file.read_text(encoding="utf-8") if prompt_file.exists() else None + return config, prompt + + if agent_path is None: + return DEFAULT_AGENT, None # type: ignore[return-value] + + path = Path(agent_path).expanduser() + + if path.is_file(): + return json.loads(path.read_text(encoding="utf-8")), None + + if path.is_dir(): + return _read_dir(path) + + # fetch from the Hub + try: + repo_dir = Path( + snapshot_download( + repo_id=DEFAULT_REPO_ID, + allow_patterns=f"{agent_path}/*", + repo_type="dataset", + ) + ) + return _read_dir(repo_dir / agent_path) + except Exception as err: + raise EntryNotFoundError( + f" Agent {agent_path} not found in tiny-agents/tiny-agents! Please make sure it exists in https://huggingface.co/datasets/tiny-agents/tiny-agents." + ) from err diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4866c30deab2e744f7e5105614cd9dc08caf7b --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__init__.py @@ -0,0 +1,210 @@ +from typing import Dict, Literal, Optional, Union + +from huggingface_hub.inference._providers.featherless_ai import ( + FeatherlessConversationalTask, + FeatherlessTextGenerationTask, +) +from huggingface_hub.utils import logging + +from ._common import TaskProviderHelper, _fetch_inference_provider_mapping +from .black_forest_labs import BlackForestLabsTextToImageTask +from .cerebras import CerebrasConversationalTask +from .cohere import CohereConversationalTask +from .fal_ai import ( + FalAIAutomaticSpeechRecognitionTask, + FalAIImageToImageTask, + FalAIImageToVideoTask, + FalAITextToImageTask, + FalAITextToSpeechTask, + FalAITextToVideoTask, +) +from .fireworks_ai import FireworksAIConversationalTask +from .groq import GroqConversationalTask +from .hf_inference import ( + HFInferenceBinaryInputTask, + HFInferenceConversational, + HFInferenceFeatureExtractionTask, + HFInferenceTask, +) +from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask +from .nebius import ( + NebiusConversationalTask, + NebiusFeatureExtractionTask, + NebiusTextGenerationTask, + NebiusTextToImageTask, +) +from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask +from .nscale import NscaleConversationalTask, NscaleTextToImageTask +from .openai import OpenAIConversationalTask +from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask +from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask +from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask + + +logger = logging.get_logger(__name__) + + +PROVIDER_T = Literal[ + "black-forest-labs", + "cerebras", + "cohere", + "fal-ai", + "featherless-ai", + "fireworks-ai", + "groq", + "hf-inference", + "hyperbolic", + "nebius", + "novita", + "nscale", + "openai", + "replicate", + "sambanova", + "together", +] + +PROVIDER_OR_POLICY_T = Union[PROVIDER_T, Literal["auto"]] + +PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = { + "black-forest-labs": { + "text-to-image": BlackForestLabsTextToImageTask(), + }, + "cerebras": { + "conversational": CerebrasConversationalTask(), + }, + "cohere": { + "conversational": CohereConversationalTask(), + }, + "fal-ai": { + "automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(), + "text-to-image": FalAITextToImageTask(), + "text-to-speech": FalAITextToSpeechTask(), + "text-to-video": FalAITextToVideoTask(), + "image-to-video": FalAIImageToVideoTask(), + "image-to-image": FalAIImageToImageTask(), + }, + "featherless-ai": { + "conversational": FeatherlessConversationalTask(), + "text-generation": FeatherlessTextGenerationTask(), + }, + "fireworks-ai": { + "conversational": FireworksAIConversationalTask(), + }, + "groq": { + "conversational": GroqConversationalTask(), + }, + "hf-inference": { + "text-to-image": HFInferenceTask("text-to-image"), + "conversational": HFInferenceConversational(), + "text-generation": HFInferenceTask("text-generation"), + "text-classification": HFInferenceTask("text-classification"), + "question-answering": HFInferenceTask("question-answering"), + "audio-classification": HFInferenceBinaryInputTask("audio-classification"), + "automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"), + "fill-mask": HFInferenceTask("fill-mask"), + "feature-extraction": HFInferenceFeatureExtractionTask(), + "image-classification": HFInferenceBinaryInputTask("image-classification"), + "image-segmentation": HFInferenceBinaryInputTask("image-segmentation"), + "document-question-answering": HFInferenceTask("document-question-answering"), + "image-to-text": HFInferenceBinaryInputTask("image-to-text"), + "object-detection": HFInferenceBinaryInputTask("object-detection"), + "audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"), + "zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"), + "zero-shot-classification": HFInferenceTask("zero-shot-classification"), + "image-to-image": HFInferenceBinaryInputTask("image-to-image"), + "sentence-similarity": HFInferenceTask("sentence-similarity"), + "table-question-answering": HFInferenceTask("table-question-answering"), + "tabular-classification": HFInferenceTask("tabular-classification"), + "text-to-speech": HFInferenceTask("text-to-speech"), + "token-classification": HFInferenceTask("token-classification"), + "translation": HFInferenceTask("translation"), + "summarization": HFInferenceTask("summarization"), + "visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"), + }, + "hyperbolic": { + "text-to-image": HyperbolicTextToImageTask(), + "conversational": HyperbolicTextGenerationTask("conversational"), + "text-generation": HyperbolicTextGenerationTask("text-generation"), + }, + "nebius": { + "text-to-image": NebiusTextToImageTask(), + "conversational": NebiusConversationalTask(), + "text-generation": NebiusTextGenerationTask(), + "feature-extraction": NebiusFeatureExtractionTask(), + }, + "novita": { + "text-generation": NovitaTextGenerationTask(), + "conversational": NovitaConversationalTask(), + "text-to-video": NovitaTextToVideoTask(), + }, + "nscale": { + "conversational": NscaleConversationalTask(), + "text-to-image": NscaleTextToImageTask(), + }, + "openai": { + "conversational": OpenAIConversationalTask(), + }, + "replicate": { + "image-to-image": ReplicateImageToImageTask(), + "text-to-image": ReplicateTextToImageTask(), + "text-to-speech": ReplicateTextToSpeechTask(), + "text-to-video": ReplicateTask("text-to-video"), + }, + "sambanova": { + "conversational": SambanovaConversationalTask(), + "feature-extraction": SambanovaFeatureExtractionTask(), + }, + "together": { + "text-to-image": TogetherTextToImageTask(), + "conversational": TogetherConversationalTask(), + "text-generation": TogetherTextGenerationTask(), + }, +} + + +def get_provider_helper( + provider: Optional[PROVIDER_OR_POLICY_T], task: str, model: Optional[str] +) -> TaskProviderHelper: + """Get provider helper instance by name and task. + + Args: + provider (`str`, *optional*): name of the provider, or "auto" to automatically select the provider for the model. + task (`str`): Name of the task + model (`str`, *optional*): Name of the model + Returns: + TaskProviderHelper: Helper instance for the specified provider and task + + Raises: + ValueError: If provider or task is not supported + """ + + if (model is None and provider in (None, "auto")) or ( + model is not None and model.startswith(("http://", "https://")) + ): + provider = "hf-inference" + + if provider is None: + logger.info( + "Defaulting to 'auto' which will select the first provider available for the model, sorted by the user's order in https://hf.co/settings/inference-providers." + ) + provider = "auto" + + if provider == "auto": + if model is None: + raise ValueError("Specifying a model is required when provider is 'auto'") + provider_mapping = _fetch_inference_provider_mapping(model) + provider = next(iter(provider_mapping)).provider + + provider_tasks = PROVIDERS.get(provider) # type: ignore + if provider_tasks is None: + raise ValueError( + f"Provider '{provider}' not supported. Available values: 'auto' or any provider from {list(PROVIDERS.keys())}." + "Passing 'auto' (default value) will automatically select the first provider available for the model, sorted " + "by the user's order in https://hf.co/settings/inference-providers." + ) + + if task not in provider_tasks: + raise ValueError( + f"Task '{task}' not supported for provider '{provider}'. Available tasks: {list(provider_tasks.keys())}" + ) + return provider_tasks[task] diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbb50eeefd3d9ba44000d9b4a16ed1bddb4c923f Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/_common.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bd8f3cf7d88690423ef89172113fbc3808cf3d5 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/black_forest_labs.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/black_forest_labs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6c05d1c15d8d45eb686e2652cae69629fd406d5 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/black_forest_labs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/cerebras.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/cerebras.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61b99bb7b23d29095dd67fe71dd2e605ac584637 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/cerebras.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/cohere.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/cohere.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd5f2f1724d67346b7738563fefb2e0c09bb292 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/cohere.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/fal_ai.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/fal_ai.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3b2660b0e1b2ff951c8e7d07f351d508038d7c4 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/fal_ai.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/featherless_ai.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/featherless_ai.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2961d05faf8169ddcd8a6dde6092d339fd77938 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/featherless_ai.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/fireworks_ai.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/fireworks_ai.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8bc63a1c033e9ad440c3beea6f77130fc73c8e0 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/fireworks_ai.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/groq.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/groq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b28ec0876fd7f284987709f98975d8fae2375d3b Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/groq.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/hf_inference.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/hf_inference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2ce9772e9a883843de374140d90f22238aa767c Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/hf_inference.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/hyperbolic.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/hyperbolic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8be007580d15e1d22bf105bb724f334d5b4b9297 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/hyperbolic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/nebius.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/nebius.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2a9ffc2e767b72cf50da0022d09e13bce2a24b9 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/nebius.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/novita.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/novita.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6a9655f3b90128ce40d249e08a315813f03ef2a Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/novita.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/nscale.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/nscale.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7521254e16890674dc632b0bd2d51314890e876 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/nscale.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/openai.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/openai.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae2f58a62c48ffcff7244717775c91185a72822d Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/openai.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/replicate.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/replicate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fddac0e7814814f732f1a72f8490793ea8ab6e89 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/replicate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/sambanova.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/sambanova.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adead6d01eca2eb1a383541e9e7caa9317358bed Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/sambanova.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/together.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/together.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c535741c2f3c8e32ec2a37deada7b1caaca00824 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/__pycache__/together.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/_common.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..3d41a7c1be54241a977bbfcd305df1be68f93517 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/_common.py @@ -0,0 +1,296 @@ +from functools import lru_cache +from typing import Any, Dict, List, Optional, Union, overload + +from huggingface_hub import constants +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters +from huggingface_hub.inference._generated.types.chat_completion import ChatCompletionInputMessage +from huggingface_hub.utils import build_hf_headers, get_token, logging + + +logger = logging.get_logger(__name__) + + +# Dev purposes only. +# If you want to try to run inference for a new model locally before it's registered on huggingface.co +# for a given Inference Provider, you can add it to the following dictionary. +HARDCODED_MODEL_INFERENCE_MAPPING: Dict[str, Dict[str, InferenceProviderMapping]] = { + # "HF model ID" => InferenceProviderMapping object initialized with "Model ID on Inference Provider's side" + # + # Example: + # "Qwen/Qwen2.5-Coder-32B-Instruct": InferenceProviderMapping(hf_model_id="Qwen/Qwen2.5-Coder-32B-Instruct", + # provider_id="Qwen2.5-Coder-32B-Instruct", + # task="conversational", + # status="live") + "cerebras": {}, + "cohere": {}, + "fal-ai": {}, + "fireworks-ai": {}, + "groq": {}, + "hf-inference": {}, + "hyperbolic": {}, + "nebius": {}, + "nscale": {}, + "replicate": {}, + "sambanova": {}, + "together": {}, +} + + +@overload +def filter_none(obj: Dict[str, Any]) -> Dict[str, Any]: ... +@overload +def filter_none(obj: List[Any]) -> List[Any]: ... + + +def filter_none(obj: Union[Dict[str, Any], List[Any]]) -> Union[Dict[str, Any], List[Any]]: + if isinstance(obj, dict): + cleaned: Dict[str, Any] = {} + for k, v in obj.items(): + if v is None: + continue + if isinstance(v, (dict, list)): + v = filter_none(v) + cleaned[k] = v + return cleaned + + if isinstance(obj, list): + return [filter_none(v) if isinstance(v, (dict, list)) else v for v in obj] + + raise ValueError(f"Expected dict or list, got {type(obj)}") + + +class TaskProviderHelper: + """Base class for task-specific provider helpers.""" + + def __init__(self, provider: str, base_url: str, task: str) -> None: + self.provider = provider + self.task = task + self.base_url = base_url + + def prepare_request( + self, + *, + inputs: Any, + parameters: Dict[str, Any], + headers: Dict, + model: Optional[str], + api_key: Optional[str], + extra_payload: Optional[Dict[str, Any]] = None, + ) -> RequestParameters: + """ + Prepare the request to be sent to the provider. + + Each step (api_key, model, headers, url, payload) can be customized in subclasses. + """ + # api_key from user, or local token, or raise error + api_key = self._prepare_api_key(api_key) + + # mapped model from HF model ID + provider_mapping_info = self._prepare_mapping_info(model) + + # default HF headers + user headers (to customize in subclasses) + headers = self._prepare_headers(headers, api_key) + + # routed URL if HF token, or direct URL (to customize in '_prepare_route' in subclasses) + url = self._prepare_url(api_key, provider_mapping_info.provider_id) + + # prepare payload (to customize in subclasses) + payload = self._prepare_payload_as_dict(inputs, parameters, provider_mapping_info=provider_mapping_info) + if payload is not None: + payload = recursive_merge(payload, filter_none(extra_payload or {})) + + # body data (to customize in subclasses) + data = self._prepare_payload_as_bytes(inputs, parameters, provider_mapping_info, extra_payload) + + # check if both payload and data are set and return + if payload is not None and data is not None: + raise ValueError("Both payload and data cannot be set in the same request.") + if payload is None and data is None: + raise ValueError("Either payload or data must be set in the request.") + return RequestParameters( + url=url, task=self.task, model=provider_mapping_info.provider_id, json=payload, data=data, headers=headers + ) + + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + """ + Return the response in the expected format. + + Override this method in subclasses for customized response handling.""" + return response + + def _prepare_api_key(self, api_key: Optional[str]) -> str: + """Return the API key to use for the request. + + Usually not overwritten in subclasses.""" + if api_key is None: + api_key = get_token() + if api_key is None: + raise ValueError( + f"You must provide an api_key to work with {self.provider} API or log in with `hf auth login`." + ) + return api_key + + def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: + """Return the mapped model ID to use for the request. + + Usually not overwritten in subclasses.""" + if model is None: + raise ValueError(f"Please provide an HF model ID supported by {self.provider}.") + + # hardcoded mapping for local testing + if HARDCODED_MODEL_INFERENCE_MAPPING.get(self.provider, {}).get(model): + return HARDCODED_MODEL_INFERENCE_MAPPING[self.provider][model] + + provider_mapping = None + for mapping in _fetch_inference_provider_mapping(model): + if mapping.provider == self.provider: + provider_mapping = mapping + break + + if provider_mapping is None: + raise ValueError(f"Model {model} is not supported by provider {self.provider}.") + + if provider_mapping.task != self.task: + raise ValueError( + f"Model {model} is not supported for task {self.task} and provider {self.provider}. " + f"Supported task: {provider_mapping.task}." + ) + + if provider_mapping.status == "staging": + logger.warning( + f"Model {model} is in staging mode for provider {self.provider}. Meant for test purposes only." + ) + if provider_mapping.status == "error": + logger.warning( + f"Our latest automated health check on model '{model}' for provider '{self.provider}' did not complete successfully. " + "Inference call might fail." + ) + return provider_mapping + + def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: + """Return the headers to use for the request. + + Override this method in subclasses for customized headers. + """ + return {**build_hf_headers(token=api_key), **headers} + + def _prepare_url(self, api_key: str, mapped_model: str) -> str: + """Return the URL to use for the request. + + Usually not overwritten in subclasses.""" + base_url = self._prepare_base_url(api_key) + route = self._prepare_route(mapped_model, api_key) + return f"{base_url.rstrip('/')}/{route.lstrip('/')}" + + def _prepare_base_url(self, api_key: str) -> str: + """Return the base URL to use for the request. + + Usually not overwritten in subclasses.""" + # Route to the proxy if the api_key is a HF TOKEN + if api_key.startswith("hf_"): + logger.info(f"Calling '{self.provider}' provider through Hugging Face router.") + return constants.INFERENCE_PROXY_TEMPLATE.format(provider=self.provider) + else: + logger.info(f"Calling '{self.provider}' provider directly.") + return self.base_url + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + """Return the route to use for the request. + + Override this method in subclasses for customized routes. + """ + return "" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + """Return the payload to use for the request, as a dict. + + Override this method in subclasses for customized payloads. + Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value. + """ + return None + + def _prepare_payload_as_bytes( + self, + inputs: Any, + parameters: Dict, + provider_mapping_info: InferenceProviderMapping, + extra_payload: Optional[Dict], + ) -> Optional[bytes]: + """Return the body to use for the request, as bytes. + + Override this method in subclasses for customized body data. + Only one of `_prepare_payload_as_dict` and `_prepare_payload_as_bytes` should return a value. + """ + return None + + +class BaseConversationalTask(TaskProviderHelper): + """ + Base class for conversational (chat completion) tasks. + The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/chat + """ + + def __init__(self, provider: str, base_url: str): + super().__init__(provider=provider, base_url=base_url, task="conversational") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/v1/chat/completions" + + def _prepare_payload_as_dict( + self, + inputs: List[Union[Dict, ChatCompletionInputMessage]], + parameters: Dict, + provider_mapping_info: InferenceProviderMapping, + ) -> Optional[Dict]: + return filter_none({"messages": inputs, **parameters, "model": provider_mapping_info.provider_id}) + + +class BaseTextGenerationTask(TaskProviderHelper): + """ + Base class for text-generation (completion) tasks. + The schema follows the OpenAI API format defined here: https://platform.openai.com/docs/api-reference/completions + """ + + def __init__(self, provider: str, base_url: str): + super().__init__(provider=provider, base_url=base_url, task="text-generation") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/v1/completions" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + return filter_none({"prompt": inputs, **parameters, "model": provider_mapping_info.provider_id}) + + +@lru_cache(maxsize=None) +def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapping"]: + """ + Fetch provider mappings for a model from the Hub. + """ + from huggingface_hub.hf_api import HfApi + + info = HfApi().model_info(model, expand=["inferenceProviderMapping"]) + provider_mapping = info.inference_provider_mapping + if provider_mapping is None: + raise ValueError(f"No provider mapping found for model {model}") + return provider_mapping + + +def recursive_merge(dict1: Dict, dict2: Dict) -> Dict: + return { + **dict1, + **{ + key: recursive_merge(dict1[key], value) + if (key in dict1 and isinstance(dict1[key], dict) and isinstance(value, dict)) + else value + for key, value in dict2.items() + }, + } diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/black_forest_labs.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/black_forest_labs.py new file mode 100644 index 0000000000000000000000000000000000000000..afa8ed281d8a8e94a054b83b74ec6909f623e300 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/black_forest_labs.py @@ -0,0 +1,69 @@ +import time +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.utils import logging +from huggingface_hub.utils._http import get_session + + +logger = logging.get_logger(__name__) + +MAX_POLLING_ATTEMPTS = 6 +POLLING_INTERVAL = 1.0 + + +class BlackForestLabsTextToImageTask(TaskProviderHelper): + def __init__(self): + super().__init__(provider="black-forest-labs", base_url="https://api.us1.bfl.ai", task="text-to-image") + + def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: + headers = super()._prepare_headers(headers, api_key) + if not api_key.startswith("hf_"): + _ = headers.pop("authorization") + headers["X-Key"] = api_key + return headers + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return f"/v1/{mapped_model}" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + parameters = filter_none(parameters) + if "num_inference_steps" in parameters: + parameters["steps"] = parameters.pop("num_inference_steps") + if "guidance_scale" in parameters: + parameters["guidance"] = parameters.pop("guidance_scale") + + return {"prompt": inputs, **parameters} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + """ + Polling mechanism for Black Forest Labs since the API is asynchronous. + """ + url = _as_dict(response).get("polling_url") + session = get_session() + for _ in range(MAX_POLLING_ATTEMPTS): + time.sleep(POLLING_INTERVAL) + + response = session.get(url, headers={"Content-Type": "application/json"}) # type: ignore + response.raise_for_status() # type: ignore + response_json: Dict = response.json() # type: ignore + status = response_json.get("status") + logger.info( + f"Polling generation result from {url}. Current status: {status}. " + f"Will retry after {POLLING_INTERVAL} seconds if not ready." + ) + + if ( + status == "Ready" + and isinstance(response_json.get("result"), dict) + and (sample_url := response_json["result"].get("sample")) + ): + image_resp = session.get(sample_url) + image_resp.raise_for_status() + return image_resp.content + + raise TimeoutError(f"Failed to get the image URL after {MAX_POLLING_ATTEMPTS} attempts.") diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/cerebras.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/cerebras.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b9c3aacb3e134a8e755297c15ece198ffe633d --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/cerebras.py @@ -0,0 +1,6 @@ +from ._common import BaseConversationalTask + + +class CerebrasConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="cerebras", base_url="https://api.cerebras.ai") diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/cohere.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/cohere.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e9191caec50b0e659dddceba3e817a4ac28307 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/cohere.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Optional + +from huggingface_hub.hf_api import InferenceProviderMapping + +from ._common import BaseConversationalTask + + +_PROVIDER = "cohere" +_BASE_URL = "https://api.cohere.com" + + +class CohereConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/compatibility/v1/chat/completions" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) + response_format = parameters.get("response_format") + if isinstance(response_format, dict) and response_format.get("type") == "json_schema": + json_schema_details = response_format.get("json_schema") + if isinstance(json_schema_details, dict) and "schema" in json_schema_details: + payload["response_format"] = { # type: ignore [index] + "type": "json_object", + "schema": json_schema_details["schema"], + } + + return payload diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/fal_ai.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/fal_ai.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7b7a5779ecdd691637fd6f755c6bf7786fc39e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/fal_ai.py @@ -0,0 +1,246 @@ +import base64 +import time +from abc import ABC +from typing import Any, Dict, Optional, Union +from urllib.parse import urlparse + +from huggingface_hub import constants +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.utils import get_session, hf_raise_for_status +from huggingface_hub.utils.logging import get_logger + + +logger = get_logger(__name__) + +# Arbitrary polling interval +_POLLING_INTERVAL = 0.5 + + +class FalAITask(TaskProviderHelper, ABC): + def __init__(self, task: str): + super().__init__(provider="fal-ai", base_url="https://fal.run", task=task) + + def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: + headers = super()._prepare_headers(headers, api_key) + if not api_key.startswith("hf_"): + headers["authorization"] = f"Key {api_key}" + return headers + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return f"/{mapped_model}" + + +class FalAIQueueTask(TaskProviderHelper, ABC): + def __init__(self, task: str): + super().__init__(provider="fal-ai", base_url="https://queue.fal.run", task=task) + + def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: + headers = super()._prepare_headers(headers, api_key) + if not api_key.startswith("hf_"): + headers["authorization"] = f"Key {api_key}" + return headers + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + if api_key.startswith("hf_"): + # Use the queue subdomain for HF routing + return f"/{mapped_model}?_subdomain=queue" + return f"/{mapped_model}" + + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + response_dict = _as_dict(response) + + request_id = response_dict.get("request_id") + if not request_id: + raise ValueError("No request ID found in the response") + if request_params is None: + raise ValueError( + f"A `RequestParameters` object should be provided to get {self.task} responses with Fal AI." + ) + + # extract the base url and query params + parsed_url = urlparse(request_params.url) + # a bit hacky way to concatenate the provider name without parsing `parsed_url.path` + base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{'/fal-ai' if parsed_url.netloc == 'router.huggingface.co' else ''}" + query_param = f"?{parsed_url.query}" if parsed_url.query else "" + + # extracting the provider model id for status and result urls + # from the response as it might be different from the mapped model in `request_params.url` + model_id = urlparse(response_dict.get("response_url")).path + status_url = f"{base_url}{str(model_id)}/status{query_param}" + result_url = f"{base_url}{str(model_id)}{query_param}" + + status = response_dict.get("status") + logger.info("Generating the output.. this can take several minutes.") + while status != "COMPLETED": + time.sleep(_POLLING_INTERVAL) + status_response = get_session().get(status_url, headers=request_params.headers) + hf_raise_for_status(status_response) + status = status_response.json().get("status") + + return get_session().get(result_url, headers=request_params.headers).json() + + +class FalAIAutomaticSpeechRecognitionTask(FalAITask): + def __init__(self): + super().__init__("automatic-speech-recognition") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): + # If input is a URL, pass it directly + audio_url = inputs + else: + # If input is a file path, read it first + if isinstance(inputs, str): + with open(inputs, "rb") as f: + inputs = f.read() + + audio_b64 = base64.b64encode(inputs).decode() + content_type = "audio/mpeg" + audio_url = f"data:{content_type};base64,{audio_b64}" + + return {"audio_url": audio_url, **filter_none(parameters)} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + text = _as_dict(response)["text"] + if not isinstance(text, str): + raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.") + return text + + +class FalAITextToImageTask(FalAITask): + def __init__(self): + super().__init__("text-to-image") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload: Dict[str, Any] = { + "prompt": inputs, + **filter_none(parameters), + } + if "width" in payload and "height" in payload: + payload["image_size"] = { + "width": payload.pop("width"), + "height": payload.pop("height"), + } + if provider_mapping_info.adapter_weights_path is not None: + lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=provider_mapping_info.hf_model_id, + revision="main", + filename=provider_mapping_info.adapter_weights_path, + ) + payload["loras"] = [{"path": lora_path, "scale": 1}] + if provider_mapping_info.provider_id == "fal-ai/lora": + # little hack: fal requires the base model for stable-diffusion-based loras but not for flux-based + # See payloads in https://fal.ai/models/fal-ai/lora/api vs https://fal.ai/models/fal-ai/flux-lora/api + payload["model_name"] = "stabilityai/stable-diffusion-xl-base-1.0" + + return payload + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + url = _as_dict(response)["images"][0]["url"] + return get_session().get(url).content + + +class FalAITextToSpeechTask(FalAITask): + def __init__(self): + super().__init__("text-to-speech") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + return {"text": inputs, **filter_none(parameters)} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + url = _as_dict(response)["audio"]["url"] + return get_session().get(url).content + + +class FalAITextToVideoTask(FalAIQueueTask): + def __init__(self): + super().__init__("text-to-video") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + return {"prompt": inputs, **filter_none(parameters)} + + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + output = super().get_response(response, request_params) + url = _as_dict(output)["video"]["url"] + return get_session().get(url).content + + +class FalAIImageToImageTask(FalAIQueueTask): + def __init__(self): + super().__init__("image-to-image") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + image_url = _as_url(inputs, default_mime_type="image/jpeg") + payload: Dict[str, Any] = { + "image_url": image_url, + **filter_none(parameters), + } + if provider_mapping_info.adapter_weights_path is not None: + lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=provider_mapping_info.hf_model_id, + revision="main", + filename=provider_mapping_info.adapter_weights_path, + ) + payload["loras"] = [{"path": lora_path, "scale": 1}] + + return payload + + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + output = super().get_response(response, request_params) + url = _as_dict(output)["images"][0]["url"] + return get_session().get(url).content + + +class FalAIImageToVideoTask(FalAIQueueTask): + def __init__(self): + super().__init__("image-to-video") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + image_url = _as_url(inputs, default_mime_type="image/jpeg") + payload: Dict[str, Any] = { + "image_url": image_url, + **filter_none(parameters), + } + if provider_mapping_info.adapter_weights_path is not None: + lora_path = constants.HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=provider_mapping_info.hf_model_id, + revision="main", + filename=provider_mapping_info.adapter_weights_path, + ) + payload["loras"] = [{"path": lora_path, "scale": 1}] + return payload + + def get_response( + self, + response: Union[bytes, Dict], + request_params: Optional[RequestParameters] = None, + ) -> Any: + output = super().get_response(response, request_params) + url = _as_dict(output)["video"]["url"] + return get_session().get(url).content diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/featherless_ai.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/featherless_ai.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad1c48134f5c990b6ac4fca5ff919f4cc0d2373 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/featherless_ai.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict + +from ._common import BaseConversationalTask, BaseTextGenerationTask, filter_none + + +_PROVIDER = "featherless-ai" +_BASE_URL = "https://api.featherless.ai" + + +class FeatherlessTextGenerationTask(BaseTextGenerationTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + params = filter_none(parameters.copy()) + params["max_tokens"] = params.pop("max_new_tokens", None) + + return {"prompt": inputs, **params, "model": provider_mapping_info.provider_id} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + output = _as_dict(response)["choices"][0] + return { + "generated_text": output["text"], + "details": { + "finish_reason": output.get("finish_reason"), + "seed": output.get("seed"), + }, + } + + +class FeatherlessConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/fireworks_ai.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/fireworks_ai.py new file mode 100644 index 0000000000000000000000000000000000000000..b4cc19a5700047f6516b2784d9785a99d7e32451 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/fireworks_ai.py @@ -0,0 +1,27 @@ +from typing import Any, Dict, Optional + +from huggingface_hub.hf_api import InferenceProviderMapping + +from ._common import BaseConversationalTask + + +class FireworksAIConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="fireworks-ai", base_url="https://api.fireworks.ai") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/inference/v1/chat/completions" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) + response_format = parameters.get("response_format") + if isinstance(response_format, dict) and response_format.get("type") == "json_schema": + json_schema_details = response_format.get("json_schema") + if isinstance(json_schema_details, dict) and "schema" in json_schema_details: + payload["response_format"] = { # type: ignore [index] + "type": "json_object", + "schema": json_schema_details["schema"], + } + return payload diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/groq.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/groq.py new file mode 100644 index 0000000000000000000000000000000000000000..11e677504e89bc02b966e7d37d9e11f1b94b297f --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/groq.py @@ -0,0 +1,9 @@ +from ._common import BaseConversationalTask + + +class GroqConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="groq", base_url="https://api.groq.com") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/openai/v1/chat/completions" diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/hf_inference.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/hf_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a5659eea7ba33e87b75eb8c6ef215dfe6320caab --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/hf_inference.py @@ -0,0 +1,220 @@ +import json +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional, Union +from urllib.parse import urlparse, urlunparse + +from huggingface_hub import constants +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _b64_encode, _bytes_to_dict, _open_as_binary +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.utils import build_hf_headers, get_session, get_token, hf_raise_for_status + + +class HFInferenceTask(TaskProviderHelper): + """Base class for HF Inference API tasks.""" + + def __init__(self, task: str): + super().__init__( + provider="hf-inference", + base_url=constants.INFERENCE_PROXY_TEMPLATE.format(provider="hf-inference"), + task=task, + ) + + def _prepare_api_key(self, api_key: Optional[str]) -> str: + # special case: for HF Inference we allow not providing an API key + return api_key or get_token() # type: ignore[return-value] + + def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: + if model is not None and model.startswith(("http://", "https://")): + return InferenceProviderMapping( + provider="hf-inference", providerId=model, hf_model_id=model, task=self.task, status="live" + ) + model_id = model if model is not None else _fetch_recommended_models().get(self.task) + if model_id is None: + raise ValueError( + f"Task {self.task} has no recommended model for HF Inference. Please specify a model" + " explicitly. Visit https://huggingface.co/tasks for more info." + ) + _check_supported_task(model_id, self.task) + return InferenceProviderMapping( + provider="hf-inference", providerId=model_id, hf_model_id=model_id, task=self.task, status="live" + ) + + def _prepare_url(self, api_key: str, mapped_model: str) -> str: + # hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment) + if mapped_model.startswith(("http://", "https://")): + return mapped_model + return ( + # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks. + f"{self.base_url}/models/{mapped_model}/pipeline/{self.task}" + if self.task in ("feature-extraction", "sentence-similarity") + # Otherwise, we use the default endpoint + else f"{self.base_url}/models/{mapped_model}" + ) + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + if isinstance(inputs, bytes): + raise ValueError(f"Unexpected binary input for task {self.task}.") + if isinstance(inputs, Path): + raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})") + return filter_none({"inputs": inputs, "parameters": parameters}) + + +class HFInferenceBinaryInputTask(HFInferenceTask): + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + return None + + def _prepare_payload_as_bytes( + self, + inputs: Any, + parameters: Dict, + provider_mapping_info: InferenceProviderMapping, + extra_payload: Optional[Dict], + ) -> Optional[bytes]: + parameters = filter_none(parameters) + extra_payload = extra_payload or {} + has_parameters = len(parameters) > 0 or len(extra_payload) > 0 + + # Raise if not a binary object or a local path or a URL. + if not isinstance(inputs, (bytes, Path)) and not isinstance(inputs, str): + raise ValueError(f"Expected binary inputs or a local path or a URL. Got {inputs}") + + # Send inputs as raw content when no parameters are provided + if not has_parameters: + with _open_as_binary(inputs) as data: + data_as_bytes = data if isinstance(data, bytes) else data.read() + return data_as_bytes + + # Otherwise encode as b64 + return json.dumps({"inputs": _b64_encode(inputs), "parameters": parameters, **extra_payload}).encode("utf-8") + + +class HFInferenceConversational(HFInferenceTask): + def __init__(self): + super().__init__("conversational") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload = filter_none(parameters) + mapped_model = provider_mapping_info.provider_id + payload_model = parameters.get("model") or mapped_model + + if payload_model is None or payload_model.startswith(("http://", "https://")): + payload_model = "dummy" + + response_format = parameters.get("response_format") + if isinstance(response_format, dict) and response_format.get("type") == "json_schema": + payload["response_format"] = { + "type": "json_object", + "value": response_format["json_schema"]["schema"], + } + return {**payload, "model": payload_model, "messages": inputs} + + def _prepare_url(self, api_key: str, mapped_model: str) -> str: + base_url = ( + mapped_model + if mapped_model.startswith(("http://", "https://")) + else f"{constants.INFERENCE_PROXY_TEMPLATE.format(provider='hf-inference')}/models/{mapped_model}" + ) + return _build_chat_completion_url(base_url) + + +def _build_chat_completion_url(model_url: str) -> str: + parsed = urlparse(model_url) + path = parsed.path.rstrip("/") + + # If the path already ends with /chat/completions, we're done! + if path.endswith("/chat/completions"): + return model_url + + # Append /chat/completions if not already present + if path.endswith("/v1"): + new_path = path + "/chat/completions" + # If path was empty or just "/", set the full path + elif not path: + new_path = "/v1/chat/completions" + # Append /v1/chat/completions if not already present + else: + new_path = path + "/v1/chat/completions" + + # Reconstruct the URL with the new path and original query parameters. + return urlunparse(parsed._replace(path=new_path)) + + +@lru_cache(maxsize=1) +def _fetch_recommended_models() -> Dict[str, Optional[str]]: + response = get_session().get(f"{constants.ENDPOINT}/api/tasks", headers=build_hf_headers()) + hf_raise_for_status(response) + return {task: next(iter(details["widgetModels"]), None) for task, details in response.json().items()} + + +@lru_cache(maxsize=None) +def _check_supported_task(model: str, task: str) -> None: + from huggingface_hub.hf_api import HfApi + + model_info = HfApi().model_info(model) + pipeline_tag = model_info.pipeline_tag + tags = model_info.tags or [] + is_conversational = "conversational" in tags + if task in ("text-generation", "conversational"): + if pipeline_tag == "text-generation": + # text-generation + conversational tag -> both tasks allowed + if is_conversational: + return + # text-generation without conversational tag -> only text-generation allowed + if task == "text-generation": + return + raise ValueError(f"Model '{model}' doesn't support task '{task}'.") + + if pipeline_tag == "text2text-generation": + if task == "text-generation": + return + raise ValueError(f"Model '{model}' doesn't support task '{task}'.") + + if pipeline_tag == "image-text-to-text": + if is_conversational and task == "conversational": + return # Only conversational allowed if tagged as conversational + raise ValueError("Non-conversational image-text-to-text task is not supported.") + + if ( + task in ("feature-extraction", "sentence-similarity") + and pipeline_tag in ("feature-extraction", "sentence-similarity") + and task in tags + ): + # feature-extraction and sentence-similarity are interchangeable for HF Inference + return + + # For all other tasks, just check pipeline tag + if pipeline_tag != task: + raise ValueError( + f"Model '{model}' doesn't support task '{task}'. Supported tasks: '{pipeline_tag}', got: '{task}'" + ) + return + + +class HFInferenceFeatureExtractionTask(HFInferenceTask): + def __init__(self): + super().__init__("feature-extraction") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + if isinstance(inputs, bytes): + raise ValueError(f"Unexpected binary input for task {self.task}.") + if isinstance(inputs, Path): + raise ValueError(f"Unexpected path input for task {self.task} (got {inputs})") + + # Parameters are sent at root-level for feature-extraction task + # See specs: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/tasks/feature-extraction/spec/input.json + return {"inputs": inputs, **filter_none(parameters)} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + if isinstance(response, bytes): + return _bytes_to_dict(response) + return response diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/hyperbolic.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/hyperbolic.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcb14cc275f6b80db5643361b9dfd3cbf8d91a2 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/hyperbolic.py @@ -0,0 +1,47 @@ +import base64 +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none + + +class HyperbolicTextToImageTask(TaskProviderHelper): + def __init__(self): + super().__init__(provider="hyperbolic", base_url="https://api.hyperbolic.xyz", task="text-to-image") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/v1/images/generations" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id + parameters = filter_none(parameters) + if "num_inference_steps" in parameters: + parameters["steps"] = parameters.pop("num_inference_steps") + if "guidance_scale" in parameters: + parameters["cfg_scale"] = parameters.pop("guidance_scale") + # For Hyperbolic, the width and height are required parameters + if "width" not in parameters: + parameters["width"] = 512 + if "height" not in parameters: + parameters["height"] = 512 + return {"prompt": inputs, "model_name": mapped_model, **parameters} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + response_dict = _as_dict(response) + return base64.b64decode(response_dict["images"][0]["image"]) + + +class HyperbolicTextGenerationTask(BaseConversationalTask): + """ + Special case for Hyperbolic, where text-generation task is handled as a conversational task. + """ + + def __init__(self, task: str): + super().__init__( + provider="hyperbolic", + base_url="https://api.hyperbolic.xyz", + ) + self.task = task diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/nebius.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/nebius.py new file mode 100644 index 0000000000000000000000000000000000000000..85ad67c4c8835d7fb8bfe5f36e426614174a66ba --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/nebius.py @@ -0,0 +1,83 @@ +import base64 +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._providers._common import ( + BaseConversationalTask, + BaseTextGenerationTask, + TaskProviderHelper, + filter_none, +) + + +class NebiusTextGenerationTask(BaseTextGenerationTask): + def __init__(self): + super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai") + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + output = _as_dict(response)["choices"][0] + return { + "generated_text": output["text"], + "details": { + "finish_reason": output.get("finish_reason"), + "seed": output.get("seed"), + }, + } + + +class NebiusConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="nebius", base_url="https://api.studio.nebius.ai") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) + response_format = parameters.get("response_format") + if isinstance(response_format, dict) and response_format.get("type") == "json_schema": + json_schema_details = response_format.get("json_schema") + if isinstance(json_schema_details, dict) and "schema" in json_schema_details: + payload["guided_json"] = json_schema_details["schema"] # type: ignore [index] + return payload + + +class NebiusTextToImageTask(TaskProviderHelper): + def __init__(self): + super().__init__(task="text-to-image", provider="nebius", base_url="https://api.studio.nebius.ai") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/v1/images/generations" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id + parameters = filter_none(parameters) + if "guidance_scale" in parameters: + parameters.pop("guidance_scale") + if parameters.get("response_format") not in ("b64_json", "url"): + parameters["response_format"] = "b64_json" + + return {"prompt": inputs, **parameters, "model": mapped_model} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + response_dict = _as_dict(response) + return base64.b64decode(response_dict["data"][0]["b64_json"]) + + +class NebiusFeatureExtractionTask(TaskProviderHelper): + def __init__(self): + super().__init__(task="feature-extraction", provider="nebius", base_url="https://api.studio.nebius.ai") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/v1/embeddings" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + return {"input": inputs, "model": provider_mapping_info.provider_id} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + embeddings = _as_dict(response)["data"] + return [embedding["embedding"] for embedding in embeddings] diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/novita.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/novita.py new file mode 100644 index 0000000000000000000000000000000000000000..44adc9017b456f487513cde251086075d84b69f0 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/novita.py @@ -0,0 +1,69 @@ +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._providers._common import ( + BaseConversationalTask, + BaseTextGenerationTask, + TaskProviderHelper, + filter_none, +) +from huggingface_hub.utils import get_session + + +_PROVIDER = "novita" +_BASE_URL = "https://api.novita.ai" + + +class NovitaTextGenerationTask(BaseTextGenerationTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + # there is no v1/ route for novita + return "/v3/openai/completions" + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + output = _as_dict(response)["choices"][0] + return { + "generated_text": output["text"], + "details": { + "finish_reason": output.get("finish_reason"), + "seed": output.get("seed"), + }, + } + + +class NovitaConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + # there is no v1/ route for novita + return "/v3/openai/chat/completions" + + +class NovitaTextToVideoTask(TaskProviderHelper): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task="text-to-video") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return f"/v3/hf/{mapped_model}" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + return {"prompt": inputs, **filter_none(parameters)} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + response_dict = _as_dict(response) + if not ( + isinstance(response_dict, dict) + and "video" in response_dict + and isinstance(response_dict["video"], dict) + and "video_url" in response_dict["video"] + ): + raise ValueError("Expected response format: { 'video': { 'video_url': string } }") + + video_url = response_dict["video"]["video_url"] + return get_session().get(video_url).content diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/nscale.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/nscale.py new file mode 100644 index 0000000000000000000000000000000000000000..ce5b20e354e246e93a7dd9831e4acf69ebcfad63 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/nscale.py @@ -0,0 +1,44 @@ +import base64 +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict + +from ._common import BaseConversationalTask, TaskProviderHelper, filter_none + + +class NscaleConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="nscale", base_url="https://inference.api.nscale.com") + + +class NscaleTextToImageTask(TaskProviderHelper): + def __init__(self): + super().__init__(provider="nscale", base_url="https://inference.api.nscale.com", task="text-to-image") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/v1/images/generations" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id + # Combine all parameters except inputs and parameters + parameters = filter_none(parameters) + if "width" in parameters and "height" in parameters: + parameters["size"] = f"{parameters.pop('width')}x{parameters.pop('height')}" + if "num_inference_steps" in parameters: + parameters.pop("num_inference_steps") + if "cfg_scale" in parameters: + parameters.pop("cfg_scale") + payload = { + "response_format": "b64_json", + "prompt": inputs, + "model": mapped_model, + **parameters, + } + return payload + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + response_dict = _as_dict(response) + return base64.b64decode(response_dict["data"][0]["b64_json"]) diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/openai.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..7a554093c173ea8f664cb7fbd9616ce3a08ce78c --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/openai.py @@ -0,0 +1,25 @@ +from typing import Optional + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._providers._common import BaseConversationalTask + + +class OpenAIConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="openai", base_url="https://api.openai.com") + + def _prepare_api_key(self, api_key: Optional[str]) -> str: + if api_key is None: + raise ValueError("You must provide an api_key to work with OpenAI API.") + if api_key.startswith("hf_"): + raise ValueError( + "OpenAI provider is not available through Hugging Face routing, please use your own OpenAI API key." + ) + return api_key + + def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping: + if model is None: + raise ValueError("Please provide an OpenAI model ID, e.g. `gpt-4o` or `o1`.") + return InferenceProviderMapping( + provider="openai", providerId=model, task="conversational", status="live", hf_model_id=model + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/replicate.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..8a1037b6f2b9a8a03cc2e282bb7c50c0c6507847 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/replicate.py @@ -0,0 +1,90 @@ +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict, _as_url +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.utils import get_session + + +_PROVIDER = "replicate" +_BASE_URL = "https://api.replicate.com" + + +class ReplicateTask(TaskProviderHelper): + def __init__(self, task: str): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task) + + def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: + headers = super()._prepare_headers(headers, api_key) + headers["Prefer"] = "wait" + return headers + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + if ":" in mapped_model: + return "/v1/predictions" + return f"/v1/models/{mapped_model}/predictions" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id + payload: Dict[str, Any] = {"input": {"prompt": inputs, **filter_none(parameters)}} + if ":" in mapped_model: + version = mapped_model.split(":", 1)[1] + payload["version"] = version + return payload + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + response_dict = _as_dict(response) + if response_dict.get("output") is None: + raise TimeoutError( + f"Inference request timed out after 60 seconds. No output generated for model {response_dict.get('model')}" + "The model might be in cold state or starting up. Please try again later." + ) + output_url = ( + response_dict["output"] if isinstance(response_dict["output"], str) else response_dict["output"][0] + ) + return get_session().get(output_url).content + + +class ReplicateTextToImageTask(ReplicateTask): + def __init__(self): + super().__init__("text-to-image") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] + if provider_mapping_info.adapter_weights_path is not None: + payload["input"]["lora_weights"] = f"https://huggingface.co/{provider_mapping_info.hf_model_id}" + return payload + + +class ReplicateTextToSpeechTask(ReplicateTask): + def __init__(self): + super().__init__("text-to-speech") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] + payload["input"]["text"] = payload["input"].pop("prompt") # rename "prompt" to "text" for TTS + return payload + + +class ReplicateImageToImageTask(ReplicateTask): + def __init__(self): + super().__init__("image-to-image") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + image_url = _as_url(inputs, default_mime_type="image/jpeg") + + payload: Dict[str, Any] = {"input": {"input_image": image_url, **filter_none(parameters)}} + + mapped_model = provider_mapping_info.provider_id + if ":" in mapped_model: + version = mapped_model.split(":", 1)[1] + payload["version"] = version + return payload diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/sambanova.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/sambanova.py new file mode 100644 index 0000000000000000000000000000000000000000..ed96fb766ce49003b605bda8ef8ee34da0ebe2f4 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/sambanova.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._providers._common import BaseConversationalTask, TaskProviderHelper, filter_none + + +class SambanovaConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="sambanova", base_url="https://api.sambanova.ai") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + response_format_config = parameters.get("response_format") + if isinstance(response_format_config, dict): + if response_format_config.get("type") == "json_schema": + json_schema_config = response_format_config.get("json_schema", {}) + strict = json_schema_config.get("strict") + if isinstance(json_schema_config, dict) and (strict is True or strict is None): + json_schema_config["strict"] = False + + payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) + return payload + + +class SambanovaFeatureExtractionTask(TaskProviderHelper): + def __init__(self): + super().__init__(provider="sambanova", base_url="https://api.sambanova.ai", task="feature-extraction") + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + return "/v1/embeddings" + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + parameters = filter_none(parameters) + return {"input": inputs, "model": provider_mapping_info.provider_id, **parameters} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + embeddings = _as_dict(response)["data"] + return [embedding["embedding"] for embedding in embeddings] diff --git a/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/together.py b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/together.py new file mode 100644 index 0000000000000000000000000000000000000000..de166b7baf8d50b255f29cf8cc9b9d3fa639646e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/inference/_providers/together.py @@ -0,0 +1,88 @@ +import base64 +from abc import ABC +from typing import Any, Dict, Optional, Union + +from huggingface_hub.hf_api import InferenceProviderMapping +from huggingface_hub.inference._common import RequestParameters, _as_dict +from huggingface_hub.inference._providers._common import ( + BaseConversationalTask, + BaseTextGenerationTask, + TaskProviderHelper, + filter_none, +) + + +_PROVIDER = "together" +_BASE_URL = "https://api.together.xyz" + + +class TogetherTask(TaskProviderHelper, ABC): + """Base class for Together API tasks.""" + + def __init__(self, task: str): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL, task=task) + + def _prepare_route(self, mapped_model: str, api_key: str) -> str: + if self.task == "text-to-image": + return "/v1/images/generations" + elif self.task == "conversational": + return "/v1/chat/completions" + elif self.task == "text-generation": + return "/v1/completions" + raise ValueError(f"Unsupported task '{self.task}' for Together API.") + + +class TogetherTextGenerationTask(BaseTextGenerationTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + output = _as_dict(response)["choices"][0] + return { + "generated_text": output["text"], + "details": { + "finish_reason": output.get("finish_reason"), + "seed": output.get("seed"), + }, + } + + +class TogetherConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider=_PROVIDER, base_url=_BASE_URL) + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) + response_format = parameters.get("response_format") + if isinstance(response_format, dict) and response_format.get("type") == "json_schema": + json_schema_details = response_format.get("json_schema") + if isinstance(json_schema_details, dict) and "schema" in json_schema_details: + payload["response_format"] = { # type: ignore [index] + "type": "json_object", + "schema": json_schema_details["schema"], + } + + return payload + + +class TogetherTextToImageTask(TogetherTask): + def __init__(self): + super().__init__("text-to-image") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + mapped_model = provider_mapping_info.provider_id + parameters = filter_none(parameters) + if "num_inference_steps" in parameters: + parameters["steps"] = parameters.pop("num_inference_steps") + if "guidance_scale" in parameters: + parameters["guidance"] = parameters.pop("guidance_scale") + + return {"prompt": inputs, "response_format": "base64", **parameters, "model": mapped_model} + + def get_response(self, response: Union[bytes, Dict], request_params: Optional[RequestParameters] = None) -> Any: + response_dict = _as_dict(response) + return base64.b64decode(response_dict["data"][0]["b64_json"]) diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/__init__.py b/phivenv/Lib/site-packages/huggingface_hub/serialization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8949a22a5f65ab29b7df65aa6a9df9bce0544b7e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/serialization/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: F401 +"""Contains helpers to serialize tensors.""" + +from ._base import StateDictSplit, split_state_dict_into_shards_factory +from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards +from ._torch import ( + get_torch_storage_id, + get_torch_storage_size, + load_state_dict_from_file, + load_torch_model, + save_torch_model, + save_torch_state_dict, + split_torch_state_dict_into_shards, +) diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b51538f602e7e204d85bb77013cf108bb9547e47 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_base.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1da2a4679c1cbde89b534bba614ae44433d274d5 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_dduf.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_dduf.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76a461e6c54e5faa5d956d2c74092a50a5fc47b7 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_dduf.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_tensorflow.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_tensorflow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9077e9e6c9258a78bf6d2522ca6c42ace36b7f90 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_tensorflow.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_torch.cpython-39.pyc b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_torch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af43e70d22286a9223ff155ee506e9088fb67335 Binary files /dev/null and b/phivenv/Lib/site-packages/huggingface_hub/serialization/__pycache__/_torch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/_base.py b/phivenv/Lib/site-packages/huggingface_hub/serialization/_base.py new file mode 100644 index 0000000000000000000000000000000000000000..b7b6454a90e1942854dd0a095a59c92794323279 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/serialization/_base.py @@ -0,0 +1,210 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains helpers to split tensors into shards.""" + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union + +from .. import logging + + +TensorT = TypeVar("TensorT") +TensorSizeFn_T = Callable[[TensorT], int] +StorageIDFn_T = Callable[[TensorT], Optional[Any]] + +MAX_SHARD_SIZE = "5GB" +SIZE_UNITS = { + "TB": 10**12, + "GB": 10**9, + "MB": 10**6, + "KB": 10**3, +} + + +logger = logging.get_logger(__file__) + + +@dataclass +class StateDictSplit: + is_sharded: bool = field(init=False) + metadata: Dict[str, Any] + filename_to_tensors: Dict[str, List[str]] + tensor_to_filename: Dict[str, str] + + def __post_init__(self): + self.is_sharded = len(self.filename_to_tensors) > 1 + + +def split_state_dict_into_shards_factory( + state_dict: Dict[str, TensorT], + *, + get_storage_size: TensorSizeFn_T, + filename_pattern: str, + get_storage_id: StorageIDFn_T = lambda tensor: None, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, Tensor]`): + The state dictionary to save. + get_storage_size (`Callable[[Tensor], int]`): + A function that returns the size of a tensor when saved on disk in bytes. + get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*): + A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the + same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage + during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + """ + storage_id_to_tensors: Dict[Any, List[str]] = {} + + shard_list: List[Dict[str, TensorT]] = [] + current_shard: Dict[str, TensorT] = {} + current_shard_size = 0 + total_size = 0 + + if isinstance(max_shard_size, str): + max_shard_size = parse_size_to_int(max_shard_size) + + for key, tensor in state_dict.items(): + # when bnb serialization is used the weights in the state dict can be strings + # check: https://github.com/huggingface/transformers/pull/24416 for more details + if isinstance(tensor, str): + logger.info("Skipping tensor %s as it is a string (bnb serialization)", key) + continue + + # If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block` + storage_id = get_storage_id(tensor) + if storage_id is not None: + if storage_id in storage_id_to_tensors: + # We skip this tensor for now and will reassign to correct shard later + storage_id_to_tensors[storage_id].append(key) + continue + else: + # This is the first tensor with this storage_id, we create a new entry + # in the storage_id_to_tensors dict => we will assign the shard id later + storage_id_to_tensors[storage_id] = [key] + + # Compute tensor size + tensor_size = get_storage_size(tensor) + + # If this tensor is bigger than the maximal size, we put it in its own shard + if tensor_size > max_shard_size: + total_size += tensor_size + shard_list.append({key: tensor}) + continue + + # If this tensor is going to tip up over the maximal size, we split. + # Current shard already has some tensors, we add it to the list of shards and create a new one. + if current_shard_size + tensor_size > max_shard_size: + shard_list.append(current_shard) + current_shard = {} + current_shard_size = 0 + + # Add the tensor to the current shard + current_shard[key] = tensor + current_shard_size += tensor_size + total_size += tensor_size + + # Add the last shard + if len(current_shard) > 0: + shard_list.append(current_shard) + nb_shards = len(shard_list) + + # Loop over the tensors that share the same storage and assign them together + for storage_id, keys in storage_id_to_tensors.items(): + # Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard + for shard in shard_list: + if keys[0] in shard: + for key in keys: + shard[key] = state_dict[key] + break + + # If we only have one shard, we return it => no need to build the index + if nb_shards == 1: + filename = filename_pattern.format(suffix="") + return StateDictSplit( + metadata={"total_size": total_size}, + filename_to_tensors={filename: list(state_dict.keys())}, + tensor_to_filename={key: filename for key in state_dict.keys()}, + ) + + # Now that each tensor is assigned to a shard, let's assign a filename to each shard + tensor_name_to_filename = {} + filename_to_tensors = {} + for idx, shard in enumerate(shard_list): + filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}") + for key in shard: + tensor_name_to_filename[key] = filename + filename_to_tensors[filename] = list(shard.keys()) + + # Build the index and return + return StateDictSplit( + metadata={"total_size": total_size}, + filename_to_tensors=filename_to_tensors, + tensor_to_filename=tensor_name_to_filename, + ) + + +def parse_size_to_int(size_as_str: str) -> int: + """ + Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). + + Supported units are "TB", "GB", "MB", "KB". + + Args: + size_as_str (`str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> parse_size_to_int("5MB") + 5000000 + ``` + """ + size_as_str = size_as_str.strip() + + # Parse unit + unit = size_as_str[-2:].upper() + if unit not in SIZE_UNITS: + raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") + multiplier = SIZE_UNITS[unit] + + # Parse value + try: + value = float(size_as_str[:-2].strip()) + except ValueError as e: + raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e + + return int(value * multiplier) diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/_dduf.py b/phivenv/Lib/site-packages/huggingface_hub/serialization/_dduf.py new file mode 100644 index 0000000000000000000000000000000000000000..a1debadb3ac8a45716f0359b932dc065f09edb84 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/serialization/_dduf.py @@ -0,0 +1,387 @@ +import json +import logging +import mmap +import os +import shutil +import zipfile +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Generator, Iterable, Tuple, Union + +from ..errors import DDUFCorruptedFileError, DDUFExportError, DDUFInvalidEntryNameError + + +logger = logging.getLogger(__name__) + +DDUF_ALLOWED_ENTRIES = { + # Allowed file extensions in a DDUF file + ".json", + ".model", + ".safetensors", + ".txt", +} + +DDUF_FOLDER_REQUIRED_ENTRIES = { + # Each folder must contain at least one of these entries + "config.json", + "tokenizer_config.json", + "preprocessor_config.json", + "scheduler_config.json", +} + + +@dataclass +class DDUFEntry: + """Object representing a file entry in a DDUF file. + + See [`read_dduf_file`] for how to read a DDUF file. + + Attributes: + filename (str): + The name of the file in the DDUF archive. + offset (int): + The offset of the file in the DDUF archive. + length (int): + The length of the file in the DDUF archive. + dduf_path (str): + The path to the DDUF archive (for internal use). + """ + + filename: str + length: int + offset: int + + dduf_path: Path = field(repr=False) + + @contextmanager + def as_mmap(self) -> Generator[bytes, None, None]: + """Open the file as a memory-mapped file. + + Useful to load safetensors directly from the file. + + Example: + ```py + >>> import safetensors.torch + >>> with entry.as_mmap() as mm: + ... tensors = safetensors.torch.load(mm) + ``` + """ + with self.dduf_path.open("rb") as f: + with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mm: + yield mm[self.offset : self.offset + self.length] + + def read_text(self, encoding: str = "utf-8") -> str: + """Read the file as text. + + Useful for '.txt' and '.json' entries. + + Example: + ```py + >>> import json + >>> index = json.loads(entry.read_text()) + ``` + """ + with self.dduf_path.open("rb") as f: + f.seek(self.offset) + return f.read(self.length).decode(encoding=encoding) + + +def read_dduf_file(dduf_path: Union[os.PathLike, str]) -> Dict[str, DDUFEntry]: + """ + Read a DDUF file and return a dictionary of entries. + + Only the metadata is read, the data is not loaded in memory. + + Args: + dduf_path (`str` or `os.PathLike`): + The path to the DDUF file to read. + + Returns: + `Dict[str, DDUFEntry]`: + A dictionary of [`DDUFEntry`] indexed by filename. + + Raises: + - [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format). + + Example: + ```python + >>> import json + >>> import safetensors.torch + >>> from huggingface_hub import read_dduf_file + + # Read DDUF metadata + >>> dduf_entries = read_dduf_file("FLUX.1-dev.dduf") + + # Returns a mapping filename <> DDUFEntry + >>> dduf_entries["model_index.json"] + DDUFEntry(filename='model_index.json', offset=66, length=587) + + # Load model index as JSON + >>> json.loads(dduf_entries["model_index.json"].read_text()) + {'_class_name': 'FluxPipeline', '_diffusers_version': '0.32.0.dev0', '_name_or_path': 'black-forest-labs/FLUX.1-dev', ... + + # Load VAE weights using safetensors + >>> with dduf_entries["vae/diffusion_pytorch_model.safetensors"].as_mmap() as mm: + ... state_dict = safetensors.torch.load(mm) + ``` + """ + entries = {} + dduf_path = Path(dduf_path) + logger.info(f"Reading DDUF file {dduf_path}") + with zipfile.ZipFile(str(dduf_path), "r") as zf: + for info in zf.infolist(): + logger.debug(f"Reading entry {info.filename}") + if info.compress_type != zipfile.ZIP_STORED: + raise DDUFCorruptedFileError("Data must not be compressed in DDUF file.") + + try: + _validate_dduf_entry_name(info.filename) + except DDUFInvalidEntryNameError as e: + raise DDUFCorruptedFileError(f"Invalid entry name in DDUF file: {info.filename}") from e + + offset = _get_data_offset(zf, info) + + entries[info.filename] = DDUFEntry( + filename=info.filename, offset=offset, length=info.file_size, dduf_path=dduf_path + ) + + # Consistency checks on the DDUF file + if "model_index.json" not in entries: + raise DDUFCorruptedFileError("Missing required 'model_index.json' entry in DDUF file.") + index = json.loads(entries["model_index.json"].read_text()) + _validate_dduf_structure(index, entries.keys()) + + logger.info(f"Done reading DDUF file {dduf_path}. Found {len(entries)} entries") + return entries + + +def export_entries_as_dduf( + dduf_path: Union[str, os.PathLike], entries: Iterable[Tuple[str, Union[str, Path, bytes]]] +) -> None: + """Write a DDUF file from an iterable of entries. + + This is a lower-level helper than [`export_folder_as_dduf`] that allows more flexibility when serializing data. + In particular, you don't need to save the data on disk before exporting it in the DDUF file. + + Args: + dduf_path (`str` or `os.PathLike`): + The path to the DDUF file to write. + entries (`Iterable[Tuple[str, Union[str, Path, bytes]]]`): + An iterable of entries to write in the DDUF file. Each entry is a tuple with the filename and the content. + The filename should be the path to the file in the DDUF archive. + The content can be a string or a pathlib.Path representing a path to a file on the local disk or directly the content as bytes. + + Raises: + - [`DDUFExportError`]: If anything goes wrong during the export (e.g. invalid entry name, missing 'model_index.json', etc.). + + Example: + ```python + # Export specific files from the local disk. + >>> from huggingface_hub import export_entries_as_dduf + >>> export_entries_as_dduf( + ... dduf_path="stable-diffusion-v1-4-FP16.dduf", + ... entries=[ # List entries to add to the DDUF file (here, only FP16 weights) + ... ("model_index.json", "path/to/model_index.json"), + ... ("vae/config.json", "path/to/vae/config.json"), + ... ("vae/diffusion_pytorch_model.fp16.safetensors", "path/to/vae/diffusion_pytorch_model.fp16.safetensors"), + ... ("text_encoder/config.json", "path/to/text_encoder/config.json"), + ... ("text_encoder/model.fp16.safetensors", "path/to/text_encoder/model.fp16.safetensors"), + ... # ... add more entries here + ... ] + ... ) + ``` + + ```python + # Export state_dicts one by one from a loaded pipeline + >>> from diffusers import DiffusionPipeline + >>> from typing import Generator, Tuple + >>> import safetensors.torch + >>> from huggingface_hub import export_entries_as_dduf + >>> pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + ... # ... do some work with the pipeline + + >>> def as_entries(pipe: DiffusionPipeline) -> Generator[Tuple[str, bytes], None, None]: + ... # Build an generator that yields the entries to add to the DDUF file. + ... # The first element of the tuple is the filename in the DDUF archive (must use UNIX separator!). The second element is the content of the file. + ... # Entries will be evaluated lazily when the DDUF file is created (only 1 entry is loaded in memory at a time) + ... yield "vae/config.json", pipe.vae.to_json_string().encode() + ... yield "vae/diffusion_pytorch_model.safetensors", safetensors.torch.save(pipe.vae.state_dict()) + ... yield "text_encoder/config.json", pipe.text_encoder.config.to_json_string().encode() + ... yield "text_encoder/model.safetensors", safetensors.torch.save(pipe.text_encoder.state_dict()) + ... # ... add more entries here + + >>> export_entries_as_dduf(dduf_path="stable-diffusion-v1-4.dduf", entries=as_entries(pipe)) + ``` + """ + logger.info(f"Exporting DDUF file '{dduf_path}'") + filenames = set() + index = None + with zipfile.ZipFile(str(dduf_path), "w", zipfile.ZIP_STORED) as archive: + for filename, content in entries: + if filename in filenames: + raise DDUFExportError(f"Can't add duplicate entry: {filename}") + filenames.add(filename) + + if filename == "model_index.json": + try: + index = json.loads(_load_content(content).decode()) + except json.JSONDecodeError as e: + raise DDUFExportError("Failed to parse 'model_index.json'.") from e + + try: + filename = _validate_dduf_entry_name(filename) + except DDUFInvalidEntryNameError as e: + raise DDUFExportError(f"Invalid entry name: {filename}") from e + logger.debug(f"Adding entry '{filename}' to DDUF file") + _dump_content_in_archive(archive, filename, content) + + # Consistency checks on the DDUF file + if index is None: + raise DDUFExportError("Missing required 'model_index.json' entry in DDUF file.") + try: + _validate_dduf_structure(index, filenames) + except DDUFCorruptedFileError as e: + raise DDUFExportError("Invalid DDUF file structure.") from e + + logger.info(f"Done writing DDUF file {dduf_path}") + + +def export_folder_as_dduf(dduf_path: Union[str, os.PathLike], folder_path: Union[str, os.PathLike]) -> None: + """ + Export a folder as a DDUF file. + + AUses [`export_entries_as_dduf`] under the hood. + + Args: + dduf_path (`str` or `os.PathLike`): + The path to the DDUF file to write. + folder_path (`str` or `os.PathLike`): + The path to the folder containing the diffusion model. + + Example: + ```python + >>> from huggingface_hub import export_folder_as_dduf + >>> export_folder_as_dduf(dduf_path="FLUX.1-dev.dduf", folder_path="path/to/FLUX.1-dev") + ``` + """ + folder_path = Path(folder_path) + + def _iterate_over_folder() -> Iterable[Tuple[str, Path]]: + for path in Path(folder_path).glob("**/*"): + if not path.is_file(): + continue + if path.suffix not in DDUF_ALLOWED_ENTRIES: + logger.debug(f"Skipping file '{path}' (file type not allowed)") + continue + path_in_archive = path.relative_to(folder_path) + if len(path_in_archive.parts) >= 3: + logger.debug(f"Skipping file '{path}' (nested directories not allowed)") + continue + yield path_in_archive.as_posix(), path + + export_entries_as_dduf(dduf_path, _iterate_over_folder()) + + +def _dump_content_in_archive(archive: zipfile.ZipFile, filename: str, content: Union[str, os.PathLike, bytes]) -> None: + with archive.open(filename, "w", force_zip64=True) as archive_fh: + if isinstance(content, (str, Path)): + content_path = Path(content) + with content_path.open("rb") as content_fh: + shutil.copyfileobj(content_fh, archive_fh, 1024 * 1024 * 8) # type: ignore[misc] + elif isinstance(content, bytes): + archive_fh.write(content) + else: + raise DDUFExportError(f"Invalid content type for {filename}. Must be str, Path or bytes.") + + +def _load_content(content: Union[str, Path, bytes]) -> bytes: + """Load the content of an entry as bytes. + + Used only for small checks (not to dump content into archive). + """ + if isinstance(content, (str, Path)): + return Path(content).read_bytes() + elif isinstance(content, bytes): + return content + else: + raise DDUFExportError(f"Invalid content type. Must be str, Path or bytes. Got {type(content)}.") + + +def _validate_dduf_entry_name(entry_name: str) -> str: + if "." + entry_name.split(".")[-1] not in DDUF_ALLOWED_ENTRIES: + raise DDUFInvalidEntryNameError(f"File type not allowed: {entry_name}") + if "\\" in entry_name: + raise DDUFInvalidEntryNameError(f"Entry names must use UNIX separators ('/'). Got {entry_name}.") + entry_name = entry_name.strip("/") + if entry_name.count("/") > 1: + raise DDUFInvalidEntryNameError(f"DDUF only supports 1 level of directory. Got {entry_name}.") + return entry_name + + +def _validate_dduf_structure(index: Any, entry_names: Iterable[str]) -> None: + """ + Consistency checks on the DDUF file structure. + + Rules: + - The 'model_index.json' entry is required and must contain a dictionary. + - Each folder name must correspond to an entry in 'model_index.json'. + - Each folder must contain at least a config file ('config.json', 'tokenizer_config.json', 'preprocessor_config.json', 'scheduler_config.json'). + + Args: + index (Any): + The content of the 'model_index.json' entry. + entry_names (Iterable[str]): + The list of entry names in the DDUF file. + + Raises: + - [`DDUFCorruptedFileError`]: If the DDUF file is corrupted (i.e. doesn't follow the DDUF format). + """ + if not isinstance(index, dict): + raise DDUFCorruptedFileError(f"Invalid 'model_index.json' content. Must be a dictionary. Got {type(index)}.") + + dduf_folders = {entry.split("/")[0] for entry in entry_names if "/" in entry} + for folder in dduf_folders: + if folder not in index: + raise DDUFCorruptedFileError(f"Missing required entry '{folder}' in 'model_index.json'.") + if not any(f"{folder}/{required_entry}" in entry_names for required_entry in DDUF_FOLDER_REQUIRED_ENTRIES): + raise DDUFCorruptedFileError( + f"Missing required file in folder '{folder}'. Must contains at least one of {DDUF_FOLDER_REQUIRED_ENTRIES}." + ) + + +def _get_data_offset(zf: zipfile.ZipFile, info: zipfile.ZipInfo) -> int: + """ + Calculate the data offset for a file in a ZIP archive. + + Args: + zf (`zipfile.ZipFile`): + The opened ZIP file. Must be opened in read mode. + info (`zipfile.ZipInfo`): + The file info. + + Returns: + int: The offset of the file data in the ZIP archive. + """ + if zf.fp is None: + raise DDUFCorruptedFileError("ZipFile object must be opened in read mode.") + + # Step 1: Get the local file header offset + header_offset = info.header_offset + + # Step 2: Read the local file header + zf.fp.seek(header_offset) + local_file_header = zf.fp.read(30) # Fixed-size part of the local header + + if len(local_file_header) < 30: + raise DDUFCorruptedFileError("Incomplete local file header.") + + # Step 3: Parse the header fields to calculate the start of file data + # Local file header: https://en.wikipedia.org/wiki/ZIP_(file_format)#File_headers + filename_len = int.from_bytes(local_file_header[26:28], "little") + extra_field_len = int.from_bytes(local_file_header[28:30], "little") + + # Data offset is after the fixed header, filename, and extra fields + data_offset = header_offset + 30 + filename_len + extra_field_len + + return data_offset diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/_tensorflow.py b/phivenv/Lib/site-packages/huggingface_hub/serialization/_tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..59ed8110b28f4891d67e754fdfbfa47a26f85be1 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/serialization/_tensorflow.py @@ -0,0 +1,95 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains tensorflow-specific helpers.""" + +import math +import re +from typing import TYPE_CHECKING, Dict, Union + +from .. import constants +from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory + + +if TYPE_CHECKING: + import tensorflow as tf + + +def split_tf_state_dict_into_shards( + state_dict: Dict[str, "tf.Tensor"], + *, + filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, Tensor]`): + The state dictionary to save. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"tf_model{suffix}.h5"`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + """ + return split_state_dict_into_shards_factory( + state_dict, + max_shard_size=max_shard_size, + filename_pattern=filename_pattern, + get_storage_size=get_tf_storage_size, + ) + + +def get_tf_storage_size(tensor: "tf.Tensor") -> int: + # Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool). + # Better to overestimate than underestimate. + return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype)) + + +def _dtype_byte_size_tf(dtype) -> float: + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + Taken from https://github.com/huggingface/transformers/blob/74d9d0cebb0263a3f8ab9c280569170cc74651d0/src/transformers/modeling_tf_utils.py#L608. + NOTE: why not `tensor.numpy().nbytes`? + Example: + ```py + >>> _dtype_byte_size(tf.float32) + 4 + ``` + """ + import tensorflow as tf + + if dtype == tf.bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)$", dtype.name) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 diff --git a/phivenv/Lib/site-packages/huggingface_hub/serialization/_torch.py b/phivenv/Lib/site-packages/huggingface_hub/serialization/_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c70fc89b58b323756bd33d38cb3f604aaf0da0 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/serialization/_torch.py @@ -0,0 +1,1033 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains pytorch-specific helpers.""" + +import importlib +import json +import os +import re +from collections import defaultdict, namedtuple +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union + +from packaging import version + +from .. import constants, logging +from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory + + +logger = logging.get_logger(__file__) + +if TYPE_CHECKING: + import torch + +# SAVING + + +def save_torch_model( + model: "torch.nn.Module", + save_directory: Union[str, Path], + *, + filename_pattern: Optional[str] = None, + force_contiguous: bool = True, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, + metadata: Optional[Dict[str, str]] = None, + safe_serialization: bool = True, + is_main_process: bool = True, + shared_tensors_to_discard: Optional[List[str]] = None, +): + """ + Saves a given torch model to disk, handling sharding and shared tensors issues. + + See also [`save_torch_state_dict`] to save a state dict with more flexibility. + + For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). + + The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are + saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, + an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses + [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as + safetensors (the default). Otherwise, the shards are saved as pickle. + + Before saving the model, the `save_directory` is cleaned from any previous shard files. + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + + + If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving. + + + + Args: + model (`torch.nn.Module`): + The model to save on disk. + save_directory (`str` or `Path`): + The directory in which the model will be saved. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` + parameter. + force_contiguous (`boolean`, *optional*): + Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the + model, but it could potentially change performance if the layout of the tensor was chosen specifically for + that reason. Defaults to `True`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the model. Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire shared structure but might help understanding + things. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. Defaults to True. + shared_tensors_to_discard (`List[str]`, *optional*): + List of tensor names to drop when saving shared tensors. If not provided and shared tensors are + detected, it will drop the first name alphabetically. + + Example: + + ```py + >>> from huggingface_hub import save_torch_model + >>> model = ... # A PyTorch model + + # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. + >>> save_torch_model(model, "path/to/folder") + + # Load model back + >>> from huggingface_hub import load_torch_model # TODO + >>> load_torch_model(model, "path/to/folder") + >>> + ``` + """ + save_torch_state_dict( + state_dict=model.state_dict(), + filename_pattern=filename_pattern, + force_contiguous=force_contiguous, + max_shard_size=max_shard_size, + metadata=metadata, + safe_serialization=safe_serialization, + save_directory=save_directory, + is_main_process=is_main_process, + shared_tensors_to_discard=shared_tensors_to_discard, + ) + + +def save_torch_state_dict( + state_dict: Dict[str, "torch.Tensor"], + save_directory: Union[str, Path], + *, + filename_pattern: Optional[str] = None, + force_contiguous: bool = True, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, + metadata: Optional[Dict[str, str]] = None, + safe_serialization: bool = True, + is_main_process: bool = True, + shared_tensors_to_discard: Optional[List[str]] = None, +) -> None: + """ + Save a model state dictionary to the disk, handling sharding and shared tensors issues. + + See also [`save_torch_model`] to directly save a PyTorch model. + + For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). + + The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are + saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, + an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses + [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as + safetensors (the default). Otherwise, the shards are saved as pickle. + + Before saving the model, the `save_directory` is cleaned from any previous shard files. + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + + + If your model is a `transformers.PreTrainedModel`, you should pass `model._tied_weights_keys` as `shared_tensors_to_discard` to properly handle shared tensors saving. This ensures the correct duplicate tensors are discarded during saving. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): + The state dictionary to save. + save_directory (`str` or `Path`): + The directory in which the model will be saved. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` + parameter. + force_contiguous (`boolean`, *optional*): + Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the + model, but it could potentially change performance if the layout of the tensor was chosen specifically for + that reason. Defaults to `True`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the model. Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire shared structure but might help understanding + things. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. + is_main_process (`bool`, *optional*): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. Defaults to True. + shared_tensors_to_discard (`List[str]`, *optional*): + List of tensor names to drop when saving shared tensors. If not provided and shared tensors are + detected, it will drop the first name alphabetically. + + Example: + + ```py + >>> from huggingface_hub import save_torch_state_dict + >>> model = ... # A PyTorch model + + # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. + >>> state_dict = model_to_save.state_dict() + >>> save_torch_state_dict(state_dict, "path/to/folder") + ``` + """ + save_directory = str(save_directory) + + if filename_pattern is None: + filename_pattern = ( + constants.SAFETENSORS_WEIGHTS_FILE_PATTERN + if safe_serialization + else constants.PYTORCH_WEIGHTS_FILE_PATTERN + ) + + if metadata is None: + metadata = {} + if safe_serialization: + try: + from safetensors.torch import save_file as save_file_fn + except ImportError as e: + raise ImportError( + "Please install `safetensors` to use safe serialization. " + "You can install it with `pip install safetensors`." + ) from e + # Clean state dict for safetensors + state_dict = _clean_state_dict_for_safetensors( + state_dict, + metadata, + force_contiguous=force_contiguous, + shared_tensors_to_discard=shared_tensors_to_discard, + ) + else: + from torch import save as save_file_fn # type: ignore[assignment, no-redef] + + logger.warning( + "You are using unsafe serialization. Due to security reasons, it is recommended not to load " + "pickled models from untrusted sources. If you intend to share your model, we strongly recommend " + "using safe serialization by installing `safetensors` with `pip install safetensors`." + ) + # Split dict + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + + # Only main process should clean up existing files to avoid race conditions in distributed environment + if is_main_process: + existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?") + for filename in os.listdir(save_directory): + if existing_files_regex.match(filename): + try: + logger.debug(f"Removing existing file '{filename}' from folder.") + os.remove(os.path.join(save_directory, filename)) + except Exception as e: + logger.warning( + f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..." + ) + + # Save each shard + per_file_metadata = {"format": "pt"} + if not state_dict_split.is_sharded: + per_file_metadata.update(metadata) + safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {} + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) + logger.debug(f"Shard saved to {filename}") + + # Save the index (if any) + if state_dict_split.is_sharded: + index_path = filename_pattern.format(suffix="") + ".index.json" + index = { + "metadata": {**state_dict_split.metadata, **metadata}, + "weight_map": state_dict_split.tensor_to_filename, + } + with open(os.path.join(save_directory, index_path), "w") as f: + json.dump(index, f, indent=2) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). " + f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. " + f"You can find where each parameters has been saved in the index located at {index_path}." + ) + + logger.info(f"Model weights successfully saved to {save_directory}!") + + +def split_torch_state_dict_into_shards( + state_dict: Dict[str, "torch.Tensor"], + *, + filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + + To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses + `split_torch_state_dict_into_shards` under the hood. + + + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): + The state dictionary to save. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + + Example: + ```py + >>> import json + >>> import os + >>> from safetensors.torch import save_file as safe_save_file + >>> from huggingface_hub import split_torch_state_dict_into_shards + + >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): + ... state_dict_split = split_torch_state_dict_into_shards(state_dict) + ... for filename, tensors in state_dict_split.filename_to_tensors.items(): + ... shard = {tensor: state_dict[tensor] for tensor in tensors} + ... safe_save_file( + ... shard, + ... os.path.join(save_directory, filename), + ... metadata={"format": "pt"}, + ... ) + ... if state_dict_split.is_sharded: + ... index = { + ... "metadata": state_dict_split.metadata, + ... "weight_map": state_dict_split.tensor_to_filename, + ... } + ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + ... f.write(json.dumps(index, indent=2)) + ``` + """ + return split_state_dict_into_shards_factory( + state_dict, + max_shard_size=max_shard_size, + filename_pattern=filename_pattern, + get_storage_size=get_torch_storage_size, + get_storage_id=get_torch_storage_id, + ) + + +# LOADING + + +def load_torch_model( + model: "torch.nn.Module", + checkpoint_path: Union[str, os.PathLike], + *, + strict: bool = False, + safe: bool = True, + weights_only: bool = False, + map_location: Optional[Union[str, "torch.device"]] = None, + mmap: bool = False, + filename_pattern: Optional[str] = None, +) -> NamedTuple: + """ + Load a checkpoint into a model, handling both sharded and non-sharded checkpoints. + + Args: + model (`torch.nn.Module`): + The model in which to load the checkpoint. + checkpoint_path (`str` or `os.PathLike`): + Path to either the checkpoint file or directory containing the checkpoint(s). + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the model state dict match the keys in the checkpoint. + safe (`bool`, *optional*, defaults to `True`): + If `safe` is True, the safetensors files will be loaded. If `safe` is False, the function + will first attempt to load safetensors files if they are available, otherwise it will fall back to loading + pickle files. `filename_pattern` parameter takes precedence over `safe` parameter. + weights_only (`bool`, *optional*, defaults to `False`): + If True, only loads the model weights without optimizer states and other metadata. + Only supported in PyTorch >= 1.13. + map_location (`str` or `torch.device`, *optional*): + A `torch.device` object, string or a dict specifying how to remap storage locations. It + indicates the location where all tensors should be loaded. + mmap (`bool`, *optional*, defaults to `False`): + Whether to use memory-mapped file loading. Memory mapping can improve loading performance + for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. + filename_pattern (`str`, *optional*): + The pattern to look for the index file. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + Returns: + `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields. + - `missing_keys` is a list of str containing the missing keys, i.e. keys that are in the model but not in the checkpoint. + - `unexpected_keys` is a list of str containing the unexpected keys, i.e. keys that are in the checkpoint but not in the model. + + Raises: + [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) + If the checkpoint file or directory does not exist. + [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the checkpoint path is invalid or if the checkpoint format cannot be determined. + + Example: + ```python + >>> from huggingface_hub import load_torch_model + >>> model = ... # A PyTorch model + >>> load_torch_model(model, "path/to/checkpoint") + ``` + """ + checkpoint_path = Path(checkpoint_path) + + if not checkpoint_path.exists(): + raise ValueError(f"Checkpoint path {checkpoint_path} does not exist") + # 1. Check if checkpoint is a single file + if checkpoint_path.is_file(): + state_dict = load_state_dict_from_file( + checkpoint_file=checkpoint_path, + map_location=map_location, + weights_only=weights_only, + ) + return model.load_state_dict(state_dict, strict=strict) + + # 2. If not, checkpoint_path is a directory + if filename_pattern is None: + filename_pattern = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN + index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json") + # Only fallback to pickle format if safetensors index is not found and safe is False. + if not index_path.is_file() and not safe: + filename_pattern = constants.PYTORCH_WEIGHTS_FILE_PATTERN + + index_path = checkpoint_path / (filename_pattern.format(suffix="") + ".index.json") + + if index_path.is_file(): + return _load_sharded_checkpoint( + model=model, + save_directory=checkpoint_path, + strict=strict, + weights_only=weights_only, + filename_pattern=filename_pattern, + ) + + # Look for single model file + model_files = list(checkpoint_path.glob("*.safetensors" if safe else "*.bin")) + if len(model_files) == 1: + state_dict = load_state_dict_from_file( + checkpoint_file=model_files[0], + map_location=map_location, + weights_only=weights_only, + mmap=mmap, + ) + return model.load_state_dict(state_dict, strict=strict) + + raise ValueError( + f"Directory '{checkpoint_path}' does not contain a valid checkpoint. " + "Expected either a sharded checkpoint with an index file, or a single model file." + ) + + +def _load_sharded_checkpoint( + model: "torch.nn.Module", + save_directory: os.PathLike, + *, + strict: bool = False, + weights_only: bool = False, + filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, +) -> NamedTuple: + """ + Loads a sharded checkpoint into a model. This is the same as + [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) + but for a sharded checkpoint. Each shard is loaded one by one and removed from memory after being loaded into the model. + + Args: + model (`torch.nn.Module`): + The model in which to load the checkpoint. + save_directory (`str` or `os.PathLike`): + A path to a folder containing the sharded checkpoint. + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + weights_only (`bool`, *optional*, defaults to `False`): + If True, only loads the model weights without optimizer states and other metadata. + Only supported in PyTorch >= 1.13. + filename_pattern (`str`, *optional*, defaults to `"model{suffix}.safetensors"`): + The pattern to look for the index file. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + + Returns: + `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields, + - `missing_keys` is a list of str containing the missing keys + - `unexpected_keys` is a list of str containing the unexpected keys + """ + + # 1. Load and validate index file + # The index file contains mapping of parameter names to shard files + index_path = filename_pattern.format(suffix="") + ".index.json" + index_file = os.path.join(save_directory, index_path) + with open(index_file, "r", encoding="utf-8") as f: + index = json.load(f) + + # 2. Validate keys if in strict mode + # This is done before loading any shards to fail fast + if strict: + _validate_keys_for_strict_loading(model, index["weight_map"].keys()) + + # 3. Load each shard using `load_state_dict` + # Get unique shard files (multiple parameters can be in same shard) + shard_files = list(set(index["weight_map"].values())) + for shard_file in shard_files: + # Load shard into memory + shard_path = os.path.join(save_directory, shard_file) + state_dict = load_state_dict_from_file( + shard_path, + map_location="cpu", + weights_only=weights_only, + ) + # Update model with parameters from this shard + model.load_state_dict(state_dict, strict=strict) + # Explicitly remove the state dict from memory + del state_dict + + # 4. Return compatibility info + loaded_keys = set(index["weight_map"].keys()) + model_keys = set(model.state_dict().keys()) + return _IncompatibleKeys( + missing_keys=list(model_keys - loaded_keys), unexpected_keys=list(loaded_keys - model_keys) + ) + + +def load_state_dict_from_file( + checkpoint_file: Union[str, os.PathLike], + map_location: Optional[Union[str, "torch.device"]] = None, + weights_only: bool = False, + mmap: bool = False, +) -> Union[Dict[str, "torch.Tensor"], Any]: + """ + Loads a checkpoint file, handling both safetensors and pickle checkpoint formats. + + Args: + checkpoint_file (`str` or `os.PathLike`): + Path to the checkpoint file to load. Can be either a safetensors or pickle (`.bin`) checkpoint. + map_location (`str` or `torch.device`, *optional*): + A `torch.device` object, string or a dict specifying how to remap storage locations. It + indicates the location where all tensors should be loaded. + weights_only (`bool`, *optional*, defaults to `False`): + If True, only loads the model weights without optimizer states and other metadata. + Only supported for pickle (`.bin`) checkpoints with PyTorch >= 1.13. Has no effect when + loading safetensors files. + mmap (`bool`, *optional*, defaults to `False`): + Whether to use memory-mapped file loading. Memory mapping can improve loading performance + for large models in PyTorch >= 2.1.0 with zipfile-based checkpoints. Has no effect when + loading safetensors files, as the `safetensors` library uses memory mapping by default. + + Returns: + `Union[Dict[str, "torch.Tensor"], Any]`: The loaded checkpoint. + - For safetensors files: always returns a dictionary mapping parameter names to tensors. + - For pickle files: returns any Python object that was pickled (commonly a state dict, but could be + an entire model, optimizer state, or any other Python object). + + Raises: + [`FileNotFoundError`](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) + If the checkpoint file does not exist. + [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError) + If safetensors or torch is not installed when trying to load a .safetensors file or a PyTorch checkpoint respectively. + [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + If the checkpoint file format is invalid or if git-lfs files are not properly downloaded. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the checkpoint file path is empty or invalid. + + Example: + ```python + >>> from huggingface_hub import load_state_dict_from_file + + # Load a PyTorch checkpoint + >>> state_dict = load_state_dict_from_file("path/to/model.bin", map_location="cpu") + >>> model.load_state_dict(state_dict) + + # Load a safetensors checkpoint + >>> state_dict = load_state_dict_from_file("path/to/model.safetensors") + >>> model.load_state_dict(state_dict) + ``` + """ + checkpoint_path = Path(checkpoint_file) + + # Check if file exists and is a regular file (not a directory) + if not checkpoint_path.is_file(): + raise FileNotFoundError( + f"No checkpoint file found at '{checkpoint_path}'. Please verify the path is correct and " + "the file has been properly downloaded." + ) + + # Load safetensors checkpoint + if checkpoint_path.suffix == ".safetensors": + try: + from safetensors import safe_open + from safetensors.torch import load_file + except ImportError as e: + raise ImportError( + "Please install `safetensors` to load safetensors checkpoint. " + "You can install it with `pip install safetensors`." + ) from e + + # Check format of the archive + with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined] + metadata = f.metadata() + # see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966 + if metadata is not None and metadata.get("format") not in ["pt", "mlx"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_torch_model` method." + ) + device = str(map_location.type) if map_location is not None and hasattr(map_location, "type") else map_location + # meta device is not supported with safetensors, falling back to CPU + if device == "meta": + logger.warning("Meta device is not supported with safetensors. Falling back to CPU device.") + device = "cpu" + return load_file(checkpoint_file, device=device) # type: ignore[arg-type] + # Otherwise, load from pickle + try: + import torch + from torch import load + except ImportError as e: + raise ImportError( + "Please install `torch` to load torch tensors. You can install it with `pip install torch`." + ) from e + # Add additional kwargs, mmap is only supported in torch >= 2.1.0 + additional_kwargs = {} + if version.parse(torch.__version__) >= version.parse("2.1.0"): + additional_kwargs["mmap"] = mmap + + # weights_only is only supported in torch >= 1.13.0 + if version.parse(torch.__version__) >= version.parse("1.13.0"): + additional_kwargs["weights_only"] = weights_only + + return load( + checkpoint_file, + map_location=map_location, + **additional_kwargs, + ) + + +# HELPERS + + +def _validate_keys_for_strict_loading( + model: "torch.nn.Module", + loaded_keys: Iterable[str], +) -> None: + """ + Validate that model keys match loaded keys when strict loading is enabled. + + Args: + model: The PyTorch model being loaded + loaded_keys: The keys present in the checkpoint + + Raises: + RuntimeError: If there are missing or unexpected keys in strict mode + """ + loaded_keys_set = set(loaded_keys) + model_keys = set(model.state_dict().keys()) + missing_keys = model_keys - loaded_keys_set # Keys in model but not in checkpoint + unexpected_keys = loaded_keys_set - model_keys # Keys in checkpoint but not in model + + if missing_keys or unexpected_keys: + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if missing_keys: + str_missing_keys = ",".join([f'"{k}"' for k in sorted(missing_keys)]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if unexpected_keys: + str_unexpected_keys = ",".join([f'"{k}"' for k in sorted(unexpected_keys)]) + error_message += f"\nUnexpected key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + +def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: + """Returns a unique id for plain tensor + or a (potentially nested) Tuple of unique id for the flattened Tensor + if the input is a wrapper tensor subclass Tensor + """ + + try: + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + local_tensor = tensor.to_local() + return local_tensor.storage().data_ptr() + except ImportError: + pass + + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs) + + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + if tensor.device.type == "xla" and is_torch_tpu_available(): + # NOTE: xla tensors dont have storage + # use some other unique id to distinguish. + # this is a XLA tensor, it must be created using torch_xla's + # device. So the following import is safe: + import torch_xla # type: ignore[import] + + unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) + else: + unique_id = storage_ptr(tensor) + + return unique_id + + +def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]: + """ + Return unique identifier to a tensor storage. + + Multiple different tensors can share the same underlying storage. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + In the case of meta tensors, we return None since we can't tell if they share the same storage. + + Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278. + """ + if tensor.device.type == "meta": + return None + else: + return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor) + + +def get_torch_storage_size(tensor: "torch.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 + """ + try: + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + # this returns the size of the FULL tensor in bytes + return tensor.nbytes + except ImportError: + pass + + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + try: + return tensor.untyped_storage().nbytes() + except AttributeError: + # Fallback for torch==1.10 + try: + return tensor.storage().size() * _get_dtype_size(tensor.dtype) + except NotImplementedError: + # Fallback for meta storage + # On torch >=2.0 this is the tensor size + return tensor.nelement() * _get_dtype_size(tensor.dtype) + + +@lru_cache() +def is_torch_tpu_available(check_device=True): + """ + Checks if `torch_xla` is installed and potentially if a TPU is in the environment + + Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463. + """ + if importlib.util.find_spec("torch_xla") is not None: + if check_device: + # We need to check if `xla_device` can be found, will raise a RuntimeError if not + try: + import torch_xla.core.xla_model as xm # type: ignore[import] + + _ = xm.xla_device() + return True + except RuntimeError: + return False + return True + return False + + +def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11. + """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + return _get_unique_id(tensor) # type: ignore + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + try: + return tensor.untyped_storage().data_ptr() + except Exception: + # Fallback for torch==1.10 + try: + return tensor.storage().data_ptr() + except NotImplementedError: + # Fallback for meta storage + return 0 + + +def _clean_state_dict_for_safetensors( + state_dict: Dict[str, "torch.Tensor"], + metadata: Dict[str, str], + force_contiguous: bool = True, + shared_tensors_to_discard: Optional[List[str]] = None, +): + """Remove shared tensors from state_dict and update metadata accordingly (for reloading). + + Warning: `state_dict` and `metadata` are mutated in-place! + + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155. + """ + to_removes = _remove_duplicate_names(state_dict, discard_names=shared_tensors_to_discard) + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if metadata is None: + metadata = {} + + if to_remove not in metadata: + # Do not override user data + metadata[to_remove] = kept_name + del state_dict[to_remove] + if force_contiguous: + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + return state_dict + + +def _end_ptr(tensor: "torch.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23. + """ + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype) + else: + stop = tensor.data_ptr() + return stop + + +def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44 + """ + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + + return filtered_tensors + + +def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69. + """ + import torch + + tensors_dict = defaultdict(set) + for k, v in state_dict.items(): + if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0: + # Need to add device as key because of multiple GPU. + tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k) + tensors = list(sorted(tensors_dict.values())) + tensors = _filter_shared_not_shared(tensors, state_dict) + return tensors + + +def _is_complete(tensor: "torch.Tensor") -> bool: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 + """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return all(_is_complete(getattr(tensor, attr)) for attr in attrs) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size( + tensor.dtype + ) == get_torch_storage_size(tensor) + + +def _remove_duplicate_names( + state_dict: Dict[str, "torch.Tensor"], + *, + preferred_names: Optional[List[str]] = None, + discard_names: Optional[List[str]] = None, +) -> Dict[str, List[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 + """ + if preferred_names is None: + preferred_names = [] + unique_preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + unique_discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + raise RuntimeError( + "Error while trying to find names to remove to save state dict, but found no suitable name to keep" + f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model" + " since you could be storing much more memory than needed. Please refer to" + " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an" + " issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(unique_discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if unique_preferred_names: + preferred = unique_preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +@lru_cache() +def _get_dtype_size(dtype: "torch.dtype") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344 + """ + import torch + + # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions + _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None) + _float8_e5m2 = getattr(torch, "float8_e5m2", None) + _SIZE = { + torch.int64: 8, + torch.float32: 4, + torch.int32: 4, + torch.bfloat16: 2, + torch.float16: 2, + torch.int16: 2, + torch.uint8: 1, + torch.int8: 1, + torch.bool: 1, + torch.float64: 8, + _float8_e4m3fn: 1, + _float8_e5m2: 1, + } + return _SIZE[dtype] + + +class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])): + """ + This is used to report missing and unexpected keys in the state dict. + Taken from https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L52. + + """ + + def __repr__(self) -> str: + if not self.missing_keys and not self.unexpected_keys: + return "" + return super().__repr__() + + __str__ = __repr__ diff --git a/phivenv/Lib/site-packages/huggingface_hub/templates/datasetcard_template.md b/phivenv/Lib/site-packages/huggingface_hub/templates/datasetcard_template.md new file mode 100644 index 0000000000000000000000000000000000000000..9af29ebbed93653ec74a8952e314e7554323ef15 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/templates/datasetcard_template.md @@ -0,0 +1,143 @@ +--- +# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/datasets-cards +{{ card_data }} +--- + +# Dataset Card for {{ pretty_name | default("Dataset Name", true) }} + + + +{{ dataset_summary | default("", true) }} + +## Dataset Details + +### Dataset Description + + + +{{ dataset_description | default("", true) }} + +- **Curated by:** {{ curators | default("[More Information Needed]", true)}} +- **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}} +- **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}} +- **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}} +- **License:** {{ license | default("[More Information Needed]", true)}} + +### Dataset Sources [optional] + + + +- **Repository:** {{ repo | default("[More Information Needed]", true)}} +- **Paper [optional]:** {{ paper | default("[More Information Needed]", true)}} +- **Demo [optional]:** {{ demo | default("[More Information Needed]", true)}} + +## Uses + + + +### Direct Use + + + +{{ direct_use | default("[More Information Needed]", true)}} + +### Out-of-Scope Use + + + +{{ out_of_scope_use | default("[More Information Needed]", true)}} + +## Dataset Structure + + + +{{ dataset_structure | default("[More Information Needed]", true)}} + +## Dataset Creation + +### Curation Rationale + + + +{{ curation_rationale_section | default("[More Information Needed]", true)}} + +### Source Data + + + +#### Data Collection and Processing + + + +{{ data_collection_and_processing_section | default("[More Information Needed]", true)}} + +#### Who are the source data producers? + + + +{{ source_data_producers_section | default("[More Information Needed]", true)}} + +### Annotations [optional] + + + +#### Annotation process + + + +{{ annotation_process_section | default("[More Information Needed]", true)}} + +#### Who are the annotators? + + + +{{ who_are_annotators_section | default("[More Information Needed]", true)}} + +#### Personal and Sensitive Information + + + +{{ personal_and_sensitive_information | default("[More Information Needed]", true)}} + +## Bias, Risks, and Limitations + + + +{{ bias_risks_limitations | default("[More Information Needed]", true)}} + +### Recommendations + + + +{{ bias_recommendations | default("Users should be made aware of the risks, biases and limitations of the dataset. More information needed for further recommendations.", true)}} + +## Citation [optional] + + + +**BibTeX:** + +{{ citation_bibtex | default("[More Information Needed]", true)}} + +**APA:** + +{{ citation_apa | default("[More Information Needed]", true)}} + +## Glossary [optional] + + + +{{ glossary | default("[More Information Needed]", true)}} + +## More Information [optional] + +{{ more_information | default("[More Information Needed]", true)}} + +## Dataset Card Authors [optional] + +{{ dataset_card_authors | default("[More Information Needed]", true)}} + +## Dataset Card Contact + +{{ dataset_card_contact | default("[More Information Needed]", true)}} diff --git a/phivenv/Lib/site-packages/huggingface_hub/templates/modelcard_template.md b/phivenv/Lib/site-packages/huggingface_hub/templates/modelcard_template.md new file mode 100644 index 0000000000000000000000000000000000000000..79ca15e4547debac763b390ef8e4b715e6f6403f --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/templates/modelcard_template.md @@ -0,0 +1,200 @@ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +{{ card_data }} +--- + +# Model Card for {{ model_id | default("Model ID", true) }} + + + +{{ model_summary | default("", true) }} + +## Model Details + +### Model Description + + + +{{ model_description | default("", true) }} + +- **Developed by:** {{ developers | default("[More Information Needed]", true)}} +- **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}} +- **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}} +- **Model type:** {{ model_type | default("[More Information Needed]", true)}} +- **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}} +- **License:** {{ license | default("[More Information Needed]", true)}} +- **Finetuned from model [optional]:** {{ base_model | default("[More Information Needed]", true)}} + +### Model Sources [optional] + + + +- **Repository:** {{ repo | default("[More Information Needed]", true)}} +- **Paper [optional]:** {{ paper | default("[More Information Needed]", true)}} +- **Demo [optional]:** {{ demo | default("[More Information Needed]", true)}} + +## Uses + + + +### Direct Use + + + +{{ direct_use | default("[More Information Needed]", true)}} + +### Downstream Use [optional] + + + +{{ downstream_use | default("[More Information Needed]", true)}} + +### Out-of-Scope Use + + + +{{ out_of_scope_use | default("[More Information Needed]", true)}} + +## Bias, Risks, and Limitations + + + +{{ bias_risks_limitations | default("[More Information Needed]", true)}} + +### Recommendations + + + +{{ bias_recommendations | default("Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.", true)}} + +## How to Get Started with the Model + +Use the code below to get started with the model. + +{{ get_started_code | default("[More Information Needed]", true)}} + +## Training Details + +### Training Data + + + +{{ training_data | default("[More Information Needed]", true)}} + +### Training Procedure + + + +#### Preprocessing [optional] + +{{ preprocessing | default("[More Information Needed]", true)}} + + +#### Training Hyperparameters + +- **Training regime:** {{ training_regime | default("[More Information Needed]", true)}} + +#### Speeds, Sizes, Times [optional] + + + +{{ speeds_sizes_times | default("[More Information Needed]", true)}} + +## Evaluation + + + +### Testing Data, Factors & Metrics + +#### Testing Data + + + +{{ testing_data | default("[More Information Needed]", true)}} + +#### Factors + + + +{{ testing_factors | default("[More Information Needed]", true)}} + +#### Metrics + + + +{{ testing_metrics | default("[More Information Needed]", true)}} + +### Results + +{{ results | default("[More Information Needed]", true)}} + +#### Summary + +{{ results_summary | default("", true) }} + +## Model Examination [optional] + + + +{{ model_examination | default("[More Information Needed]", true)}} + +## Environmental Impact + + + +Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). + +- **Hardware Type:** {{ hardware_type | default("[More Information Needed]", true)}} +- **Hours used:** {{ hours_used | default("[More Information Needed]", true)}} +- **Cloud Provider:** {{ cloud_provider | default("[More Information Needed]", true)}} +- **Compute Region:** {{ cloud_region | default("[More Information Needed]", true)}} +- **Carbon Emitted:** {{ co2_emitted | default("[More Information Needed]", true)}} + +## Technical Specifications [optional] + +### Model Architecture and Objective + +{{ model_specs | default("[More Information Needed]", true)}} + +### Compute Infrastructure + +{{ compute_infrastructure | default("[More Information Needed]", true)}} + +#### Hardware + +{{ hardware_requirements | default("[More Information Needed]", true)}} + +#### Software + +{{ software | default("[More Information Needed]", true)}} + +## Citation [optional] + + + +**BibTeX:** + +{{ citation_bibtex | default("[More Information Needed]", true)}} + +**APA:** + +{{ citation_apa | default("[More Information Needed]", true)}} + +## Glossary [optional] + + + +{{ glossary | default("[More Information Needed]", true)}} + +## More Information [optional] + +{{ more_information | default("[More Information Needed]", true)}} + +## Model Card Authors [optional] + +{{ model_card_authors | default("[More Information Needed]", true)}} + +## Model Card Contact + +{{ model_card_contact | default("[More Information Needed]", true)}} diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_auth.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_auth.py new file mode 100644 index 0000000000000000000000000000000000000000..72be4dedbd94421ee2b4b2ba1073569d71b50569 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_auth.py @@ -0,0 +1,214 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains an helper to get the token from machine (env variable, secret or config file).""" + +import configparser +import logging +import os +import warnings +from pathlib import Path +from threading import Lock +from typing import Dict, Optional + +from .. import constants +from ._runtime import is_colab_enterprise, is_google_colab + + +_IS_GOOGLE_COLAB_CHECKED = False +_GOOGLE_COLAB_SECRET_LOCK = Lock() +_GOOGLE_COLAB_SECRET: Optional[str] = None + +logger = logging.getLogger(__name__) + + +def get_token() -> Optional[str]: + """ + Get token if user is logged in. + + Note: in most cases, you should use [`huggingface_hub.utils.build_hf_headers`] instead. This method is only useful + if you want to retrieve the token for other purposes than sending an HTTP request. + + Token is retrieved in priority from the `HF_TOKEN` environment variable. Otherwise, we read the token file located + in the Hugging Face home folder. Returns None if user is not logged in. To log in, use [`login`] or + `hf auth login`. + + Returns: + `str` or `None`: The token, `None` if it doesn't exist. + """ + return _get_token_from_google_colab() or _get_token_from_environment() or _get_token_from_file() + + +def _get_token_from_google_colab() -> Optional[str]: + """Get token from Google Colab secrets vault using `google.colab.userdata.get(...)`. + + Token is read from the vault only once per session and then stored in a global variable to avoid re-requesting + access to the vault. + """ + # If it's not a Google Colab or it's Colab Enterprise, fallback to environment variable or token file authentication + if not is_google_colab() or is_colab_enterprise(): + return None + + # `google.colab.userdata` is not thread-safe + # This can lead to a deadlock if multiple threads try to access it at the same time + # (typically when using `snapshot_download`) + # => use a lock + # See https://github.com/huggingface/huggingface_hub/issues/1952 for more details. + with _GOOGLE_COLAB_SECRET_LOCK: + global _GOOGLE_COLAB_SECRET + global _IS_GOOGLE_COLAB_CHECKED + + if _IS_GOOGLE_COLAB_CHECKED: # request access only once + return _GOOGLE_COLAB_SECRET + + try: + from google.colab import userdata # type: ignore + from google.colab.errors import Error as ColabError # type: ignore + except ImportError: + return None + + try: + token = userdata.get("HF_TOKEN") + _GOOGLE_COLAB_SECRET = _clean_token(token) + except userdata.NotebookAccessError: + # Means the user has a secret call `HF_TOKEN` and got a popup "please grand access to HF_TOKEN" and refused it + # => warn user but ignore error => do not re-request access to user + warnings.warn( + "\nAccess to the secret `HF_TOKEN` has not been granted on this notebook." + "\nYou will not be requested again." + "\nPlease restart the session if you want to be prompted again." + ) + _GOOGLE_COLAB_SECRET = None + except userdata.SecretNotFoundError: + # Means the user did not define a `HF_TOKEN` secret => warn + warnings.warn( + "\nThe secret `HF_TOKEN` does not exist in your Colab secrets." + "\nTo authenticate with the Hugging Face Hub, create a token in your settings tab " + "(https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session." + "\nYou will be able to reuse this secret in all of your notebooks." + "\nPlease note that authentication is recommended but still optional to access public models or datasets." + ) + _GOOGLE_COLAB_SECRET = None + except ColabError as e: + # Something happen but we don't know what => recommend to open a GitHub issue + warnings.warn( + f"\nError while fetching `HF_TOKEN` secret value from your vault: '{str(e)}'." + "\nYou are not authenticated with the Hugging Face Hub in this notebook." + "\nIf the error persists, please let us know by opening an issue on GitHub " + "(https://github.com/huggingface/huggingface_hub/issues/new)." + ) + _GOOGLE_COLAB_SECRET = None + + _IS_GOOGLE_COLAB_CHECKED = True + return _GOOGLE_COLAB_SECRET + + +def _get_token_from_environment() -> Optional[str]: + # `HF_TOKEN` has priority (keep `HUGGING_FACE_HUB_TOKEN` for backward compatibility) + return _clean_token(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")) + + +def _get_token_from_file() -> Optional[str]: + try: + return _clean_token(Path(constants.HF_TOKEN_PATH).read_text()) + except FileNotFoundError: + return None + + +def get_stored_tokens() -> Dict[str, str]: + """ + Returns the parsed INI file containing the access tokens. + The file is located at `HF_STORED_TOKENS_PATH`, defaulting to `~/.cache/huggingface/stored_tokens`. + If the file does not exist, an empty dictionary is returned. + + Returns: `Dict[str, str]` + Key is the token name and value is the token. + """ + tokens_path = Path(constants.HF_STORED_TOKENS_PATH) + if not tokens_path.exists(): + stored_tokens = {} + config = configparser.ConfigParser() + try: + config.read(tokens_path) + stored_tokens = {token_name: config.get(token_name, "hf_token") for token_name in config.sections()} + except configparser.Error as e: + logger.error(f"Error parsing stored tokens file: {e}") + stored_tokens = {} + return stored_tokens + + +def _save_stored_tokens(stored_tokens: Dict[str, str]) -> None: + """ + Saves the given configuration to the stored tokens file. + + Args: + stored_tokens (`Dict[str, str]`): + The stored tokens to save. Key is the token name and value is the token. + """ + stored_tokens_path = Path(constants.HF_STORED_TOKENS_PATH) + + # Write the stored tokens into an INI file + config = configparser.ConfigParser() + for token_name in sorted(stored_tokens.keys()): + config.add_section(token_name) + config.set(token_name, "hf_token", stored_tokens[token_name]) + + stored_tokens_path.parent.mkdir(parents=True, exist_ok=True) + with stored_tokens_path.open("w") as config_file: + config.write(config_file) + + +def _get_token_by_name(token_name: str) -> Optional[str]: + """ + Get the token by name. + + Args: + token_name (`str`): + The name of the token to get. + + Returns: + `str` or `None`: The token, `None` if it doesn't exist. + + """ + stored_tokens = get_stored_tokens() + if token_name not in stored_tokens: + return None + return _clean_token(stored_tokens[token_name]) + + +def _save_token(token: str, token_name: str) -> None: + """ + Save the given token. + + If the stored tokens file does not exist, it will be created. + Args: + token (`str`): + The token to save. + token_name (`str`): + The name of the token. + """ + tokens_path = Path(constants.HF_STORED_TOKENS_PATH) + stored_tokens = get_stored_tokens() + stored_tokens[token_name] = token + _save_stored_tokens(stored_tokens) + logger.info(f"The token `{token_name}` has been saved to {tokens_path}") + + +def _clean_token(token: Optional[str]) -> Optional[str]: + """Clean token by removing trailing and leading spaces and newlines. + + If token is an empty string, return None. + """ + if token is None: + return None + return token.replace("\r", "").replace("\n", "").strip() or None diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_cache_assets.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_cache_assets.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d435df9b0bb0c67c0bcb5ef65711e9aef367f6 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_cache_assets.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Union + +from ..constants import HF_ASSETS_CACHE + + +def cached_assets_path( + library_name: str, + namespace: str = "default", + subfolder: str = "default", + *, + assets_dir: Union[str, Path, None] = None, +): + """Return a folder path to cache arbitrary files. + + `huggingface_hub` provides a canonical folder path to store assets. This is the + recommended way to integrate cache in a downstream library as it will benefit from + the builtins tools to scan and delete the cache properly. + + The distinction is made between files cached from the Hub and assets. Files from the + Hub are cached in a git-aware manner and entirely managed by `huggingface_hub`. See + [related documentation](https://huggingface.co/docs/huggingface_hub/how-to-cache). + All other files that a downstream library caches are considered to be "assets" + (files downloaded from external sources, extracted from a .tar archive, preprocessed + for training,...). + + Once the folder path is generated, it is guaranteed to exist and to be a directory. + The path is based on 3 levels of depth: the library name, a namespace and a + subfolder. Those 3 levels grants flexibility while allowing `huggingface_hub` to + expect folders when scanning/deleting parts of the assets cache. Within a library, + it is expected that all namespaces share the same subset of subfolder names but this + is not a mandatory rule. The downstream library has then full control on which file + structure to adopt within its cache. Namespace and subfolder are optional (would + default to a `"default/"` subfolder) but library name is mandatory as we want every + downstream library to manage its own cache. + + Expected tree: + ```text + assets/ + └── datasets/ + │ ├── SQuAD/ + │ │ ├── downloaded/ + │ │ ├── extracted/ + │ │ └── processed/ + │ ├── Helsinki-NLP--tatoeba_mt/ + │ ├── downloaded/ + │ ├── extracted/ + │ └── processed/ + └── transformers/ + ├── default/ + │ ├── something/ + ├── bert-base-cased/ + │ ├── default/ + │ └── training/ + hub/ + └── models--julien-c--EsperBERTo-small/ + ├── blobs/ + │ ├── (...) + │ ├── (...) + ├── refs/ + │ └── (...) + └── [ 128] snapshots/ + ├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/ + │ ├── (...) + └── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/ + └── (...) + ``` + + + Args: + library_name (`str`): + Name of the library that will manage the cache folder. Example: `"dataset"`. + namespace (`str`, *optional*, defaults to "default"): + Namespace to which the data belongs. Example: `"SQuAD"`. + subfolder (`str`, *optional*, defaults to "default"): + Subfolder in which the data will be stored. Example: `extracted`. + assets_dir (`str`, `Path`, *optional*): + Path to the folder where assets are cached. This must not be the same folder + where Hub files are cached. Defaults to `HF_HOME / "assets"` if not provided. + Can also be set with `HF_ASSETS_CACHE` environment variable. + + Returns: + Path to the cache folder (`Path`). + + Example: + ```py + >>> from huggingface_hub import cached_assets_path + + >>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download") + PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/download') + + >>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="extracted") + PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/extracted') + + >>> cached_assets_path(library_name="datasets", namespace="Helsinki-NLP/tatoeba_mt") + PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/Helsinki-NLP--tatoeba_mt/default') + + >>> cached_assets_path(library_name="datasets", assets_dir="/tmp/tmp123456") + PosixPath('/tmp/tmp123456/datasets/default/default') + ``` + """ + # Resolve assets_dir + if assets_dir is None: + assets_dir = HF_ASSETS_CACHE + assets_dir = Path(assets_dir).expanduser().resolve() + + # Avoid names that could create path issues + for part in (" ", "/", "\\"): + library_name = library_name.replace(part, "--") + namespace = namespace.replace(part, "--") + subfolder = subfolder.replace(part, "--") + + # Path to subfolder is created + path = assets_dir / library_name / namespace / subfolder + try: + path.mkdir(exist_ok=True, parents=True) + except (FileExistsError, NotADirectoryError): + raise ValueError(f"Corrupted assets folder: cannot create directory because of an existing file ({path}).") + + # Return + return path diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_cache_manager.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..311e164a4fea0ea323a00915fb2bf7f930e14de1 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_cache_manager.py @@ -0,0 +1,896 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to manage the HF cache directory.""" + +import os +import shutil +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, FrozenSet, List, Literal, Optional, Set, Union + +from huggingface_hub.errors import CacheNotFound, CorruptedCacheException + +from ..commands._cli_utils import tabulate +from ..constants import HF_HUB_CACHE +from . import logging + + +logger = logging.get_logger(__name__) + +REPO_TYPE_T = Literal["model", "dataset", "space"] + +# List of OS-created helper files that need to be ignored +FILES_TO_IGNORE = [".DS_Store"] + + +@dataclass(frozen=True) +class CachedFileInfo: + """Frozen data structure holding information about a single cached file. + + Args: + file_name (`str`): + Name of the file. Example: `config.json`. + file_path (`Path`): + Path of the file in the `snapshots` directory. The file path is a symlink + referring to a blob in the `blobs` folder. + blob_path (`Path`): + Path of the blob file. This is equivalent to `file_path.resolve()`. + size_on_disk (`int`): + Size of the blob file in bytes. + blob_last_accessed (`float`): + Timestamp of the last time the blob file has been accessed (from any + revision). + blob_last_modified (`float`): + Timestamp of the last time the blob file has been modified/created. + + + + `blob_last_accessed` and `blob_last_modified` reliability can depend on the OS you + are using. See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result) + for more details. + + + """ + + file_name: str + file_path: Path + blob_path: Path + size_on_disk: int + + blob_last_accessed: float + blob_last_modified: float + + @property + def blob_last_accessed_str(self) -> str: + """ + (property) Timestamp of the last time the blob file has been accessed (from any + revision), returned as a human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.blob_last_accessed) + + @property + def blob_last_modified_str(self) -> str: + """ + (property) Timestamp of the last time the blob file has been modified, returned + as a human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.blob_last_modified) + + @property + def size_on_disk_str(self) -> str: + """ + (property) Size of the blob file as a human-readable string. + + Example: "42.2K". + """ + return _format_size(self.size_on_disk) + + +@dataclass(frozen=True) +class CachedRevisionInfo: + """Frozen data structure holding information about a revision. + + A revision correspond to a folder in the `snapshots` folder and is populated with + the exact tree structure as the repo on the Hub but contains only symlinks. A + revision can be either referenced by 1 or more `refs` or be "detached" (no refs). + + Args: + commit_hash (`str`): + Hash of the revision (unique). + Example: `"9338f7b671827df886678df2bdd7cc7b4f36dffd"`. + snapshot_path (`Path`): + Path to the revision directory in the `snapshots` folder. It contains the + exact tree structure as the repo on the Hub. + files: (`FrozenSet[CachedFileInfo]`): + Set of [`~CachedFileInfo`] describing all files contained in the snapshot. + refs (`FrozenSet[str]`): + Set of `refs` pointing to this revision. If the revision has no `refs`, it + is considered detached. + Example: `{"main", "2.4.0"}` or `{"refs/pr/1"}`. + size_on_disk (`int`): + Sum of the blob file sizes that are symlink-ed by the revision. + last_modified (`float`): + Timestamp of the last time the revision has been created/modified. + + + + `last_accessed` cannot be determined correctly on a single revision as blob files + are shared across revisions. + + + + + + `size_on_disk` is not necessarily the sum of all file sizes because of possible + duplicated files. Besides, only blobs are taken into account, not the (negligible) + size of folders and symlinks. + + + """ + + commit_hash: str + snapshot_path: Path + size_on_disk: int + files: FrozenSet[CachedFileInfo] + refs: FrozenSet[str] + + last_modified: float + + @property + def last_modified_str(self) -> str: + """ + (property) Timestamp of the last time the revision has been modified, returned + as a human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.last_modified) + + @property + def size_on_disk_str(self) -> str: + """ + (property) Sum of the blob file sizes as a human-readable string. + + Example: "42.2K". + """ + return _format_size(self.size_on_disk) + + @property + def nb_files(self) -> int: + """ + (property) Total number of files in the revision. + """ + return len(self.files) + + +@dataclass(frozen=True) +class CachedRepoInfo: + """Frozen data structure holding information about a cached repository. + + Args: + repo_id (`str`): + Repo id of the repo on the Hub. Example: `"google/fleurs"`. + repo_type (`Literal["dataset", "model", "space"]`): + Type of the cached repo. + repo_path (`Path`): + Local path to the cached repo. + size_on_disk (`int`): + Sum of the blob file sizes in the cached repo. + nb_files (`int`): + Total number of blob files in the cached repo. + revisions (`FrozenSet[CachedRevisionInfo]`): + Set of [`~CachedRevisionInfo`] describing all revisions cached in the repo. + last_accessed (`float`): + Timestamp of the last time a blob file of the repo has been accessed. + last_modified (`float`): + Timestamp of the last time a blob file of the repo has been modified/created. + + + + `size_on_disk` is not necessarily the sum of all revisions sizes because of + duplicated files. Besides, only blobs are taken into account, not the (negligible) + size of folders and symlinks. + + + + + + `last_accessed` and `last_modified` reliability can depend on the OS you are using. + See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result) + for more details. + + + """ + + repo_id: str + repo_type: REPO_TYPE_T + repo_path: Path + size_on_disk: int + nb_files: int + revisions: FrozenSet[CachedRevisionInfo] + + last_accessed: float + last_modified: float + + @property + def last_accessed_str(self) -> str: + """ + (property) Last time a blob file of the repo has been accessed, returned as a + human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.last_accessed) + + @property + def last_modified_str(self) -> str: + """ + (property) Last time a blob file of the repo has been modified, returned as a + human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.last_modified) + + @property + def size_on_disk_str(self) -> str: + """ + (property) Sum of the blob file sizes as a human-readable string. + + Example: "42.2K". + """ + return _format_size(self.size_on_disk) + + @property + def refs(self) -> Dict[str, CachedRevisionInfo]: + """ + (property) Mapping between `refs` and revision data structures. + """ + return {ref: revision for revision in self.revisions for ref in revision.refs} + + +@dataclass(frozen=True) +class DeleteCacheStrategy: + """Frozen data structure holding the strategy to delete cached revisions. + + This object is not meant to be instantiated programmatically but to be returned by + [`~utils.HFCacheInfo.delete_revisions`]. See documentation for usage example. + + Args: + expected_freed_size (`float`): + Expected freed size once strategy is executed. + blobs (`FrozenSet[Path]`): + Set of blob file paths to be deleted. + refs (`FrozenSet[Path]`): + Set of reference file paths to be deleted. + repos (`FrozenSet[Path]`): + Set of entire repo paths to be deleted. + snapshots (`FrozenSet[Path]`): + Set of snapshots to be deleted (directory of symlinks). + """ + + expected_freed_size: int + blobs: FrozenSet[Path] + refs: FrozenSet[Path] + repos: FrozenSet[Path] + snapshots: FrozenSet[Path] + + @property + def expected_freed_size_str(self) -> str: + """ + (property) Expected size that will be freed as a human-readable string. + + Example: "42.2K". + """ + return _format_size(self.expected_freed_size) + + def execute(self) -> None: + """Execute the defined strategy. + + + + If this method is interrupted, the cache might get corrupted. Deletion order is + implemented so that references and symlinks are deleted before the actual blob + files. + + + + + + This method is irreversible. If executed, cached files are erased and must be + downloaded again. + + + """ + # Deletion order matters. Blobs are deleted in last so that the user can't end + # up in a state where a `ref`` refers to a missing snapshot or a snapshot + # symlink refers to a deleted blob. + + # Delete entire repos + for path in self.repos: + _try_delete_path(path, path_type="repo") + + # Delete snapshot directories + for path in self.snapshots: + _try_delete_path(path, path_type="snapshot") + + # Delete refs files + for path in self.refs: + _try_delete_path(path, path_type="ref") + + # Delete blob files + for path in self.blobs: + _try_delete_path(path, path_type="blob") + + logger.info(f"Cache deletion done. Saved {self.expected_freed_size_str}.") + + +@dataclass(frozen=True) +class HFCacheInfo: + """Frozen data structure holding information about the entire cache-system. + + This data structure is returned by [`scan_cache_dir`] and is immutable. + + Args: + size_on_disk (`int`): + Sum of all valid repo sizes in the cache-system. + repos (`FrozenSet[CachedRepoInfo]`): + Set of [`~CachedRepoInfo`] describing all valid cached repos found on the + cache-system while scanning. + warnings (`List[CorruptedCacheException]`): + List of [`~CorruptedCacheException`] that occurred while scanning the cache. + Those exceptions are captured so that the scan can continue. Corrupted repos + are skipped from the scan. + + + + Here `size_on_disk` is equal to the sum of all repo sizes (only blobs). However if + some cached repos are corrupted, their sizes are not taken into account. + + + """ + + size_on_disk: int + repos: FrozenSet[CachedRepoInfo] + warnings: List[CorruptedCacheException] + + @property + def size_on_disk_str(self) -> str: + """ + (property) Sum of all valid repo sizes in the cache-system as a human-readable + string. + + Example: "42.2K". + """ + return _format_size(self.size_on_disk) + + def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy: + """Prepare the strategy to delete one or more revisions cached locally. + + Input revisions can be any revision hash. If a revision hash is not found in the + local cache, a warning is thrown but no error is raised. Revisions can be from + different cached repos since hashes are unique across repos, + + Examples: + ```py + >>> from huggingface_hub import scan_cache_dir + >>> cache_info = scan_cache_dir() + >>> delete_strategy = cache_info.delete_revisions( + ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa" + ... ) + >>> print(f"Will free {delete_strategy.expected_freed_size_str}.") + Will free 7.9K. + >>> delete_strategy.execute() + Cache deletion done. Saved 7.9K. + ``` + + ```py + >>> from huggingface_hub import scan_cache_dir + >>> scan_cache_dir().delete_revisions( + ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa", + ... "e2983b237dccf3ab4937c97fa717319a9ca1a96d", + ... "6c0e6080953db56375760c0471a8c5f2929baf11", + ... ).execute() + Cache deletion done. Saved 8.6G. + ``` + + + + `delete_revisions` returns a [`~utils.DeleteCacheStrategy`] object that needs to + be executed. The [`~utils.DeleteCacheStrategy`] is not meant to be modified but + allows having a dry run before actually executing the deletion. + + + """ + hashes_to_delete: Set[str] = set(revisions) + + repos_with_revisions: Dict[CachedRepoInfo, Set[CachedRevisionInfo]] = defaultdict(set) + + for repo in self.repos: + for revision in repo.revisions: + if revision.commit_hash in hashes_to_delete: + repos_with_revisions[repo].add(revision) + hashes_to_delete.remove(revision.commit_hash) + + if len(hashes_to_delete) > 0: + logger.warning(f"Revision(s) not found - cannot delete them: {', '.join(hashes_to_delete)}") + + delete_strategy_blobs: Set[Path] = set() + delete_strategy_refs: Set[Path] = set() + delete_strategy_repos: Set[Path] = set() + delete_strategy_snapshots: Set[Path] = set() + delete_strategy_expected_freed_size = 0 + + for affected_repo, revisions_to_delete in repos_with_revisions.items(): + other_revisions = affected_repo.revisions - revisions_to_delete + + # If no other revisions, it means all revisions are deleted + # -> delete the entire cached repo + if len(other_revisions) == 0: + delete_strategy_repos.add(affected_repo.repo_path) + delete_strategy_expected_freed_size += affected_repo.size_on_disk + continue + + # Some revisions of the repo will be deleted but not all. We need to filter + # which blob files will not be linked anymore. + for revision_to_delete in revisions_to_delete: + # Snapshot dir + delete_strategy_snapshots.add(revision_to_delete.snapshot_path) + + # Refs dir + for ref in revision_to_delete.refs: + delete_strategy_refs.add(affected_repo.repo_path / "refs" / ref) + + # Blobs dir + for file in revision_to_delete.files: + if file.blob_path not in delete_strategy_blobs: + is_file_alone = True + for revision in other_revisions: + for rev_file in revision.files: + if file.blob_path == rev_file.blob_path: + is_file_alone = False + break + if not is_file_alone: + break + + # Blob file not referenced by remaining revisions -> delete + if is_file_alone: + delete_strategy_blobs.add(file.blob_path) + delete_strategy_expected_freed_size += file.size_on_disk + + # Return the strategy instead of executing it. + return DeleteCacheStrategy( + blobs=frozenset(delete_strategy_blobs), + refs=frozenset(delete_strategy_refs), + repos=frozenset(delete_strategy_repos), + snapshots=frozenset(delete_strategy_snapshots), + expected_freed_size=delete_strategy_expected_freed_size, + ) + + def export_as_table(self, *, verbosity: int = 0) -> str: + """Generate a table from the [`HFCacheInfo`] object. + + Pass `verbosity=0` to get a table with a single row per repo, with columns + "repo_id", "repo_type", "size_on_disk", "nb_files", "last_accessed", "last_modified", "refs", "local_path". + + Pass `verbosity=1` to get a table with a row per repo and revision (thus multiple rows can appear for a single repo), with columns + "repo_id", "repo_type", "revision", "size_on_disk", "nb_files", "last_modified", "refs", "local_path". + + Example: + ```py + >>> from huggingface_hub.utils import scan_cache_dir + + >>> hf_cache_info = scan_cache_dir() + HFCacheInfo(...) + + >>> print(hf_cache_info.export_as_table()) + REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH + --------------------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------------- + roberta-base model 2.7M 5 1 day ago 1 week ago main ~/.cache/huggingface/hub/models--roberta-base + suno/bark model 8.8K 1 1 week ago 1 week ago main ~/.cache/huggingface/hub/models--suno--bark + t5-base model 893.8M 4 4 days ago 7 months ago main ~/.cache/huggingface/hub/models--t5-base + t5-large model 3.0G 4 5 weeks ago 5 months ago main ~/.cache/huggingface/hub/models--t5-large + + >>> print(hf_cache_info.export_as_table(verbosity=1)) + REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH + --------------------------------------------------- --------- ---------------------------------------- ------------ -------- ------------- ---- ----------------------------------------------------------------------------------------------------------------------------------------------------- + roberta-base model e2da8e2f811d1448a5b465c236feacd80ffbac7b 2.7M 5 1 week ago main ~/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b + suno/bark model 70a8a7d34168586dc5d028fa9666aceade177992 8.8K 1 1 week ago main ~/.cache/huggingface/hub/models--suno--bark/snapshots/70a8a7d34168586dc5d028fa9666aceade177992 + t5-base model a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 893.8M 4 7 months ago main ~/.cache/huggingface/hub/models--t5-base/snapshots/a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 + t5-large model 150ebc2c4b72291e770f58e6057481c8d2ed331a 3.0G 4 5 months ago main ~/.cache/huggingface/hub/models--t5-large/snapshots/150ebc2c4b72291e770f58e6057481c8d2ed331a + ``` + + Args: + verbosity (`int`, *optional*): + The verbosity level. Defaults to 0. + + Returns: + `str`: The table as a string. + """ + if verbosity == 0: + return tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + "{:>12}".format(repo.size_on_disk_str), + repo.nb_files, + repo.last_accessed_str, + repo.last_modified_str, + ", ".join(sorted(repo.refs)), + str(repo.repo_path), + ] + for repo in sorted(self.repos, key=lambda repo: repo.repo_path) + ], + headers=[ + "REPO ID", + "REPO TYPE", + "SIZE ON DISK", + "NB FILES", + "LAST_ACCESSED", + "LAST_MODIFIED", + "REFS", + "LOCAL PATH", + ], + ) + else: + return tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + revision.commit_hash, + "{:>12}".format(revision.size_on_disk_str), + revision.nb_files, + revision.last_modified_str, + ", ".join(sorted(revision.refs)), + str(revision.snapshot_path), + ] + for repo in sorted(self.repos, key=lambda repo: repo.repo_path) + for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash) + ], + headers=[ + "REPO ID", + "REPO TYPE", + "REVISION", + "SIZE ON DISK", + "NB FILES", + "LAST_MODIFIED", + "REFS", + "LOCAL PATH", + ], + ) + + +def scan_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> HFCacheInfo: + """Scan the entire HF cache-system and return a [`~HFCacheInfo`] structure. + + Use `scan_cache_dir` in order to programmatically scan your cache-system. The cache + will be scanned repo by repo. If a repo is corrupted, a [`~CorruptedCacheException`] + will be thrown internally but captured and returned in the [`~HFCacheInfo`] + structure. Only valid repos get a proper report. + + ```py + >>> from huggingface_hub import scan_cache_dir + + >>> hf_cache_info = scan_cache_dir() + HFCacheInfo( + size_on_disk=3398085269, + repos=frozenset({ + CachedRepoInfo( + repo_id='t5-small', + repo_type='model', + repo_path=PosixPath(...), + size_on_disk=970726914, + nb_files=11, + revisions=frozenset({ + CachedRevisionInfo( + commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5', + size_on_disk=970726339, + snapshot_path=PosixPath(...), + files=frozenset({ + CachedFileInfo( + file_name='config.json', + size_on_disk=1197 + file_path=PosixPath(...), + blob_path=PosixPath(...), + ), + CachedFileInfo(...), + ... + }), + ), + CachedRevisionInfo(...), + ... + }), + ), + CachedRepoInfo(...), + ... + }), + warnings=[ + CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."), + CorruptedCacheException(...), + ... + ], + ) + ``` + + You can also print a detailed report directly from the `hf` command line using: + ```text + > hf cache scan + REPO ID REPO TYPE SIZE ON DISK NB FILES REFS LOCAL PATH + --------------------------- --------- ------------ -------- ------------------- ------------------------------------------------------------------------- + glue dataset 116.3K 15 1.17.0, main, 2.4.0 /Users/lucain/.cache/huggingface/hub/datasets--glue + google/fleurs dataset 64.9M 6 main, refs/pr/1 /Users/lucain/.cache/huggingface/hub/datasets--google--fleurs + Jean-Baptiste/camembert-ner model 441.0M 7 main /Users/lucain/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner + bert-base-cased model 1.9G 13 main /Users/lucain/.cache/huggingface/hub/models--bert-base-cased + t5-base model 10.1K 3 main /Users/lucain/.cache/huggingface/hub/models--t5-base + t5-small model 970.7M 11 refs/pr/1, main /Users/lucain/.cache/huggingface/hub/models--t5-small + + Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. + Got 1 warning(s) while scanning. Use -vvv to print details. + ``` + + Args: + cache_dir (`str` or `Path`, `optional`): + Cache directory to cache. Defaults to the default HF cache directory. + + + + Raises: + + `CacheNotFound` + If the cache directory does not exist. + + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the cache directory is a file, instead of a directory. + + + + Returns: a [`~HFCacheInfo`] object. + """ + if cache_dir is None: + cache_dir = HF_HUB_CACHE + + cache_dir = Path(cache_dir).expanduser().resolve() + if not cache_dir.exists(): + raise CacheNotFound( + f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.", + cache_dir=cache_dir, + ) + + if cache_dir.is_file(): + raise ValueError( + f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable." + ) + + repos: Set[CachedRepoInfo] = set() + warnings: List[CorruptedCacheException] = [] + for repo_path in cache_dir.iterdir(): + if repo_path.name == ".locks": # skip './.locks/' folder + continue + try: + repos.add(_scan_cached_repo(repo_path)) + except CorruptedCacheException as e: + warnings.append(e) + + return HFCacheInfo( + repos=frozenset(repos), + size_on_disk=sum(repo.size_on_disk for repo in repos), + warnings=warnings, + ) + + +def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: + """Scan a single cache repo and return information about it. + + Any unexpected behavior will raise a [`~CorruptedCacheException`]. + """ + if not repo_path.is_dir(): + raise CorruptedCacheException(f"Repo path is not a directory: {repo_path}") + + if "--" not in repo_path.name: + raise CorruptedCacheException(f"Repo path is not a valid HuggingFace cache directory: {repo_path}") + + repo_type, repo_id = repo_path.name.split("--", maxsplit=1) + repo_type = repo_type[:-1] # "models" -> "model" + repo_id = repo_id.replace("--", "/") # google/fleurs -> "google/fleurs" + + if repo_type not in {"dataset", "model", "space"}: + raise CorruptedCacheException( + f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}` ({repo_path})." + ) + + blob_stats: Dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats + + snapshots_path = repo_path / "snapshots" + refs_path = repo_path / "refs" + + if not snapshots_path.exists() or not snapshots_path.is_dir(): + raise CorruptedCacheException(f"Snapshots dir doesn't exist in cached repo: {snapshots_path}") + + # Scan over `refs` directory + + # key is revision hash, value is set of refs + refs_by_hash: Dict[str, Set[str]] = defaultdict(set) + if refs_path.exists(): + # Example of `refs` directory + # ── refs + # ├── main + # └── refs + # └── pr + # └── 1 + if refs_path.is_file(): + raise CorruptedCacheException(f"Refs directory cannot be a file: {refs_path}") + + for ref_path in refs_path.glob("**/*"): + # glob("**/*") iterates over all files and directories -> skip directories + if ref_path.is_dir() or ref_path.name in FILES_TO_IGNORE: + continue + + ref_name = str(ref_path.relative_to(refs_path)) + with ref_path.open() as f: + commit_hash = f.read() + + refs_by_hash[commit_hash].add(ref_name) + + # Scan snapshots directory + cached_revisions: Set[CachedRevisionInfo] = set() + for revision_path in snapshots_path.iterdir(): + # Ignore OS-created helper files + if revision_path.name in FILES_TO_IGNORE: + continue + if revision_path.is_file(): + raise CorruptedCacheException(f"Snapshots folder corrupted. Found a file: {revision_path}") + + cached_files = set() + for file_path in revision_path.glob("**/*"): + # glob("**/*") iterates over all files and directories -> skip directories + if file_path.is_dir(): + continue + + blob_path = Path(file_path).resolve() + if not blob_path.exists(): + raise CorruptedCacheException(f"Blob missing (broken symlink): {blob_path}") + + if blob_path not in blob_stats: + blob_stats[blob_path] = blob_path.stat() + + cached_files.add( + CachedFileInfo( + file_name=file_path.name, + file_path=file_path, + size_on_disk=blob_stats[blob_path].st_size, + blob_path=blob_path, + blob_last_accessed=blob_stats[blob_path].st_atime, + blob_last_modified=blob_stats[blob_path].st_mtime, + ) + ) + + # Last modified is either the last modified blob file or the revision folder + # itself if it is empty + if len(cached_files) > 0: + revision_last_modified = max(blob_stats[file.blob_path].st_mtime for file in cached_files) + else: + revision_last_modified = revision_path.stat().st_mtime + + cached_revisions.add( + CachedRevisionInfo( + commit_hash=revision_path.name, + files=frozenset(cached_files), + refs=frozenset(refs_by_hash.pop(revision_path.name, set())), + size_on_disk=sum( + blob_stats[blob_path].st_size for blob_path in set(file.blob_path for file in cached_files) + ), + snapshot_path=revision_path, + last_modified=revision_last_modified, + ) + ) + + # Check that all refs referred to an existing revision + if len(refs_by_hash) > 0: + raise CorruptedCacheException( + f"Reference(s) refer to missing commit hashes: {dict(refs_by_hash)} ({repo_path})." + ) + + # Last modified is either the last modified blob file or the repo folder itself if + # no blob files has been found. Same for last accessed. + if len(blob_stats) > 0: + repo_last_accessed = max(stat.st_atime for stat in blob_stats.values()) + repo_last_modified = max(stat.st_mtime for stat in blob_stats.values()) + else: + repo_stats = repo_path.stat() + repo_last_accessed = repo_stats.st_atime + repo_last_modified = repo_stats.st_mtime + + # Build and return frozen structure + return CachedRepoInfo( + nb_files=len(blob_stats), + repo_id=repo_id, + repo_path=repo_path, + repo_type=repo_type, # type: ignore + revisions=frozenset(cached_revisions), + size_on_disk=sum(stat.st_size for stat in blob_stats.values()), + last_accessed=repo_last_accessed, + last_modified=repo_last_modified, + ) + + +def _format_size(num: int) -> str: + """Format size in bytes into a human-readable string. + + Taken from https://stackoverflow.com/a/1094933 + """ + num_f = float(num) + for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: + if abs(num_f) < 1000.0: + return f"{num_f:3.1f}{unit}" + num_f /= 1000.0 + return f"{num_f:.1f}Y" + + +_TIMESINCE_CHUNKS = ( + # Label, divider, max value + ("second", 1, 60), + ("minute", 60, 60), + ("hour", 60 * 60, 24), + ("day", 60 * 60 * 24, 6), + ("week", 60 * 60 * 24 * 7, 6), + ("month", 60 * 60 * 24 * 30, 11), + ("year", 60 * 60 * 24 * 365, None), +) + + +def _format_timesince(ts: float) -> str: + """Format timestamp in seconds into a human-readable string, relative to now. + + Vaguely inspired by Django's `timesince` formatter. + """ + delta = time.time() - ts + if delta < 20: + return "a few seconds ago" + for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007 + value = round(delta / divider) + if max_value is not None and value <= max_value: + break + return f"{value} {label}{'s' if value > 1 else ''} ago" + + +def _try_delete_path(path: Path, path_type: str) -> None: + """Try to delete a local file or folder. + + If the path does not exists, error is logged as a warning and then ignored. + + Args: + path (`Path`) + Path to delete. Can be a file or a folder. + path_type (`str`) + What path are we deleting ? Only for logging purposes. Example: "snapshot". + """ + logger.info(f"Delete {path_type}: {path}") + try: + if path.is_file(): + os.remove(path) + else: + shutil.rmtree(path) + except FileNotFoundError: + logger.warning(f"Couldn't delete {path_type}: file not found ({path})", exc_info=True) + except PermissionError: + logger.warning(f"Couldn't delete {path_type}: permission denied ({path})", exc_info=True) diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_chunk_utils.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_chunk_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b0af032ae6a68f03676ad7fdb8e483248d9853f8 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_chunk_utils.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a utility to iterate by chunks over an iterator.""" + +import itertools +from typing import Iterable, TypeVar + + +T = TypeVar("T") + + +def chunk_iterable(iterable: Iterable[T], chunk_size: int) -> Iterable[Iterable[T]]: + """Iterates over an iterator chunk by chunk. + + Taken from https://stackoverflow.com/a/8998040. + See also https://github.com/huggingface/huggingface_hub/pull/920#discussion_r938793088. + + Args: + iterable (`Iterable`): + The iterable on which we want to iterate. + chunk_size (`int`): + Size of the chunks. Must be a strictly positive integer (e.g. >0). + + Example: + + ```python + >>> from huggingface_hub.utils import chunk_iterable + + >>> for items in chunk_iterable(range(17), chunk_size=8): + ... print(items) + # [0, 1, 2, 3, 4, 5, 6, 7] + # [8, 9, 10, 11, 12, 13, 14, 15] + # [16] # smaller last chunk + ``` + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `chunk_size` <= 0. + + + The last chunk can be smaller than `chunk_size`. + + """ + if not isinstance(chunk_size, int) or chunk_size <= 0: + raise ValueError("`chunk_size` must be a strictly positive integer (>0).") + + iterator = iter(iterable) + while True: + try: + next_item = next(iterator) + except StopIteration: + return + yield itertools.chain((next_item,), itertools.islice(iterator, chunk_size - 1)) diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_datetime.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7f44285d1c826006c97176ca66c3e9c33f61c0 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_datetime.py @@ -0,0 +1,67 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle datetimes in Huggingface Hub.""" + +from datetime import datetime, timezone + + +def parse_datetime(date_string: str) -> datetime: + """ + Parses a date_string returned from the server to a datetime object. + + This parser is a weak-parser is the sense that it handles only a single format of + date_string. It is expected that the server format will never change. The + implementation depends only on the standard lib to avoid an external dependency + (python-dateutil). See full discussion about this decision on PR: + https://github.com/huggingface/huggingface_hub/pull/999. + + Example: + ```py + > parse_datetime('2022-08-19T07:19:38.123Z') + datetime.datetime(2022, 8, 19, 7, 19, 38, 123000, tzinfo=timezone.utc) + ``` + + Args: + date_string (`str`): + A string representing a datetime returned by the Hub server. + String is expected to follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern. + + Returns: + A python datetime object. + + Raises: + :class:`ValueError`: + If `date_string` cannot be parsed. + """ + try: + # Normalize the string to always have 6 digits of fractional seconds + if date_string.endswith("Z"): + # Case 1: No decimal point (e.g., "2024-11-16T00:27:02Z") + if "." not in date_string: + # No fractional seconds - insert .000000 + date_string = date_string[:-1] + ".000000Z" + # Case 2: Has decimal point (e.g., "2022-08-19T07:19:38.123456789Z") + else: + # Get the fractional and base parts + base, fraction = date_string[:-1].split(".") + # fraction[:6] takes first 6 digits and :0<6 pads with zeros if less than 6 digits + date_string = f"{base}.{fraction[:6]:0<6}Z" + + return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ").replace(tzinfo=timezone.utc) + except ValueError as e: + raise ValueError( + f"Cannot parse '{date_string}' as a datetime. Date string is expected to" + " follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern." + ) from e diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_deprecation.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb8d6e418c76accd1ecd61158b4bdd265e12f71 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_deprecation.py @@ -0,0 +1,136 @@ +import warnings +from functools import wraps +from inspect import Parameter, signature +from typing import Iterable, Optional + + +def _deprecate_positional_args(*, version: str): + """Decorator for methods that issues warnings for positional arguments. + Using the keyword-only argument syntax in pep 3102, arguments after the + * will issue a warning when passed as a positional argument. + + Args: + version (`str`): + The version when positional arguments will result in error. + """ + + def _inner_deprecate_positional_args(f): + sig = signature(f) + kwonly_args = [] + all_args = [] + for name, param in sig.parameters.items(): + if param.kind == Parameter.POSITIONAL_OR_KEYWORD: + all_args.append(name) + elif param.kind == Parameter.KEYWORD_ONLY: + kwonly_args.append(name) + + @wraps(f) + def inner_f(*args, **kwargs): + extra_args = len(args) - len(all_args) + if extra_args <= 0: + return f(*args, **kwargs) + # extra_args > 0 + args_msg = [ + f"{name}='{arg}'" if isinstance(arg, str) else f"{name}={arg}" + for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:]) + ] + args_msg = ", ".join(args_msg) + warnings.warn( + f"Deprecated positional argument(s) used in '{f.__name__}': pass" + f" {args_msg} as keyword args. From version {version} passing these" + " as positional arguments will result in an error,", + FutureWarning, + ) + kwargs.update(zip(sig.parameters, args)) + return f(**kwargs) + + return inner_f + + return _inner_deprecate_positional_args + + +def _deprecate_arguments( + *, + version: str, + deprecated_args: Iterable[str], + custom_message: Optional[str] = None, +): + """Decorator to issue warnings when using deprecated arguments. + + TODO: could be useful to be able to set a custom error message. + + Args: + version (`str`): + The version when deprecated arguments will result in error. + deprecated_args (`List[str]`): + List of the arguments to be deprecated. + custom_message (`str`, *optional*): + Warning message that is raised. If not passed, a default warning message + will be created. + """ + + def _inner_deprecate_positional_args(f): + sig = signature(f) + + @wraps(f) + def inner_f(*args, **kwargs): + # Check for used deprecated arguments + used_deprecated_args = [] + for _, parameter in zip(args, sig.parameters.values()): + if parameter.name in deprecated_args: + used_deprecated_args.append(parameter.name) + for kwarg_name, kwarg_value in kwargs.items(): + if ( + # If argument is deprecated but still used + kwarg_name in deprecated_args + # And then the value is not the default value + and kwarg_value != sig.parameters[kwarg_name].default + ): + used_deprecated_args.append(kwarg_name) + + # Warn and proceed + if len(used_deprecated_args) > 0: + message = ( + f"Deprecated argument(s) used in '{f.__name__}':" + f" {', '.join(used_deprecated_args)}. Will not be supported from" + f" version '{version}'." + ) + if custom_message is not None: + message += "\n\n" + custom_message + warnings.warn(message, FutureWarning) + return f(*args, **kwargs) + + return inner_f + + return _inner_deprecate_positional_args + + +def _deprecate_method(*, version: str, message: Optional[str] = None): + """Decorator to issue warnings when using a deprecated method. + + Args: + version (`str`): + The version when deprecated arguments will result in error. + message (`str`, *optional*): + Warning message that is raised. If not passed, a default warning message + will be created. + """ + + def _inner_deprecate_method(f): + name = f.__name__ + if name == "__init__": + name = f.__qualname__.split(".")[0] # class name instead of method name + + @wraps(f) + def inner_f(*args, **kwargs): + warning_message = ( + f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'." + ) + if message is not None: + warning_message += " " + message + warnings.warn(warning_message, FutureWarning) + return f(*args, **kwargs) + + return inner_f + + return _inner_deprecate_method diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_dotenv.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_dotenv.py new file mode 100644 index 0000000000000000000000000000000000000000..23b8a1b70a4827fc8ae4149c2b1b1e4b00ed7ca2 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_dotenv.py @@ -0,0 +1,55 @@ +# AI-generated module (ChatGPT) +import re +from typing import Dict, Optional + + +def load_dotenv(dotenv_str: str, environ: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """ + Parse a DOTENV-format string and return a dictionary of key-value pairs. + Handles quoted values, comments, export keyword, and blank lines. + """ + env: Dict[str, str] = {} + line_pattern = re.compile( + r""" + ^\s* + (?:export[^\S\n]+)? # optional export + ([A-Za-z_][A-Za-z0-9_]*) # key + [^\S\n]*(=)?[^\S\n]* + ( # value group + (?: + '(?:\\'|[^'])*' # single-quoted value + | \"(?:\\\"|[^\"])*\" # double-quoted value + | [^#\n\r]+? # unquoted value + ) + )? + [^\S\n]*(?:\#.*)?$ # optional inline comment + """, + re.VERBOSE, + ) + + for line in dotenv_str.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue # Skip comments and empty lines + + match = line_pattern.match(line) + if match: + key = match.group(1) + val = None + if match.group(2): # if there is '=' + raw_val = match.group(3) or "" + val = raw_val.strip() + # Remove surrounding quotes if quoted + if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")): + val = val[1:-1] + val = val.replace(r"\n", "\n").replace(r"\t", "\t").replace(r"\"", '"').replace(r"\\", "\\") + if raw_val.startswith('"'): + val = val.replace(r"\$", "$") # only in double quotes + elif environ is not None: + # Get it from the current environment + val = environ.get(key) + + if val is not None: + env[key] = val + + return env diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_experimental.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..40b0ed90ff8af6797758d59b93019498cd72f9ad --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_experimental.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to flag a feature as "experimental" in Huggingface Hub.""" + +import warnings +from functools import wraps +from typing import Callable + +from .. import constants + + +def experimental(fn: Callable) -> Callable: + """Decorator to flag a feature as experimental. + + An experimental feature triggers a warning when used as it might be subject to breaking changes without prior notice + in the future. + + Warnings can be disabled by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable. + + Args: + fn (`Callable`): + The function to flag as experimental. + + Returns: + `Callable`: The decorated function. + + Example: + + ```python + >>> from huggingface_hub.utils import experimental + + >>> @experimental + ... def my_function(): + ... print("Hello world!") + + >>> my_function() + UserWarning: 'my_function' is experimental and might be subject to breaking changes in the future without prior + notice. You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable. + Hello world! + ``` + """ + # For classes, put the "experimental" around the "__new__" method => __new__ will be removed in warning message + name = fn.__qualname__[: -len(".__new__")] if fn.__qualname__.endswith(".__new__") else fn.__qualname__ + + @wraps(fn) + def _inner_fn(*args, **kwargs): + if not constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING: + warnings.warn( + f"'{name}' is experimental and might be subject to breaking changes in the future without prior notice." + " You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment" + " variable.", + UserWarning, + ) + return fn(*args, **kwargs) + + return _inner_fn diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_fixes.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_fixes.py new file mode 100644 index 0000000000000000000000000000000000000000..560003b6222058b03791491b1ce70ea9d7a94404 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_fixes.py @@ -0,0 +1,133 @@ +# JSONDecodeError was introduced in requests=2.27 released in 2022. +# This allows us to support older requests for users +# More information: https://github.com/psf/requests/pull/5856 +try: + from requests import JSONDecodeError # type: ignore # noqa: F401 +except ImportError: + try: + from simplejson import JSONDecodeError # type: ignore # noqa: F401 + except ImportError: + from json import JSONDecodeError # type: ignore # noqa: F401 +import contextlib +import os +import shutil +import stat +import tempfile +import time +from functools import partial +from pathlib import Path +from typing import Callable, Generator, Optional, Union + +import yaml +from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout + +from .. import constants +from . import logging + + +logger = logging.get_logger(__name__) + +# Wrap `yaml.dump` to set `allow_unicode=True` by default. +# +# Example: +# ```py +# >>> yaml.dump({"emoji": "👀", "some unicode": "日本か"}) +# 'emoji: "\\U0001F440"\nsome unicode: "\\u65E5\\u672C\\u304B"\n' +# +# >>> yaml_dump({"emoji": "👀", "some unicode": "日本か"}) +# 'emoji: "👀"\nsome unicode: "日本か"\n' +# ``` +yaml_dump: Callable[..., str] = partial(yaml.dump, stream=None, allow_unicode=True) # type: ignore + + +@contextlib.contextmanager +def SoftTemporaryDirectory( + suffix: Optional[str] = None, + prefix: Optional[str] = None, + dir: Optional[Union[Path, str]] = None, + **kwargs, +) -> Generator[Path, None, None]: + """ + Context manager to create a temporary directory and safely delete it. + + If tmp directory cannot be deleted normally, we set the WRITE permission and retry. + If cleanup still fails, we give up but don't raise an exception. This is equivalent + to `tempfile.TemporaryDirectory(..., ignore_cleanup_errors=True)` introduced in + Python 3.10. + + See https://www.scivision.dev/python-tempfile-permission-error-windows/. + """ + tmpdir = tempfile.TemporaryDirectory(prefix=prefix, suffix=suffix, dir=dir, **kwargs) + yield Path(tmpdir.name).resolve() + + try: + # First once with normal cleanup + shutil.rmtree(tmpdir.name) + except Exception: + # If failed, try to set write permission and retry + try: + shutil.rmtree(tmpdir.name, onerror=_set_write_permission_and_retry) + except Exception: + pass + + # And finally, cleanup the tmpdir. + # If it fails again, give up but do not throw error + try: + tmpdir.cleanup() + except Exception: + pass + + +def _set_write_permission_and_retry(func, path, excinfo): + os.chmod(path, stat.S_IWRITE) + func(path) + + +@contextlib.contextmanager +def WeakFileLock( + lock_file: Union[str, Path], *, timeout: Optional[float] = None +) -> Generator[BaseFileLock, None, None]: + """A filelock with some custom logic. + + This filelock is weaker than the default filelock in that: + 1. It won't raise an exception if release fails. + 2. It will default to a SoftFileLock if the filesystem does not support flock. + + An INFO log message is emitted every 10 seconds if the lock is not acquired immediately. + If a timeout is provided, a `filelock.Timeout` exception is raised if the lock is not acquired within the timeout. + """ + log_interval = constants.FILELOCK_LOG_EVERY_SECONDS + lock = FileLock(lock_file, timeout=log_interval) + start_time = time.time() + + while True: + elapsed_time = time.time() - start_time + if timeout is not None and elapsed_time >= timeout: + raise Timeout(str(lock_file)) + + try: + lock.acquire(timeout=min(log_interval, timeout - elapsed_time) if timeout else log_interval) + except Timeout: + logger.info( + f"Still waiting to acquire lock on {lock_file} (elapsed: {time.time() - start_time:.1f} seconds)" + ) + except NotImplementedError as e: + if "use SoftFileLock instead" in str(e): + logger.warning( + "FileSystem does not appear to support flock. Falling back to SoftFileLock for %s", lock_file + ) + lock = SoftFileLock(lock_file, timeout=log_interval) + continue + else: + break + + try: + yield lock + finally: + try: + lock.release() + except OSError: + try: + Path(lock_file).unlink() + except OSError: + pass diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_git_credential.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_git_credential.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ed77f4e49ca88ff4fa9aba48cbf00195036013 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_git_credential.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to manage Git credentials.""" + +import re +import subprocess +from typing import List, Optional + +from ..constants import ENDPOINT +from ._subprocess import run_interactive_subprocess, run_subprocess + + +GIT_CREDENTIAL_REGEX = re.compile( + r""" + ^\s* # start of line + credential\.helper # credential.helper value + \s*=\s* # separator + (\w+) # the helper name (group 1) + (\s|$) # whitespace or end of line + """, + flags=re.MULTILINE | re.IGNORECASE | re.VERBOSE, +) + + +def list_credential_helpers(folder: Optional[str] = None) -> List[str]: + """Return the list of git credential helpers configured. + + See https://git-scm.com/docs/gitcredentials. + + Credentials are saved in all configured helpers (store, cache, macOS keychain,...). + Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential. + + Args: + folder (`str`, *optional*): + The folder in which to check the configured helpers. + """ + try: + output = run_subprocess("git config --list", folder=folder).stdout + parsed = _parse_credential_output(output) + return parsed + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + +def set_git_credential(token: str, username: str = "hf_user", folder: Optional[str] = None) -> None: + """Save a username/token pair in git credential for HF Hub registry. + + Credentials are saved in all configured helpers (store, cache, macOS keychain,...). + Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential. + + Args: + username (`str`, defaults to `"hf_user"`): + A git username. Defaults to `"hf_user"`, the default user used in the Hub. + token (`str`, defaults to `"hf_user"`): + A git password. In practice, the User Access Token for the Hub. + See https://huggingface.co/settings/tokens. + folder (`str`, *optional*): + The folder in which to check the configured helpers. + """ + with run_interactive_subprocess("git credential approve", folder=folder) as ( + stdin, + _, + ): + stdin.write(f"url={ENDPOINT}\nusername={username.lower()}\npassword={token}\n\n") + stdin.flush() + + +def unset_git_credential(username: str = "hf_user", folder: Optional[str] = None) -> None: + """Erase credentials from git credential for HF Hub registry. + + Credentials are erased from the configured helpers (store, cache, macOS + keychain,...), if any. If `username` is not provided, any credential configured for + HF Hub endpoint is erased. + Calls "`git credential erase`" internally. See https://git-scm.com/docs/git-credential. + + Args: + username (`str`, defaults to `"hf_user"`): + A git username. Defaults to `"hf_user"`, the default user used in the Hub. + folder (`str`, *optional*): + The folder in which to check the configured helpers. + """ + with run_interactive_subprocess("git credential reject", folder=folder) as ( + stdin, + _, + ): + standard_input = f"url={ENDPOINT}\n" + if username is not None: + standard_input += f"username={username.lower()}\n" + standard_input += "\n" + + stdin.write(standard_input) + stdin.flush() + + +def _parse_credential_output(output: str) -> List[str]: + """Parse the output of `git credential fill` to extract the password. + + Args: + output (`str`): + The output of `git credential fill`. + """ + # NOTE: If user has set an helper for a custom URL, it will not we caught here. + # Example: `credential.https://huggingface.co.helper=store` + # See: https://github.com/huggingface/huggingface_hub/pull/1138#discussion_r1013324508 + return sorted( # Sort for nice printing + set( # Might have some duplicates + match[0] for match in GIT_CREDENTIAL_REGEX.findall(output) + ) + ) diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_headers.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_headers.py new file mode 100644 index 0000000000000000000000000000000000000000..053a92a398f8734ee14cd67e4b514dfc350fcecd --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_headers.py @@ -0,0 +1,228 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle headers to send in calls to Huggingface Hub.""" + +from typing import Dict, Optional, Union + +from huggingface_hub.errors import LocalTokenNotFoundError + +from .. import constants +from ._auth import get_token +from ._deprecation import _deprecate_arguments +from ._runtime import ( + get_fastai_version, + get_fastcore_version, + get_hf_hub_version, + get_python_version, + get_tf_version, + get_torch_version, + is_fastai_available, + is_fastcore_available, + is_tf_available, + is_torch_available, +) +from ._validators import validate_hf_hub_args + + +@_deprecate_arguments( + version="1.0", + deprecated_args="is_write_action", + custom_message="This argument is ignored and we let the server handle the permission error instead (if any).", +) +@validate_hf_hub_args +def build_hf_headers( + *, + token: Optional[Union[bool, str]] = None, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + headers: Optional[Dict[str, str]] = None, + is_write_action: bool = False, +) -> Dict[str, str]: + """ + Build headers dictionary to send in a HF Hub call. + + By default, authorization token is always provided either from argument (explicit + use) or retrieved from the cache (implicit use). To explicitly avoid sending the + token to the Hub, set `token=False` or set the `HF_HUB_DISABLE_IMPLICIT_TOKEN` + environment variable. + + In case of an API call that requires write access, an error is thrown if token is + `None` or token is an organization token (starting with `"api_org***"`). + + In addition to the auth header, a user-agent is added to provide information about + the installed packages (versions of python, huggingface_hub, torch, tensorflow, + fastai and fastcore). + + Args: + token (`str`, `bool`, *optional*): + The token to be sent in authorization header for the Hub call: + - if a string, it is used as the Hugging Face token + - if `True`, the token is read from the machine (cache or env variable) + - if `False`, authorization header is not set + - if `None`, the token is read from the machine only except if + `HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set. + library_name (`str`, *optional*): + The name of the library that is making the HTTP request. Will be added to + the user-agent header. + library_version (`str`, *optional*): + The version of the library that is making the HTTP request. Will be added + to the user-agent header. + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. It will + be completed with information about the installed packages. + headers (`dict`, *optional*): + Additional headers to include in the request. Those headers take precedence + over the ones generated by this function. + is_write_action (`bool`): + Ignored and deprecated argument. + + Returns: + A `Dict` of headers to pass in your API call. + + Example: + ```py + >>> build_hf_headers(token="hf_***") # explicit token + {"authorization": "Bearer hf_***", "user-agent": ""} + + >>> build_hf_headers(token=True) # explicitly use cached token + {"authorization": "Bearer hf_***",...} + + >>> build_hf_headers(token=False) # explicitly don't use cached token + {"user-agent": ...} + + >>> build_hf_headers() # implicit use of the cached token + {"authorization": "Bearer hf_***",...} + + # HF_HUB_DISABLE_IMPLICIT_TOKEN=True # to set as env variable + >>> build_hf_headers() # token is not sent + {"user-agent": ...} + + >>> build_hf_headers(library_name="transformers", library_version="1.2.3") + {"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"} + ``` + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If organization token is passed and "write" access is required. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If "write" access is required but token is not passed and not saved locally. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` but token is not saved locally. + """ + # Get auth token to send + token_to_send = get_token_to_send(token) + + # Combine headers + hf_headers = { + "user-agent": _http_user_agent( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ) + } + if token_to_send is not None: + hf_headers["authorization"] = f"Bearer {token_to_send}" + if headers is not None: + hf_headers.update(headers) + return hf_headers + + +def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]: + """Select the token to send from either `token` or the cache.""" + # Case token is explicitly provided + if isinstance(token, str): + return token + + # Case token is explicitly forbidden + if token is False: + return None + + # Token is not provided: we get it from local cache + cached_token = get_token() + + # Case token is explicitly required + if token is True: + if cached_token is None: + raise LocalTokenNotFoundError( + "Token is required (`token=True`), but no token found. You" + " need to provide a token or be logged in to Hugging Face with" + " `hf auth login` or `huggingface_hub.login`. See" + " https://huggingface.co/settings/tokens." + ) + return cached_token + + # Case implicit use of the token is forbidden by env variable + if constants.HF_HUB_DISABLE_IMPLICIT_TOKEN: + return None + + # Otherwise: we use the cached token as the user has not explicitly forbidden it + return cached_token + + +def _http_user_agent( + *, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, +) -> str: + """Format a user-agent string containing information about the installed packages. + + Args: + library_name (`str`, *optional*): + The name of the library that is making the HTTP request. + library_version (`str`, *optional*): + The version of the library that is making the HTTP request. + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. + + Returns: + The formatted user-agent string. + """ + if library_name is not None: + ua = f"{library_name}/{library_version}" + else: + ua = "unknown/None" + ua += f"; hf_hub/{get_hf_hub_version()}" + ua += f"; python/{get_python_version()}" + + if not constants.HF_HUB_DISABLE_TELEMETRY: + if is_torch_available(): + ua += f"; torch/{get_torch_version()}" + if is_tf_available(): + ua += f"; tensorflow/{get_tf_version()}" + if is_fastai_available(): + ua += f"; fastai/{get_fastai_version()}" + if is_fastcore_available(): + ua += f"; fastcore/{get_fastcore_version()}" + + if isinstance(user_agent, dict): + ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + + # Retrieve user-agent origin headers from environment variable + origin = constants.HF_HUB_USER_AGENT_ORIGIN + if origin is not None: + ua += "; origin/" + origin + + return _deduplicate_user_agent(ua) + + +def _deduplicate_user_agent(user_agent: str) -> str: + """Deduplicate redundant information in the generated user-agent.""" + # Split around ";" > Strip whitespaces > Store as dict keys (ensure unicity) > format back as string + # Order is implicitly preserved by dictionary structure (see https://stackoverflow.com/a/53657523). + return "; ".join({key.strip(): None for key in user_agent.split(";")}.keys()) diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_hf_folder.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_hf_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..6418bf2fd2c59b4bcf301c1dd82bc468f2f42ddf --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_hf_folder.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contain helper class to retrieve/store token from/to local cache.""" + +from pathlib import Path +from typing import Optional + +from .. import constants +from ._auth import get_token + + +class HfFolder: + # TODO: deprecate when adapted in transformers/datasets/gradio + # @_deprecate_method(version="1.0", message="Use `huggingface_hub.login` instead.") + @classmethod + def save_token(cls, token: str) -> None: + """ + Save token, creating folder as needed. + + Token is saved in the huggingface home folder. You can configure it by setting + the `HF_HOME` environment variable. + + Args: + token (`str`): + The token to save to the [`HfFolder`] + """ + path_token = Path(constants.HF_TOKEN_PATH) + path_token.parent.mkdir(parents=True, exist_ok=True) + path_token.write_text(token) + + # TODO: deprecate when adapted in transformers/datasets/gradio + # @_deprecate_method(version="1.0", message="Use `huggingface_hub.get_token` instead.") + @classmethod + def get_token(cls) -> Optional[str]: + """ + Get token or None if not existent. + + This method is deprecated in favor of [`huggingface_hub.get_token`] but is kept for backward compatibility. + Its behavior is the same as [`huggingface_hub.get_token`]. + + Returns: + `str` or `None`: The token, `None` if it doesn't exist. + """ + return get_token() + + # TODO: deprecate when adapted in transformers/datasets/gradio + # @_deprecate_method(version="1.0", message="Use `huggingface_hub.logout` instead.") + @classmethod + def delete_token(cls) -> None: + """ + Deletes the token from storage. Does not fail if token does not exist. + """ + try: + Path(constants.HF_TOKEN_PATH).unlink() + except FileNotFoundError: + pass diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_http.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_http.py new file mode 100644 index 0000000000000000000000000000000000000000..5baceb8f8fd511403aa30c93dfe1fd33068c8dfe --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_http.py @@ -0,0 +1,637 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle HTTP requests in Huggingface Hub.""" + +import io +import os +import re +import threading +import time +import uuid +from functools import lru_cache +from http import HTTPStatus +from shlex import quote +from typing import Any, Callable, List, Optional, Tuple, Type, Union + +import requests +from requests import HTTPError, Response +from requests.adapters import HTTPAdapter +from requests.models import PreparedRequest + +from huggingface_hub.errors import OfflineModeIsEnabled + +from .. import constants +from ..errors import ( + BadRequestError, + DisabledRepoError, + EntryNotFoundError, + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from . import logging +from ._fixes import JSONDecodeError +from ._lfs import SliceFileObj +from ._typing import HTTP_METHOD_T + + +logger = logging.get_logger(__name__) + +# Both headers are used by the Hub to debug failed requests. +# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB. +# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well. +X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" +X_REQUEST_ID = "x-request-id" + +REPO_API_REGEX = re.compile( + r""" + # staging or production endpoint + ^https://[^/]+ + ( + # on /api/repo_type/repo_id + /api/(models|datasets|spaces)/(.+) + | + # or /repo_id/resolve/revision/... + /(.+)/resolve/(.+) + ) + """, + flags=re.VERBOSE, +) + + +class UniqueRequestIdAdapter(HTTPAdapter): + X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" + + def add_headers(self, request, **kwargs): + super().add_headers(request, **kwargs) + + # Add random request ID => easier for server-side debug + if X_AMZN_TRACE_ID not in request.headers: + request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) + + # Add debug log + has_token = len(str(request.headers.get("authorization", ""))) > 0 + logger.debug( + f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})" + ) + + def send(self, request: PreparedRequest, *args, **kwargs) -> Response: + """Catch any RequestException to append request id to the error message for debugging.""" + if constants.HF_DEBUG: + logger.debug(f"Send: {_curlify(request)}") + try: + return super().send(request, *args, **kwargs) + except requests.RequestException as e: + request_id = request.headers.get(X_AMZN_TRACE_ID) + if request_id is not None: + # Taken from https://stackoverflow.com/a/58270258 + e.args = (*e.args, f"(Request ID: {request_id})") + raise + + +class OfflineAdapter(HTTPAdapter): + def send(self, request: PreparedRequest, *args, **kwargs) -> Response: + raise OfflineModeIsEnabled( + f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable." + ) + + +def _default_backend_factory() -> requests.Session: + session = requests.Session() + if constants.HF_HUB_OFFLINE: + session.mount("http://", OfflineAdapter()) + session.mount("https://", OfflineAdapter()) + else: + session.mount("http://", UniqueRequestIdAdapter()) + session.mount("https://", UniqueRequestIdAdapter()) + return session + + +BACKEND_FACTORY_T = Callable[[], requests.Session] +_GLOBAL_BACKEND_FACTORY: BACKEND_FACTORY_T = _default_backend_factory + + +def configure_http_backend(backend_factory: BACKEND_FACTORY_T = _default_backend_factory) -> None: + """ + Configure the HTTP backend by providing a `backend_factory`. Any HTTP calls made by `huggingface_hub` will use a + Session object instantiated by this factory. This can be useful if you are running your scripts in a specific + environment requiring custom configuration (e.g. custom proxy or certifications). + + Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, + `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` + set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between + calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. + + See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. + + Example: + ```py + import requests + from huggingface_hub import configure_http_backend, get_session + + # Create a factory function that returns a Session with configured proxies + def backend_factory() -> requests.Session: + session = requests.Session() + session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} + return session + + # Set it as the default session factory + configure_http_backend(backend_factory=backend_factory) + + # In practice, this is mostly done internally in `huggingface_hub` + session = get_session() + ``` + """ + global _GLOBAL_BACKEND_FACTORY + _GLOBAL_BACKEND_FACTORY = backend_factory + reset_sessions() + + +def get_session() -> requests.Session: + """ + Get a `requests.Session` object, using the session factory from the user. + + Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, + `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` + set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between + calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. + + See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. + + Example: + ```py + import requests + from huggingface_hub import configure_http_backend, get_session + + # Create a factory function that returns a Session with configured proxies + def backend_factory() -> requests.Session: + session = requests.Session() + session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} + return session + + # Set it as the default session factory + configure_http_backend(backend_factory=backend_factory) + + # In practice, this is mostly done internally in `huggingface_hub` + session = get_session() + ``` + """ + return _get_session_from_cache(process_id=os.getpid(), thread_id=threading.get_ident()) + + +def reset_sessions() -> None: + """Reset the cache of sessions. + + Mostly used internally when sessions are reconfigured or an SSLError is raised. + See [`configure_http_backend`] for more details. + """ + _get_session_from_cache.cache_clear() + + +@lru_cache +def _get_session_from_cache(process_id: int, thread_id: int) -> requests.Session: + """ + Create a new session per thread using global factory. Using LRU cache (maxsize 128) to avoid memory leaks when + using thousands of threads. Cache is cleared when `configure_http_backend` is called. + """ + return _GLOBAL_BACKEND_FACTORY() + + +def http_backoff( + method: HTTP_METHOD_T, + url: str, + *, + max_retries: int = 5, + base_wait_time: float = 1, + max_wait_time: float = 8, + retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + requests.Timeout, + requests.ConnectionError, + ), + retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, + **kwargs, +) -> Response: + """Wrapper around requests to retry calls on an endpoint, with exponential backoff. + + Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...) + and/or on specific status codes (ex: service unavailable). If the call failed more + than `max_retries`, the exception is thrown or `raise_for_status` is called on the + response object. + + Re-implement mechanisms from the `backoff` library to avoid adding an external + dependencies to `hugging_face_hub`. See https://github.com/litl/backoff. + + Args: + method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`): + HTTP method to perform. + url (`str`): + The URL of the resource to fetch. + max_retries (`int`, *optional*, defaults to `5`): + Maximum number of retries, defaults to 5 (no retries). + base_wait_time (`float`, *optional*, defaults to `1`): + Duration (in seconds) to wait before retrying the first time. + Wait time between retries then grows exponentially, capped by + `max_wait_time`. + max_wait_time (`float`, *optional*, defaults to `8`): + Maximum duration (in seconds) to wait before retrying. + retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): + Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. + By default, retry on `requests.Timeout` and `requests.ConnectionError`. + retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): + Define on which status codes the request must be retried. By default, only + HTTP 503 Service Unavailable is retried. + **kwargs (`dict`, *optional*): + kwargs to pass to `requests.request`. + + Example: + ``` + >>> from huggingface_hub.utils import http_backoff + + # Same usage as "requests.request". + >>> response = http_backoff("GET", "https://www.google.com") + >>> response.raise_for_status() + + # If you expect a Gateway Timeout from time to time + >>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504) + >>> response.raise_for_status() + ``` + + + + When using `requests` it is possible to stream data by passing an iterator to the + `data` argument. On http backoff this is a problem as the iterator is not reset + after a failed call. This issue is mitigated for file objects or any IO streams + by saving the initial position of the cursor (with `data.tell()`) and resetting the + cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff + will fail. If this is a hard constraint for you, please let us know by opening an + issue on [Github](https://github.com/huggingface/huggingface_hub). + + + """ + if isinstance(retry_on_exceptions, type): # Tuple from single exception type + retry_on_exceptions = (retry_on_exceptions,) + + if isinstance(retry_on_status_codes, int): # Tuple from single status code + retry_on_status_codes = (retry_on_status_codes,) + + nb_tries = 0 + sleep_time = base_wait_time + + # If `data` is used and is a file object (or any IO), it will be consumed on the + # first HTTP request. We need to save the initial position so that the full content + # of the file is re-sent on http backoff. See warning tip in docstring. + io_obj_initial_pos = None + if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)): + io_obj_initial_pos = kwargs["data"].tell() + + session = get_session() + while True: + nb_tries += 1 + try: + # If `data` is used and is a file object (or any IO), set back cursor to + # initial position. + if io_obj_initial_pos is not None: + kwargs["data"].seek(io_obj_initial_pos) + + # Perform request and return if status_code is not in the retry list. + response = session.request(method=method, url=url, **kwargs) + if response.status_code not in retry_on_status_codes: + return response + + # Wrong status code returned (HTTP 503 for instance) + logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}") + if nb_tries > max_retries: + response.raise_for_status() # Will raise uncaught exception + # We return response to avoid infinite loop in the corner case where the + # user ask for retry on a status code that doesn't raise_for_status. + return response + + except retry_on_exceptions as err: + logger.warning(f"'{err}' thrown while requesting {method} {url}") + + if isinstance(err, requests.ConnectionError): + reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects + + if nb_tries > max_retries: + raise err + + # Sleep for X seconds + logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].") + time.sleep(sleep_time) + + # Update sleep time for next retry + sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff + + +def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str: + """Replace the default endpoint in a URL by a custom one. + + This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint. + """ + endpoint = endpoint.rstrip("/") if endpoint else constants.ENDPOINT + # check if a proxy has been set => if yes, update the returned URL to use the proxy + if endpoint not in (constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT): + url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint) + url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint) + return url + + +def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: + """ + Internal version of `response.raise_for_status()` that will refine a + potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. + + This helper is meant to be the unique method to raise_for_status when making a call + to the Hugging Face Hub. + + + Example: + ```py + import requests + from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError + + response = get_session().post(...) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + print(str(e)) # formatted message + e.request_id, e.server_message # details returned by server + + # Complete the error message with additional information once it's raised + e.append_to_message("\n`create_commit` expects the repository to exist.") + raise + ``` + + Args: + response (`Response`): + Response from the server. + endpoint_name (`str`, *optional*): + Name of the endpoint that has been called. If provided, the error message + will be more complete. + + + + Raises when the request has failed: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it + doesn't exist, because `repo_type` is not set correctly, or because the repo + is `private` and you do not have access. + - [`~utils.GatedRepoError`] + If the repository exists but is gated and the user is not on the authorized + list. + - [`~utils.RevisionNotFoundError`] + If the repository exists but the revision couldn't be find. + - [`~utils.EntryNotFoundError`] + If the repository exists but the entry (e.g. the requested file) couldn't be + find. + - [`~utils.BadRequestError`] + If request failed with a HTTP 400 BadRequest error. + - [`~utils.HfHubHTTPError`] + If request failed for a reason not listed above. + + + """ + try: + response.raise_for_status() + except HTTPError as e: + error_code = response.headers.get("X-Error-Code") + error_message = response.headers.get("X-Error-Message") + + if error_code == "RevisionNotFound": + message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." + raise _format(RevisionNotFoundError, message, response) from e + + elif error_code == "EntryNotFound": + message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." + raise _format(EntryNotFoundError, message, response) from e + + elif error_code == "GatedRepo": + message = ( + f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." + ) + raise _format(GatedRepoError, message, response) from e + + elif error_message == "Access to this resource is disabled.": + message = ( + f"{response.status_code} Client Error." + + "\n\n" + + f"Cannot access repository for url {response.url}." + + "\n" + + "Access to this resource is disabled." + ) + raise _format(DisabledRepoError, message, response) from e + + elif error_code == "RepoNotFound" or ( + response.status_code == 401 + and error_message != "Invalid credentials in Authorization header" + and response.request is not None + and response.request.url is not None + and REPO_API_REGEX.search(response.request.url) is not None + ): + # 401 is misleading as it is returned for: + # - private and gated repos if user is not authenticated + # - missing repos + # => for now, we process them as `RepoNotFound` anyway. + # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9 + message = ( + f"{response.status_code} Client Error." + + "\n\n" + + f"Repository Not Found for url: {response.url}." + + "\nPlease make sure you specified the correct `repo_id` and" + " `repo_type`.\nIf you are trying to access a private or gated repo," + " make sure you are authenticated. For more details, see" + " https://huggingface.co/docs/huggingface_hub/authentication" + ) + raise _format(RepositoryNotFoundError, message, response) from e + + elif response.status_code == 400: + message = ( + f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:" + ) + raise _format(BadRequestError, message, response) from e + + elif response.status_code == 403: + message = ( + f"\n\n{response.status_code} Forbidden: {error_message}." + + f"\nCannot access content at: {response.url}." + + "\nMake sure your token has the correct permissions." + ) + raise _format(HfHubHTTPError, message, response) from e + + elif response.status_code == 416: + range_header = response.request.headers.get("Range") + message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}." + raise _format(HfHubHTTPError, message, response) from e + + # Convert `HTTPError` into a `HfHubHTTPError` to display request information + # as well (request id and/or server error message) + raise _format(HfHubHTTPError, str(e), response) from e + + +def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError: + server_errors = [] + + # Retrieve server error from header + from_headers = response.headers.get("X-Error-Message") + if from_headers is not None: + server_errors.append(from_headers) + + # Retrieve server error from body + try: + # Case errors are returned in a JSON format + data = response.json() + + error = data.get("error") + if error is not None: + if isinstance(error, list): + # Case {'error': ['my error 1', 'my error 2']} + server_errors.extend(error) + else: + # Case {'error': 'my error'} + server_errors.append(error) + + errors = data.get("errors") + if errors is not None: + # Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]} + for error in errors: + if "message" in error: + server_errors.append(error["message"]) + + except JSONDecodeError: + # If content is not JSON and not HTML, append the text + content_type = response.headers.get("Content-Type", "") + if response.text and "html" not in content_type.lower(): + server_errors.append(response.text) + + # Strip all server messages + server_errors = [str(line).strip() for line in server_errors if str(line).strip()] + + # Deduplicate server messages (keep order) + # taken from https://stackoverflow.com/a/17016257 + server_errors = list(dict.fromkeys(server_errors)) + + # Format server error + server_message = "\n".join(server_errors) + + # Add server error to custom message + final_error_message = custom_message + if server_message and server_message.lower() not in custom_message.lower(): + if "\n\n" in custom_message: + final_error_message += "\n" + server_message + else: + final_error_message += "\n\n" + server_message + # Add Request ID + request_id = str(response.headers.get(X_REQUEST_ID, "")) + if request_id: + request_id_message = f" (Request ID: {request_id})" + else: + # Fallback to X-Amzn-Trace-Id + request_id = str(response.headers.get(X_AMZN_TRACE_ID, "")) + if request_id: + request_id_message = f" (Amzn Trace ID: {request_id})" + if request_id and request_id.lower() not in final_error_message.lower(): + if "\n" in final_error_message: + newline_index = final_error_message.index("\n") + final_error_message = ( + final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:] + ) + else: + final_error_message += request_id_message + + # Return + return error_type(final_error_message.strip(), response=response, server_message=server_message or None) + + +def _curlify(request: requests.PreparedRequest) -> str: + """Convert a `requests.PreparedRequest` into a curl command (str). + + Used for debug purposes only. + + Implementation vendored from https://github.com/ofw/curlify/blob/master/curlify.py. + MIT License Copyright (c) 2016 Egor. + """ + parts: List[Tuple[Any, Any]] = [ + ("curl", None), + ("-X", request.method), + ] + + for k, v in sorted(request.headers.items()): + if k.lower() == "authorization": + v = "" # Hide authorization header, no matter its value (can be Bearer, Key, etc.) + parts += [("-H", "{0}: {1}".format(k, v))] + + if request.body: + body = request.body + if isinstance(body, bytes): + body = body.decode("utf-8", errors="ignore") + elif hasattr(body, "read"): + body = "" # Don't try to read it to avoid consuming the stream + if len(body) > 1000: + body = body[:1000] + " ... [truncated]" + parts += [("-d", body.replace("\n", ""))] + + parts += [(None, request.url)] + + flat_parts = [] + for k, v in parts: + if k: + flat_parts.append(quote(k)) + if v: + flat_parts.append(quote(v)) + + return " ".join(flat_parts) + + +# Regex to parse HTTP Range header +RANGE_REGEX = re.compile(r"^\s*bytes\s*=\s*(\d*)\s*-\s*(\d*)\s*$", re.IGNORECASE) + + +def _adjust_range_header(original_range: Optional[str], resume_size: int) -> Optional[str]: + """ + Adjust HTTP Range header to account for resume position. + """ + if not original_range: + return f"bytes={resume_size}-" + + if "," in original_range: + raise ValueError(f"Multiple ranges detected - {original_range!r}, not supported yet.") + + match = RANGE_REGEX.match(original_range) + if not match: + raise RuntimeError(f"Invalid range format - {original_range!r}.") + start, end = match.groups() + + if not start: + if not end: + raise RuntimeError(f"Invalid range format - {original_range!r}.") + + new_suffix = int(end) - resume_size + new_range = f"bytes=-{new_suffix}" + if new_suffix <= 0: + raise RuntimeError(f"Empty new range - {new_range!r}.") + return new_range + + start = int(start) + new_start = start + resume_size + if end: + end = int(end) + new_range = f"bytes={new_start}-{end}" + if new_start > end: + raise RuntimeError(f"Empty new range - {new_range!r}.") + return new_range + + return f"bytes={new_start}-" diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_lfs.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_lfs.py new file mode 100644 index 0000000000000000000000000000000000000000..307f371ffa79a8ae726ee03458c52e230a792898 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_lfs.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Git LFS related utilities""" + +import io +import os +from contextlib import AbstractContextManager +from typing import BinaryIO + + +class SliceFileObj(AbstractContextManager): + """ + Utility context manager to read a *slice* of a seekable file-like object as a seekable, file-like object. + + This is NOT thread safe + + Inspired by stackoverflow.com/a/29838711/593036 + + Credits to @julien-c + + Args: + fileobj (`BinaryIO`): + A file-like object to slice. MUST implement `tell()` and `seek()` (and `read()` of course). + `fileobj` will be reset to its original position when exiting the context manager. + seek_from (`int`): + The start of the slice (offset from position 0 in bytes). + read_limit (`int`): + The maximum number of bytes to read from the slice. + + Attributes: + previous_position (`int`): + The previous position + + Examples: + + Reading 200 bytes with an offset of 128 bytes from a file (ie bytes 128 to 327): + ```python + >>> with open("path/to/file", "rb") as file: + ... with SliceFileObj(file, seek_from=128, read_limit=200) as fslice: + ... fslice.read(...) + ``` + + Reading a file in chunks of 512 bytes + ```python + >>> import os + >>> chunk_size = 512 + >>> file_size = os.getsize("path/to/file") + >>> with open("path/to/file", "rb") as file: + ... for chunk_idx in range(ceil(file_size / chunk_size)): + ... with SliceFileObj(file, seek_from=chunk_idx * chunk_size, read_limit=chunk_size) as fslice: + ... chunk = fslice.read(...) + + ``` + """ + + def __init__(self, fileobj: BinaryIO, seek_from: int, read_limit: int): + self.fileobj = fileobj + self.seek_from = seek_from + self.read_limit = read_limit + + def __enter__(self): + self._previous_position = self.fileobj.tell() + end_of_stream = self.fileobj.seek(0, os.SEEK_END) + self._len = min(self.read_limit, end_of_stream - self.seek_from) + # ^^ The actual number of bytes that can be read from the slice + self.fileobj.seek(self.seek_from, io.SEEK_SET) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.fileobj.seek(self._previous_position, io.SEEK_SET) + + def read(self, n: int = -1): + pos = self.tell() + if pos >= self._len: + return b"" + remaining_amount = self._len - pos + data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount)) + return data + + def tell(self) -> int: + return self.fileobj.tell() - self.seek_from + + def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: + start = self.seek_from + end = start + self._len + if whence in (os.SEEK_SET, os.SEEK_END): + offset = start + offset if whence == os.SEEK_SET else end + offset + offset = max(start, min(offset, end)) + whence = os.SEEK_SET + elif whence == os.SEEK_CUR: + cur_pos = self.fileobj.tell() + offset = max(start - cur_pos, min(offset, end - cur_pos)) + else: + raise ValueError(f"whence value {whence} is not supported") + return self.fileobj.seek(offset, whence) - self.seek_from + + def __iter__(self): + yield self.read(n=4 * 1024 * 1024) diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_pagination.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_pagination.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef2b6668ba09d4c6a715509131d157139a1fac0 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_pagination.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle pagination on Huggingface Hub.""" + +from typing import Dict, Iterable, Optional + +import requests + +from . import get_session, hf_raise_for_status, http_backoff, logging + + +logger = logging.get_logger(__name__) + + +def paginate(path: str, params: Dict, headers: Dict) -> Iterable: + """Fetch a list of models/datasets/spaces and paginate through results. + + This is using the same "Link" header format as GitHub. + See: + - https://requests.readthedocs.io/en/latest/api/#requests.Response.links + - https://docs.github.com/en/rest/guides/traversing-with-pagination#link-header + """ + session = get_session() + r = session.get(path, params=params, headers=headers) + hf_raise_for_status(r) + yield from r.json() + + # Follow pages + # Next link already contains query params + next_page = _get_next_page(r) + while next_page is not None: + logger.debug(f"Pagination detected. Requesting next page: {next_page}") + r = http_backoff("GET", next_page, max_retries=20, retry_on_status_codes=429, headers=headers) + hf_raise_for_status(r) + yield from r.json() + next_page = _get_next_page(r) + + +def _get_next_page(response: requests.Response) -> Optional[str]: + return response.links.get("next", {}).get("url") diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/_paths.py b/phivenv/Lib/site-packages/huggingface_hub/utils/_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2c0ebce070bbde4900e919a3aca7cfc331e747 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/_paths.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle paths in Huggingface Hub.""" + +from fnmatch import fnmatch +from pathlib import Path +from typing import Callable, Generator, Iterable, List, Optional, TypeVar, Union + + +T = TypeVar("T") + +# Always ignore `.git` and `.cache/huggingface` folders in commits +DEFAULT_IGNORE_PATTERNS = [ + ".git", + ".git/*", + "*/.git", + "**/.git/**", + ".cache/huggingface", + ".cache/huggingface/*", + "*/.cache/huggingface", + "**/.cache/huggingface/**", +] +# Forbidden to commit these folders +FORBIDDEN_FOLDERS = [".git", ".cache"] + + +def filter_repo_objects( + items: Iterable[T], + *, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + key: Optional[Callable[[T], str]] = None, +) -> Generator[T, None, None]: + """Filter repo objects based on an allowlist and a denylist. + + Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects. + In the later case, `key` must be provided and specifies a function of one argument + that is used to extract a path from each element in iterable. + + Patterns are Unix shell-style wildcards which are NOT regular expressions. See + https://docs.python.org/3/library/fnmatch.html for more details. + + Args: + items (`Iterable`): + List of items to filter. + allow_patterns (`str` or `List[str]`, *optional*): + Patterns constituting the allowlist. If provided, item paths must match at + least one pattern from the allowlist. + ignore_patterns (`str` or `List[str]`, *optional*): + Patterns constituting the denylist. If provided, item paths must not match + any patterns from the denylist. + key (`Callable[[T], str]`, *optional*): + Single-argument function to extract a path from each item. If not provided, + the `items` must already be `str` or `Path`. + + Returns: + Filtered list of objects, as a generator. + + Raises: + :class:`ValueError`: + If `key` is not provided and items are not `str` or `Path`. + + Example usage with paths: + ```python + >>> # Filter only PDFs that are not hidden. + >>> list(filter_repo_objects( + ... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"], + ... allow_patterns=["*.pdf"], + ... ignore_patterns=[".*"], + ... )) + ["aaa.pdf"] + ``` + + Example usage with objects: + ```python + >>> list(filter_repo_objects( + ... [ + ... CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf") + ... CommitOperationAdd(path_or_fileobj="/tmp/bbb.jpg", path_in_repo="bbb.jpg") + ... CommitOperationAdd(path_or_fileobj="/tmp/.ccc.pdf", path_in_repo=".ccc.pdf") + ... CommitOperationAdd(path_or_fileobj="/tmp/.ddd.png", path_in_repo=".ddd.png") + ... ], + ... allow_patterns=["*.pdf"], + ... ignore_patterns=[".*"], + ... key=lambda x: x.repo_in_path + ... )) + [CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")] + ``` + """ + if isinstance(allow_patterns, str): + allow_patterns = [allow_patterns] + + if isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + + if allow_patterns is not None: + allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns] + if ignore_patterns is not None: + ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns] + + if key is None: + + def _identity(item: T) -> str: + if isinstance(item, str): + return item + if isinstance(item, Path): + return str(item) + raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") + + key = _identity # Items must be `str` or `Path`, otherwise raise ValueError + + for item in items: + path = key(item) + + # Skip if there's an allowlist and path doesn't match any + if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns): + continue + + # Skip if there's a denylist and path matches any + if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns): + continue + + yield item + + +def _add_wildcard_to_directories(pattern: str) -> str: + if pattern[-1] == "/": + return pattern + "*" + return pattern diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/endpoint_helpers.py b/phivenv/Lib/site-packages/huggingface_hub/utils/endpoint_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..85cd86011b78bcdc57034aeebc3c01e9e721ab50 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/endpoint_helpers.py @@ -0,0 +1,66 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Helpful utility functions and classes in relation to exploring API endpoints +with the aim for a user-friendly interface. +""" + +import math +import re +from typing import TYPE_CHECKING + +from ..repocard_data import ModelCardData + + +if TYPE_CHECKING: + from ..hf_api import ModelInfo + + +def _is_emission_within_threshold(model_info: "ModelInfo", minimum_threshold: float, maximum_threshold: float) -> bool: + """Checks if a model's emission is within a given threshold. + + Args: + model_info (`ModelInfo`): + A model info object containing the model's emission information. + minimum_threshold (`float`): + A minimum carbon threshold to filter by, such as 1. + maximum_threshold (`float`): + A maximum carbon threshold to filter by, such as 10. + + Returns: + `bool`: Whether the model's emission is within the given threshold. + """ + if minimum_threshold is None and maximum_threshold is None: + raise ValueError("Both `minimum_threshold` and `maximum_threshold` cannot both be `None`") + if minimum_threshold is None: + minimum_threshold = -1 + if maximum_threshold is None: + maximum_threshold = math.inf + + card_data = getattr(model_info, "card_data", None) + if card_data is None or not isinstance(card_data, (dict, ModelCardData)): + return False + + # Get CO2 emission metadata + emission = card_data.get("co2_eq_emissions", None) + if isinstance(emission, dict): + emission = emission["emissions"] + if not emission: + return False + + # Filter out if value is missing or out of range + matched = re.search(r"\d+\.\d+|\d+", str(emission)) + if matched is None: + return False + + emission_value = float(matched.group(0)) + return minimum_threshold <= emission_value <= maximum_threshold diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/insecure_hashlib.py b/phivenv/Lib/site-packages/huggingface_hub/utils/insecure_hashlib.py new file mode 100644 index 0000000000000000000000000000000000000000..6901b6d647cc706b85333a66f3bcb7d8c5e2ee9e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/insecure_hashlib.py @@ -0,0 +1,38 @@ +# Taken from https://github.com/mlflow/mlflow/pull/10119 +# +# DO NOT use this function for security purposes (e.g., password hashing). +# +# In Python >= 3.9, insecure hashing algorithms such as MD5 fail in FIPS-compliant +# environments unless `usedforsecurity=False` is explicitly passed. +# +# References: +# - https://github.com/mlflow/mlflow/issues/9905 +# - https://github.com/mlflow/mlflow/pull/10119 +# - https://docs.python.org/3/library/hashlib.html +# - https://github.com/huggingface/transformers/pull/27038 +# +# Usage: +# ```python +# # Use +# from huggingface_hub.utils.insecure_hashlib import sha256 +# # instead of +# from hashlib import sha256 +# +# # Use +# from huggingface_hub.utils import insecure_hashlib +# # instead of +# import hashlib +# ``` +import functools +import hashlib +import sys + + +if sys.version_info >= (3, 9): + md5 = functools.partial(hashlib.md5, usedforsecurity=False) + sha1 = functools.partial(hashlib.sha1, usedforsecurity=False) + sha256 = functools.partial(hashlib.sha256, usedforsecurity=False) +else: + md5 = hashlib.md5 + sha1 = hashlib.sha1 + sha256 = hashlib.sha256 diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/logging.py b/phivenv/Lib/site-packages/huggingface_hub/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..813719683a54cc65768bab5488e7ea153ad08d7e --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/logging.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright 2020 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logging utilities.""" + +import logging +import os +from logging import ( + CRITICAL, # NOQA + DEBUG, # NOQA + ERROR, # NOQA + FATAL, # NOQA + INFO, # NOQA + NOTSET, # NOQA + WARN, # NOQA + WARNING, # NOQA +) +from typing import Optional + +from .. import constants + + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _get_default_logging_level(): + """ + If `HF_HUB_VERBOSITY` env var is set to one of the valid choices return that as the new default level. If it is not + - fall back to `_default_log_level` + """ + env_level_str = os.getenv("HF_HUB_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option HF_HUB_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}" + ) + return _default_log_level + + +def _configure_library_root_logger() -> None: + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(logging.StreamHandler()) + library_root_logger.setLevel(_get_default_logging_level()) + + +def _reset_library_root_logger() -> None: + library_root_logger = _get_library_root_logger() + library_root_logger.setLevel(logging.NOTSET) + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Returns a logger with the specified name. This function is not supposed + to be directly accessed by library users. + + Args: + name (`str`, *optional*): + The name of the logger to get, usually the filename + + Example: + + ```python + >>> from huggingface_hub import get_logger + + >>> logger = get_logger(__file__) + >>> logger.set_verbosity_info() + ``` + """ + + if name is None: + name = _get_library_name() + + return logging.getLogger(name) + + +def get_verbosity() -> int: + """Return the current level for the HuggingFace Hub's root logger. + + Returns: + Logging level, e.g., `huggingface_hub.logging.DEBUG` and + `huggingface_hub.logging.INFO`. + + + + HuggingFace Hub has following logging levels: + + - `huggingface_hub.logging.CRITICAL`, `huggingface_hub.logging.FATAL` + - `huggingface_hub.logging.ERROR` + - `huggingface_hub.logging.WARNING`, `huggingface_hub.logging.WARN` + - `huggingface_hub.logging.INFO` + - `huggingface_hub.logging.DEBUG` + + + """ + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Sets the level for the HuggingFace Hub's root logger. + + Args: + verbosity (`int`): + Logging level, e.g., `huggingface_hub.logging.DEBUG` and + `huggingface_hub.logging.INFO`. + """ + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """ + Sets the verbosity to `logging.INFO`. + """ + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """ + Sets the verbosity to `logging.WARNING`. + """ + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """ + Sets the verbosity to `logging.DEBUG`. + """ + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """ + Sets the verbosity to `logging.ERROR`. + """ + return set_verbosity(ERROR) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is + disabled by default. + """ + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the + HuggingFace Hub's default handler to prevent double logging if the root + logger has been configured. + """ + _get_library_root_logger().propagate = True + + +_configure_library_root_logger() + +if constants.HF_DEBUG: + # If `HF_DEBUG` environment variable is set, set the verbosity of `huggingface_hub` logger to `DEBUG`. + set_verbosity_debug() diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/sha.py b/phivenv/Lib/site-packages/huggingface_hub/utils/sha.py new file mode 100644 index 0000000000000000000000000000000000000000..001c3fe8b2f37a64e890888ca3d521c10ec8f03b --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/sha.py @@ -0,0 +1,64 @@ +"""Utilities to efficiently compute the SHA 256 hash of a bunch of bytes.""" + +from typing import BinaryIO, Optional + +from .insecure_hashlib import sha1, sha256 + + +def sha_fileobj(fileobj: BinaryIO, chunk_size: Optional[int] = None) -> bytes: + """ + Computes the sha256 hash of the given file object, by chunks of size `chunk_size`. + + Args: + fileobj (file-like object): + The File object to compute sha256 for, typically obtained with `open(path, "rb")` + chunk_size (`int`, *optional*): + The number of bytes to read from `fileobj` at once, defaults to 1MB. + + Returns: + `bytes`: `fileobj`'s sha256 hash as bytes + """ + chunk_size = chunk_size if chunk_size is not None else 1024 * 1024 + + sha = sha256() + while True: + chunk = fileobj.read(chunk_size) + sha.update(chunk) + if not chunk: + break + return sha.digest() + + +def git_hash(data: bytes) -> str: + """ + Computes the git-sha1 hash of the given bytes, using the same algorithm as git. + + This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object + for more details. + + Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the + pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of + the LFS file content when we want to compare LFS files. + + Args: + data (`bytes`): + The data to compute the git-hash for. + + Returns: + `str`: the git-hash of `data` as an hexadecimal string. + + Example: + ```python + >>> from huggingface_hub.utils.sha import git_hash + >>> git_hash(b"Hello, World!") + 'b45ef6fec89518d314f546fd6c3025367b721684' + ``` + """ + # Taken from https://gist.github.com/msabramo/763200 + # Note: no need to optimize by reading the file in chunks as we're not supposed to hash huge files (5MB maximum). + sha = sha1() + sha.update(b"blob ") + sha.update(str(len(data)).encode()) + sha.update(b"\0") + sha.update(data) + return sha.hexdigest() diff --git a/phivenv/Lib/site-packages/huggingface_hub/utils/tqdm.py b/phivenv/Lib/site-packages/huggingface_hub/utils/tqdm.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1fcef4beb73bae13c57b3f66c5828e775b7cd9 --- /dev/null +++ b/phivenv/Lib/site-packages/huggingface_hub/utils/tqdm.py @@ -0,0 +1,307 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Utility helpers to handle progress bars in `huggingface_hub`. + +Example: + 1. Use `huggingface_hub.utils.tqdm` as you would use `tqdm.tqdm` or `tqdm.auto.tqdm`. + 2. To disable progress bars, either use `disable_progress_bars()` helper or set the + environment variable `HF_HUB_DISABLE_PROGRESS_BARS` to 1. + 3. To re-enable progress bars, use `enable_progress_bars()`. + 4. To check whether progress bars are disabled, use `are_progress_bars_disabled()`. + +NOTE: Environment variable `HF_HUB_DISABLE_PROGRESS_BARS` has the priority. + +Example: + ```py + >>> from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm + + # Disable progress bars globally + >>> disable_progress_bars() + + # Use as normal `tqdm` + >>> for _ in tqdm(range(5)): + ... pass + + # Still not showing progress bars, as `disable=False` is overwritten to `True`. + >>> for _ in tqdm(range(5), disable=False): + ... pass + + >>> are_progress_bars_disabled() + True + + # Re-enable progress bars globally + >>> enable_progress_bars() + + # Progress bar will be shown ! + >>> for _ in tqdm(range(5)): + ... pass + 100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s] + ``` + +Group-based control: + ```python + # Disable progress bars for a specific group + >>> disable_progress_bars("peft.foo") + + # Check state of different groups + >>> assert not are_progress_bars_disabled("peft")) + >>> assert not are_progress_bars_disabled("peft.something") + >>> assert are_progress_bars_disabled("peft.foo")) + >>> assert are_progress_bars_disabled("peft.foo.bar")) + + # Enable progress bars for a subgroup + >>> enable_progress_bars("peft.foo.bar") + + # Check if enabling a subgroup affects the parent group + >>> assert are_progress_bars_disabled("peft.foo")) + >>> assert not are_progress_bars_disabled("peft.foo.bar")) + + # No progress bar for `name="peft.foo"` + >>> for _ in tqdm(range(5), name="peft.foo"): + ... pass + + # Progress bar will be shown for `name="peft.foo.bar"` + >>> for _ in tqdm(range(5), name="peft.foo.bar"): + ... pass + 100%|███████████████████████████████████████| 5/5 [00:00<00:00, 117817.53it/s] + + ``` +""" + +import io +import logging +import os +import warnings +from contextlib import contextmanager, nullcontext +from pathlib import Path +from typing import ContextManager, Dict, Iterator, Optional, Union + +from tqdm.auto import tqdm as old_tqdm + +from ..constants import HF_HUB_DISABLE_PROGRESS_BARS + + +# The `HF_HUB_DISABLE_PROGRESS_BARS` environment variable can be True, False, or not set (None), +# allowing for control over progress bar visibility. When set, this variable takes precedence +# over programmatic settings, dictating whether progress bars should be shown or hidden globally. +# Essentially, the environment variable's setting overrides any code-based configurations. +# +# If `HF_HUB_DISABLE_PROGRESS_BARS` is not defined (None), it implies that users can manage +# progress bar visibility through code. By default, progress bars are turned on. + + +progress_bar_states: Dict[str, bool] = {} + + +def disable_progress_bars(name: Optional[str] = None) -> None: + """ + Disable progress bars either globally or for a specified group. + + This function updates the state of progress bars based on a group name. + If no group name is provided, all progress bars are disabled. The operation + respects the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable's setting. + + Args: + name (`str`, *optional*): + The name of the group for which to disable the progress bars. If None, + progress bars are disabled globally. + + Raises: + Warning: If the environment variable precludes changes. + """ + if HF_HUB_DISABLE_PROGRESS_BARS is False: + warnings.warn( + "Cannot disable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=0` is set and has priority." + ) + return + + if name is None: + progress_bar_states.clear() + progress_bar_states["_global"] = False + else: + keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")] + for key in keys_to_remove: + del progress_bar_states[key] + progress_bar_states[name] = False + + +def enable_progress_bars(name: Optional[str] = None) -> None: + """ + Enable progress bars either globally or for a specified group. + + This function sets the progress bars to enabled for the specified group or globally + if no group is specified. The operation is subject to the `HF_HUB_DISABLE_PROGRESS_BARS` + environment setting. + + Args: + name (`str`, *optional*): + The name of the group for which to enable the progress bars. If None, + progress bars are enabled globally. + + Raises: + Warning: If the environment variable precludes changes. + """ + if HF_HUB_DISABLE_PROGRESS_BARS is True: + warnings.warn( + "Cannot enable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=1` is set and has priority." + ) + return + + if name is None: + progress_bar_states.clear() + progress_bar_states["_global"] = True + else: + keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")] + for key in keys_to_remove: + del progress_bar_states[key] + progress_bar_states[name] = True + + +def are_progress_bars_disabled(name: Optional[str] = None) -> bool: + """ + Check if progress bars are disabled globally or for a specific group. + + This function returns whether progress bars are disabled for a given group or globally. + It checks the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable first, then the programmatic + settings. + + Args: + name (`str`, *optional*): + The group name to check; if None, checks the global setting. + + Returns: + `bool`: True if progress bars are disabled, False otherwise. + """ + if HF_HUB_DISABLE_PROGRESS_BARS is True: + return True + + if name is None: + return not progress_bar_states.get("_global", True) + + while name: + if name in progress_bar_states: + return not progress_bar_states[name] + name = ".".join(name.split(".")[:-1]) + + return not progress_bar_states.get("_global", True) + + +def is_tqdm_disabled(log_level: int) -> Optional[bool]: + """ + Determine if tqdm progress bars should be disabled based on logging level and environment settings. + + see https://github.com/huggingface/huggingface_hub/pull/2000 and https://github.com/huggingface/huggingface_hub/pull/2698. + """ + if log_level == logging.NOTSET: + return True + if os.getenv("TQDM_POSITION") == "-1": + return False + return None + + +class tqdm(old_tqdm): + """ + Class to override `disable` argument in case progress bars are globally disabled. + + Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324. + """ + + def __init__(self, *args, **kwargs): + name = kwargs.pop("name", None) # do not pass `name` to `tqdm` + if are_progress_bars_disabled(name): + kwargs["disable"] = True + super().__init__(*args, **kwargs) + + def __delattr__(self, attr: str) -> None: + """Fix for https://github.com/huggingface/huggingface_hub/issues/1603""" + try: + super().__delattr__(attr) + except AttributeError: + if attr != "_lock": + raise + + +@contextmanager +def tqdm_stream_file(path: Union[Path, str]) -> Iterator[io.BufferedReader]: + """ + Open a file as binary and wrap the `read` method to display a progress bar when it's streamed. + + First implemented in `transformers` in 2019 but removed when switched to git-lfs. Used in `huggingface_hub` to show + progress bar when uploading an LFS file to the Hub. See github.com/huggingface/transformers/pull/2078#discussion_r354739608 + for implementation details. + + Note: currently implementation handles only files stored on disk as it is the most common use case. Could be + extended to stream any `BinaryIO` object but we might have to debug some corner cases. + + Example: + ```py + >>> with tqdm_stream_file("config.json") as f: + >>> requests.put(url, data=f) + config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] + ``` + """ + if isinstance(path, str): + path = Path(path) + + with path.open("rb") as f: + total_size = path.stat().st_size + pbar = tqdm( + unit="B", + unit_scale=True, + total=total_size, + initial=0, + desc=path.name, + ) + + f_read = f.read + + def _inner_read(size: Optional[int] = -1) -> bytes: + data = f_read(size) + pbar.update(len(data)) + return data + + f.read = _inner_read # type: ignore + + yield f + + pbar.close() + + +def _get_progress_bar_context( + *, + desc: str, + log_level: int, + total: Optional[int] = None, + initial: int = 0, + unit: str = "B", + unit_scale: bool = True, + name: Optional[str] = None, + _tqdm_bar: Optional[tqdm] = None, +) -> ContextManager[tqdm]: + if _tqdm_bar is not None: + return nullcontext(_tqdm_bar) + # ^ `contextlib.nullcontext` mimics a context manager that does nothing + # Makes it easier to use the same code path for both cases but in the later + # case, the progress bar is not closed when exiting the context manager. + + return tqdm( + unit=unit, + unit_scale=unit_scale, + total=total, + initial=initial, + desc=desc, + disable=is_tqdm_disabled(log_level=log_level), + name=name, + )