| """ |
| Text Generation Utilities for ASA Models |
| |
| Simple, dependency-free text generation with common decoding strategies. |
| |
| Repository: https://github.com/DigitalDaimyo/AddressedStateAttention |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from typing import Optional, Set, Tuple, List |
|
|
|
|
| __all__ = ['generate'] |
|
|
|
|
| def _forward_logits(model, input_ids, attention_mask=None): |
| """Extract logits from various model output formats.""" |
| out = model(input_ids, attention_mask=attention_mask) if attention_mask is not None else model(input_ids) |
| |
| if isinstance(out, torch.Tensor): |
| return out |
| if isinstance(out, (tuple, list)): |
| return out[0] |
| if isinstance(out, dict): |
| for key in ["logits", "out", "y", "pred"]: |
| if key in out: |
| return out[key] |
| raise TypeError(f"Unrecognized model output type: {type(out)}") |
|
|
|
|
| def _apply_repetition_penalty(logits: torch.Tensor, input_ids: torch.Tensor, penalty: float): |
| """Apply repetition penalty to logits (GPT-2 style).""" |
| if penalty is None or penalty == 1.0: |
| return logits |
| |
| B = logits.size(0) |
| for b in range(B): |
| prev_tokens = torch.unique(input_ids[b]) |
| l = logits[b, prev_tokens] |
| logits[b, prev_tokens] = torch.where(l < 0, l * penalty, l / penalty) |
| return logits |
|
|
|
|
| def _top_k_top_p_filtering( |
| logits: torch.Tensor, |
| top_k: int = 0, |
| top_p: float = 1.0, |
| min_tokens_to_keep: int = 1 |
| ): |
| """Filter logits using top-k and nucleus (top-p) filtering.""" |
| B, V = logits.shape |
| top_k = int(top_k) if top_k is not None else 0 |
| top_p = float(top_p) if top_p is not None else 1.0 |
|
|
| if top_k > 0 and top_k < V: |
| kth = torch.topk(logits, top_k, dim=-1).values[:, -1].unsqueeze(-1) |
| logits = logits.masked_fill(logits < kth, float("-inf")) |
|
|
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) |
| probs = F.softmax(sorted_logits, dim=-1) |
| cum = probs.cumsum(dim=-1) |
|
|
| remove = cum > top_p |
| if min_tokens_to_keep > 1: |
| remove[:, :min_tokens_to_keep] = False |
| remove = torch.cat([ |
| torch.zeros((B, 1), device=logits.device, dtype=torch.bool), |
| remove[:, :-1] |
| ], dim=-1) |
|
|
| sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) |
| logits = torch.full_like(logits, float("-inf")) |
| logits.scatter_(dim=-1, index=sorted_idx, src=sorted_logits) |
|
|
| return logits |
|
|
|
|
| def _update_seen_ngrams(seen: Set, tokens: List[int], n: int): |
| """Add n-gram to seen set.""" |
| if n > 0 and len(tokens) >= n: |
| seen.add(tuple(tokens[-n:])) |
|
|
|
|
| def _seed_seen_ngrams(input_ids: torch.Tensor, n: int) -> Set: |
| """Initialize seen n-grams from input.""" |
| seen = set() |
| if n <= 0: |
| return seen |
| tokens = input_ids[0].tolist() |
| if len(tokens) >= n: |
| for i in range(len(tokens) - n + 1): |
| seen.add(tuple(tokens[i:i+n])) |
| return seen |
|
|
|
|
| def _banned_from_seen(seen: Set, input_ids: torch.Tensor, n: int) -> Set: |
| """Get tokens banned by n-gram constraint.""" |
| if n <= 0 or input_ids.shape[1] < n - 1: |
| return set() |
| |
| prefix = tuple(input_ids[0, -(n - 1):].tolist()) |
| banned = set() |
| for ng in seen: |
| if ng[:-1] == prefix: |
| banned.add(ng[-1]) |
| return banned |
|
|
|
|
| @torch.no_grad() |
| def generate( |
| model, |
| tokenizer, |
| prompt: str, |
| max_new_tokens: int = 120, |
| max_seq_len: int = 1024, |
| strategy: str = "sample", |
| temperature: float = 1.0, |
| top_k: int = 0, |
| top_p: float = 0.9, |
| repetition_penalty: float = 1.0, |
| no_repeat_ngram_size: int = 0, |
| eos_token_id: Optional[int] = None, |
| device: str = "cuda", |
| ) -> str: |
| """ |
| Generate text from a prompt using various decoding strategies. |
| |
| Args: |
| model: ASA language model |
| tokenizer: HuggingFace tokenizer |
| prompt: Input text prompt |
| max_new_tokens: Maximum tokens to generate |
| max_seq_len: Maximum sequence length (truncates context if exceeded) |
| strategy: "greedy" or "sample" |
| temperature: Sampling temperature (higher = more random) |
| top_k: Keep only top k tokens (0 = disabled) |
| top_p: Nucleus sampling threshold (1.0 = disabled) |
| repetition_penalty: Penalty for repeating tokens (1.0 = disabled) |
| no_repeat_ngram_size: Block repeating n-grams (0 = disabled) |
| eos_token_id: Stop generation at this token |
| device: Device to run on |
| |
| Returns: |
| Generated text (including prompt) |
| |
| Example: |
| >>> text = generate( |
| ... model, tokenizer, |
| ... prompt="The capital of France is", |
| ... max_new_tokens=20, |
| ... strategy="greedy" |
| ... ) |
| """ |
| model.eval() |
| |
| enc = tokenizer(prompt, return_tensors="pt") |
| input_ids = enc.input_ids.to(device) |
| |
| if eos_token_id is None: |
| eos_token_id = tokenizer.eos_token_id |
| |
| seen = _seed_seen_ngrams(input_ids, no_repeat_ngram_size) |
| |
| for _ in range(max_new_tokens): |
| |
| if input_ids.shape[1] > max_seq_len: |
| input_ids = input_ids[:, -max_seq_len:] |
| seen = _seed_seen_ngrams(input_ids, no_repeat_ngram_size) |
| |
| logits = _forward_logits(model, input_ids) |
| next_logits = logits[:, -1, :].to(torch.float32).clone() |
| |
| |
| next_logits = _apply_repetition_penalty(next_logits, input_ids, repetition_penalty) |
| |
| |
| banned = _banned_from_seen(seen, input_ids, no_repeat_ngram_size) |
| if banned: |
| next_logits[0, list(banned)] = float("-inf") |
| |
| |
| if strategy == "greedy": |
| next_token = torch.argmax(next_logits, dim=-1, keepdim=True) |
| elif strategy == "sample": |
| temp = max(1e-6, float(temperature)) |
| next_logits = next_logits / temp |
| next_logits = _top_k_top_p_filtering(next_logits, top_k=top_k, top_p=top_p) |
| probs = F.softmax(next_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| raise ValueError(f"Unknown strategy '{strategy}'. Use 'greedy' or 'sample'.") |
| |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
| |
| |
| tokens = input_ids[0].tolist() |
| _update_seen_ngrams(seen, tokens, no_repeat_ngram_size) |
| |
| |
| if eos_token_id is not None and next_token.item() == eos_token_id: |
| break |
| |
| return tokenizer.decode(input_ids[0], skip_special_tokens=False) |
|
|