| from __future__ import annotations | |
| from collections import OrderedDict | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Iterable, Hashable | |
| import torch | |
| class _CacheEntry: | |
| value: Any | |
| size_bytes: int | |
| class TextEncoderCache: | |
| def __init__(self, max_size_mb: float = 100) -> None: | |
| self.max_size_bytes = int(max_size_mb * 1024 * 1024) | |
| self._entries: "OrderedDict[Hashable, _CacheEntry]" = OrderedDict() | |
| self._size_bytes = 0 | |
| def encode( | |
| self, | |
| encode_fn: Callable[[list[str]], list[Any]], | |
| prompts: Iterable[str] | str, | |
| device: torch.device | str | None = None, | |
| parallel: bool = False, | |
| cache_keys: Iterable[Hashable] | Hashable | None = None, | |
| ) -> list[Any]: | |
| if isinstance(prompts, str): | |
| prompts_list = [prompts] | |
| else: | |
| prompts_list = list(prompts) | |
| if not prompts_list: | |
| return [] | |
| if cache_keys is None: | |
| keys_list = prompts_list | |
| else: | |
| if len(prompts_list) == 1 and not isinstance(cache_keys, list): | |
| keys_list = [cache_keys] | |
| else: | |
| keys_list = list(cache_keys) | |
| if len(keys_list) != len(prompts_list): | |
| raise ValueError("cache_keys must match the number of prompts.") | |
| if not parallel: | |
| results: list[Any] = [] | |
| for prompt, cache_key in zip(prompts_list, keys_list): | |
| cached = self._entries.get(cache_key) | |
| if cached is not None: | |
| self._entries.move_to_end(cache_key) | |
| results.append(self._to_device(cached.value, device)) | |
| continue | |
| encoded = encode_fn([prompt]) | |
| if isinstance(encoded, (list, tuple)): | |
| if not encoded: | |
| raise ValueError("encode_fn returned empty embeddings.") | |
| encoded_item = encoded[0] | |
| else: | |
| encoded_item = encoded | |
| results.append(self._store(cache_key, encoded_item, device)) | |
| return results | |
| results = [None] * len(prompts_list) | |
| missing_prompts: list[str] = [] | |
| missing_indices: list[int] = [] | |
| missing_keys: list[Hashable] = [] | |
| for idx, (prompt, cache_key) in enumerate(zip(prompts_list, keys_list)): | |
| cached = self._entries.get(cache_key) | |
| if cached is None: | |
| missing_prompts.append(prompt) | |
| missing_indices.append(idx) | |
| missing_keys.append(cache_key) | |
| continue | |
| self._entries.move_to_end(cache_key) | |
| results[idx] = self._to_device(cached.value, device) | |
| if missing_prompts: | |
| encoded_batch = encode_fn(missing_prompts) | |
| if not isinstance(encoded_batch, list): | |
| encoded_batch = list(encoded_batch) | |
| if len(encoded_batch) != len(missing_prompts): | |
| raise ValueError("encode_fn returned unexpected number of embeddings.") | |
| for cache_key, idx, encoded in zip(missing_keys, missing_indices, encoded_batch): | |
| results[idx] = self._store(cache_key, encoded, device) | |
| return results | |
| def _store(self, cache_key: Hashable, encoded: Any, device: torch.device | str | None) -> Any: | |
| cached_value = self._detach_to_cpu(encoded) | |
| size_bytes = self._estimate_size_bytes(cached_value) | |
| if size_bytes <= self.max_size_bytes: | |
| existing = self._entries.pop(cache_key, None) | |
| if existing is not None: | |
| self._size_bytes -= existing.size_bytes | |
| self._entries[cache_key] = _CacheEntry(cached_value, size_bytes) | |
| self._size_bytes += size_bytes | |
| self._purge_if_needed() | |
| else: | |
| if cache_key in self._entries: | |
| self._entries.move_to_end(cache_key) | |
| return self._to_device(encoded, device) | |
| def _purge_if_needed(self) -> None: | |
| if self._size_bytes <= self.max_size_bytes: | |
| return | |
| while self._entries and self._size_bytes > self.max_size_bytes: | |
| _, entry = self._entries.popitem(last=False) | |
| self._size_bytes -= entry.size_bytes | |
| def _estimate_size_bytes(self, value: Any) -> int: | |
| if torch.is_tensor(value): | |
| return int(value.numel() * value.element_size()) | |
| if isinstance(value, dict): | |
| return sum(self._estimate_size_bytes(v) for v in value.values()) | |
| if isinstance(value, (list, tuple)): | |
| return sum(self._estimate_size_bytes(v) for v in value) | |
| return 0 | |
| def _detach_to_cpu(self, value: Any) -> Any: | |
| if torch.is_tensor(value): | |
| if value.device.type == "cpu": | |
| return value.detach() | |
| return value.detach().to("cpu") | |
| if isinstance(value, dict): | |
| return {k: self._detach_to_cpu(v) for k, v in value.items()} | |
| if isinstance(value, tuple): | |
| items = [self._detach_to_cpu(v) for v in value] | |
| if hasattr(value, "_fields"): | |
| return value.__class__(*items) | |
| return tuple(items) | |
| if isinstance(value, list): | |
| return [self._detach_to_cpu(v) for v in value] | |
| return value | |
| def _to_device(self, value: Any, device: torch.device | str | None) -> Any: | |
| if device is None: | |
| return value | |
| if torch.is_tensor(value): | |
| return value.to(device) | |
| if isinstance(value, dict): | |
| return {k: self._to_device(v, device) for k, v in value.items()} | |
| if isinstance(value, tuple): | |
| items = [self._to_device(v, device) for v in value] | |
| if hasattr(value, "_fields"): | |
| return value.__class__(*items) | |
| return tuple(items) | |
| if isinstance(value, list): | |
| return [self._to_device(v, device) for v in value] | |
| return value | |
Xet Storage Details
- Size:
- 6.02 kB
- Xet hash:
- 8b00447829919f505d7a71e3c12ff73c4be2702ad20951331bd8068561363b88
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.