|
|
| |
| |
| import torch |
| import torch.nn.functional as F |
| from transformers.generation.stopping_criteria import ( |
| MaxLengthCriteria, |
| StoppingCriteriaList, |
| ) |
| from typing import Union, List |
| from .eva_cache import EvaStaticCacheForTriton |
| from .eva_prep_kv_kernel import triton_eva_prep_kv_fwd |
|
|
| class MultibyteEosTokenCriteria: |
| """ |
| This class implements a simple stopping criteria to stop generation whenever |
| the "end-of-sequence" token is generated in the last `new_tokens` tokens. |
| |
| Adapted from |
| https://github.com/huggingface/transformers/blob/main/src/transformers/generation/stopping_criteria.py#L446 |
| By default, it uses the `model.generation_config.eos_token_id`. |
| |
| Args: |
| eos_token_id (`Union[int, List[int]]`): |
| The id(s) of the *end-of-sequence* token. |
| """ |
|
|
| def __init__(self, eos_token_ids: Union[int, List[int]]): |
| if isinstance(eos_token_ids, int): |
| eos_token_ids = [eos_token_ids] |
| self.eos_token_ids = eos_token_ids |
| |
| def __call__(self, input_ids: torch.LongTensor, new_tokens: int) -> bool: |
| current_input_len = input_ids.shape[-1] |
| new_token_ids = input_ids[:, current_input_len - new_tokens:] |
| for eos_token_id in self.eos_token_ids: |
| if torch.any(new_token_ids == eos_token_id): |
| return True |
| return False |
|
|
| def build_tree(spec): |
| nodes_at_depth = [] |
| nodes_at_depth.append([()]) |
|
|
| for d in range(1, len(spec) + 1): |
| prev_nodes = nodes_at_depth[d - 1] |
| spec_list = spec[d - 1] |
| current_nodes = [] |
| for node_idx, node in enumerate(prev_nodes): |
| if node_idx < len(spec_list): |
| num_children = spec_list[node_idx] |
| else: |
| num_children = 0 |
| for child_idx in range(num_children): |
| new_node = node + (child_idx,) |
| current_nodes.append(new_node) |
| nodes_at_depth.append(current_nodes) |
|
|
| |
| all_nodes = [node for depth_nodes in nodes_at_depth for node in depth_nodes if node] |
| return all_nodes |
|
|
| evabyte_7b_95 = build_tree( |
| [ |
| [10], |
| [10, 8, 2, 2, 1, 1], |
| [10, 4, 2, 1, 0, 0, 0, 0, 0, 0, 2, 1, 1, 0, 0, 0, 0, 0, 1], |
| [8, 2, 2, 1, 0, 0, 0, 0, 0, 0, 1], |
| [6, 2, 1, 1], |
| [4, 2, 1, 1], |
| [4, 2, 1], |
| ] |
| ) |
| evabyte_7b_31 = build_tree( |
| [ |
| [4], |
| [3, 2, 1, 1], |
| [3, 2, 1, 1], |
| [2, 1, 1], |
| [2, 1], |
| [2, 1], |
| [2, 1], |
| ] |
| ) |
| TOPK = 10 |
|
|
| def pad_path(path, length, pad_value=-2): |
| """ |
| Pad the given path list with a specific value up to a specified length. |
| |
| Parameters: |
| - path (list): The original list that needs padding. |
| - length (int): The desired length of the padded list. |
| - pad_value (optional, default=-2): The value to use for padding. |
| |
| Returns: |
| - list: A new list based on the original path but padded to the desired length. |
| |
| Example: |
| >>> pad_path([1,2,3], 5) |
| [1, 2, 3, -2, -2] |
| |
| Note: |
| If the given path is already longer than the specified length, |
| then no padding occurs, and the original path is returned. |
| """ |
| return path + [pad_value] * (length - len(path)) |
|
|
| def reset_past_key_values(passed_key_values): |
| """ |
| Resets the current lengths in the passed key-values to zero. |
| |
| This function is designed to be used during the evaluation of a baseline model. |
| It iterates through each layer's key-values and sets their current lengths to zero, |
| effectively resetting their state. |
| |
| Args: |
| - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer. |
| |
| Returns: |
| - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. |
| """ |
| for i in range(len(passed_key_values)): |
| for j in range(2): |
| passed_key_values[i][j].current_length.fill_(0) |
| return passed_key_values |
|
|
| def get_nucleus_one_token(logit, temperature, top_p): |
| """ |
| Performs token sampling based on the nucleus (top-p) sampling method. |
| |
| This function selects a token from a given logit distribution using the nucleus sampling strategy. |
| It allows for more controlled and diverse generation compared to traditional top-k sampling. |
| |
| Args: |
| logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC). |
| temperature (float): A temperature parameter to control the randomness in sampling. |
| Higher values increase diversity, lower values make selections more deterministic. |
| top_p (float): The cumulative probability threshold for nucleus sampling. |
| It controls the size of the set of high-probability tokens to consider for sampling. |
| |
| Returns: |
| torch.Tensor: A tensor containing the indices of the sampled tokens. |
| """ |
| if top_p >= 1: |
| return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1) |
| logit = logit / temperature |
| probs = torch.softmax(logit, dim=-1) |
| sorted_logits, sorted_indices = torch.sort(probs, descending=True) |
| cum_probs = torch.cumsum(sorted_logits, dim=-1) |
| sorted_indices_to_remove = cum_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
| indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) |
| logit[indices_to_remove] = float('-inf') |
| sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1) |
| return sampled_tokens |
|
|
| def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha): |
| """ |
| Implements token sampling based on the typical sampling method. |
| |
| This function selects a token from a given logit distribution using the typical sampling strategy, |
| aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods. |
| |
| Args: |
| logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor. |
| temperature (float): A parameter to control the randomness in sampling. |
| Higher values increase diversity, lower values make selections more deterministic. |
| posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling. |
| posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold. |
| |
| Returns: |
| torch.Tensor: A tensor containing the indices of the sampled tokens. |
| """ |
| logit = logit / temperature |
| probs = torch.softmax(logit, dim=-1) |
| entropy = -torch.sum( |
| probs * torch.log(probs + 1e-5), dim=-1 |
| ) |
| threshold = torch.minimum( |
| torch.ones_like(entropy) * posterior_threshold, |
| torch.exp(-entropy) * posterior_alpha, |
| ) |
| indices_to_remove = probs < threshold.unsqueeze(-1) |
| logit[indices_to_remove] = float('-inf') |
| sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1) |
| return sampled_tokens |
|
|
|
|
|
|
| def generate_medusa_buffers(medusa_choices, device="cuda"): |
| """ |
| Generate buffers for the Medusa structure based on the provided choices. |
| |
| Parameters: |
| - medusa_choices (list): A nested list representing tree in the Medusa structure. |
| - device (str): Device to which the tensors should be moved. Default is "cuda". |
| |
| Returns: |
| - dict: A dictionary containing buffers related to the Medusa structure. |
| """ |
|
|
| |
| sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x)) |
| medusa_len = len(sorted_medusa_choices) + 1 |
|
|
| |
| depth_counts = [0] * max([len(path) for path in sorted_medusa_choices]) |
| for path in sorted_medusa_choices: |
| depth_counts[len(path) - 1] += 1 |
| |
| |
| medusa_attn_mask = torch.eye(medusa_len, medusa_len) |
| medusa_attn_mask[:, 0] = 1 |
| start = 0 |
| for i in range(len(depth_counts)): |
| for j in range(depth_counts[i]): |
| cur_medusa_choice = sorted_medusa_choices[start + j] |
| |
| if len(cur_medusa_choice) == 1: |
| continue |
| ancestor_idx = [] |
| for c in range(len(cur_medusa_choice) - 1): |
| ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1) |
| medusa_attn_mask[j + start + 1, ancestor_idx] = 1 |
| start += depth_counts[i] |
|
|
| |
| medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long) |
| medusa_tree_indices[0] = 0 |
| start = 0 |
| for i in range(len(depth_counts)): |
| for j in range(depth_counts[i]): |
| cur_medusa_choice = sorted_medusa_choices[start + j] |
| medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 |
| start += depth_counts[i] |
|
|
| |
| medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long) |
| start = 0 |
| for i in range(len(depth_counts)): |
| medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 |
| start += depth_counts[i] |
|
|
| |
| retrieve_indices_nest = [] |
| retrieve_paths = [] |
| for i in range(len(sorted_medusa_choices)): |
| cur_medusa_choice = sorted_medusa_choices[-i-1] |
| retrieve_indice = [] |
| if cur_medusa_choice in retrieve_paths: |
| continue |
| else: |
| for c in range(len(cur_medusa_choice)): |
| retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1])) |
| retrieve_paths.append(cur_medusa_choice[:c+1]) |
| retrieve_indices_nest.append(retrieve_indice) |
| max_length = max([len(x) for x in retrieve_indices_nest]) |
| retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest] |
| retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) |
| retrieve_indices = retrieve_indices + 1 |
| retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1) |
|
|
| |
| medusa_buffers = { |
| "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0), |
| "tree_indices": medusa_tree_indices, |
| "medusa_position_ids": medusa_position_ids.unsqueeze(0), |
| "retrieve_indices": retrieve_indices, |
| } |
| |
| |
| medusa_buffers = { |
| k: v.clone().to(device) |
| if isinstance(v, torch.Tensor) |
| else torch.tensor(v, device=device) |
| for k, v in medusa_buffers.items() |
| } |
| return medusa_buffers |
|
|
| def generate_candidates( |
| medusa_logits, |
| logits, |
| tree_indices, |
| retrieve_indices, |
| temperature = 0, |
| posterior_threshold=0.3, |
| posterior_alpha = 0.09, |
| top_p=0.8, |
| sampling = 'typical', |
| fast = False |
| ): |
| |
| |
| |
| |
|
|
| |
| if temperature == 0 or fast: |
| candidates_ids = torch.argmax(logits[:, -1]).unsqueeze(0) |
| else: |
| if sampling == 'typical': |
| candidates_ids = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0) |
| elif sampling == 'nucleus': |
| candidates_ids = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0) |
| else: |
| raise NotImplementedError |
|
|
| |
| |
| |
| |
| |
| candidates_medusa_ids = torch.topk(medusa_logits[:, 0, -1], TOPK, dim=-1).indices |
|
|
| |
| candidate_ids = torch.cat([candidates_ids, candidates_medusa_ids.view(-1)], dim=-1) |
|
|
| |
| |
| |
| tree_candidate_ids = candidate_ids[tree_indices] |
|
|
| |
| |
| |
| tree_candidate_ids_ext = torch.cat( |
| [ |
| tree_candidate_ids, |
| torch.zeros((1), dtype=torch.long, device=tree_candidate_ids.device) |
| ], |
| dim=0 |
| ) |
| |
| unflattened_candidate_ids = tree_candidate_ids_ext[retrieve_indices] |
|
|
| tree_candidate_ids = tree_candidate_ids.unsqueeze(0) |
| |
| return tree_candidate_ids, unflattened_candidate_ids |
|
|
| def get_nucleus_posterior_mask(logits, candidates, temperature, top_p): |
| """ |
| Generates a posterior mask for token candidates using nucleus (top-p) sampling. |
| |
| This function applies nucleus sampling to a set of logits, and then generates a mask indicating |
| which candidate tokens are selected. It adapts the sampling strategy to accommodate for |
| temperature scaling and cumulative probability thresholding. |
| |
| Args: |
| logits (torch.Tensor): A tensor of logits from a language model output. |
| candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens. |
| temperature (float): A parameter to scale the logits, controlling randomness in sampling. |
| top_p (float): The cumulative probability threshold for nucleus sampling. |
| |
| Returns: |
| torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens. |
| """ |
| |
|
|
| |
| logits = logits[:, :-1] / temperature |
| n_samples, n_tokens = logits.shape[0], logits.shape[1] |
| logits = logits.view(n_samples*n_tokens, -1) |
| if top_p >= 1: |
| sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) |
| sampled_tokens = sampled_tokens.view(n_samples, n_tokens) |
| posterior_mask = (candidates[:, 1:] == sampled_tokens).int() |
| return posterior_mask |
| |
| probs = F.softmax(logits, dim=-1) |
| |
| sorted_logits, sorted_indices = torch.sort(probs, descending=True) |
|
|
| |
| cum_probs = torch.cumsum(sorted_logits, dim=-1) |
|
|
| |
| sorted_indices_to_remove = cum_probs > top_p |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| sorted_indices_to_remove[..., 0] = 0 |
|
|
| indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) |
|
|
| |
| |
| logits[indices_to_remove] = float('-inf') |
| |
| sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) |
| sampled_tokens = sampled_tokens.view(n_samples, n_tokens) |
| |
| posterior_mask = (candidates[:, 1:] == sampled_tokens).int() |
|
|
| return posterior_mask |
|
|
| def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha): |
| """ |
| Args: |
| logits (torch.Tensor): A tensor of logits from a language model output. |
| candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens. |
| temperature (float): A parameter to scale the logits, controlling randomness in sampling. |
| posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling. |
| posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold. |
| |
| Returns: |
| torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens. |
| """ |
| logits = logits[:, :-1] / temperature |
| n_samples, n_tokens = logits.shape[0], logits.shape[1] |
| logits = logits.view(n_samples*n_tokens, -1) |
| probs = F.softmax(logits, dim=-1) |
| entropy = -torch.sum( |
| probs * torch.log(probs + 1e-5), dim=-1 |
| ) |
| threshold = torch.minimum( |
| torch.ones_like(entropy) * posterior_threshold, |
| torch.exp(-entropy) * posterior_alpha, |
| ) |
| indices_to_remove = probs < threshold.unsqueeze(-1) |
| logits[indices_to_remove] = float('-inf') |
| sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) |
| sampled_tokens = sampled_tokens.view(n_samples, n_tokens) |
| posterior_mask = (candidates[:, 1:] == sampled_tokens).int() |
| return posterior_mask |
| |
| |
|
|
| def evaluate_posterior( |
| logits, |
| candidates, |
| temperature, |
| posterior_threshold=0.3, |
| posterior_alpha = 0.09, |
| top_p=0.8, |
| sampling = 'typical', |
| fast = True |
| ): |
| if logits.shape[1] <= 1: |
| return torch.tensor(0, dtype=torch.long, device=candidates.device), 0 |
| |
| if temperature == 0: |
| |
| posterior_mask = ( |
| candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1) |
| ).int() |
| candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) |
| accept_length = candidates_accept_length.max().item() |
| |
| if accept_length == 0: |
| |
| best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) |
| else: |
| best_candidate = torch.argmax(candidates_accept_length).to(torch.long) |
| return best_candidate, accept_length |
| elif sampling == 'typical': |
| if fast: |
| posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1) |
| candidates_prob = torch.gather( |
| posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1) |
| ).squeeze(-1) |
| posterior_entropy = -torch.sum( |
| posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1 |
| ) |
| threshold = torch.minimum( |
| torch.ones_like(posterior_entropy) * posterior_threshold, |
| torch.exp(-posterior_entropy) * posterior_alpha, |
| ) |
| posterior_mask = candidates_prob > threshold |
| candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) |
|
|
| |
| accept_length = candidates_accept_length.max().item() |
| if accept_length == 0: |
| |
| best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) |
| else: |
| best_candidates = torch.where(candidates_accept_length == accept_length)[0] |
| |
| likelihood = torch.sum( |
| torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1 |
| ) |
| best_candidate = best_candidates[torch.argmax(likelihood)] |
| return best_candidate, accept_length |
| |
| posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha) |
| candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) |
| |
| accept_length = candidates_accept_length.max().item() |
| |
| if accept_length == 0: |
| |
| best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) |
| else: |
| best_candidate = torch.argmax(candidates_accept_length).to(torch.long) |
| |
| return best_candidate, accept_length |
| elif sampling == 'nucleus': |
| assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1" |
| posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p) |
| candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) |
| accept_length = candidates_accept_length.max().item() |
| |
| if accept_length == 0: |
| |
| best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) |
| else: |
| best_candidate = torch.argmax(candidates_accept_length).to(torch.long) |
| return best_candidate, accept_length |
| else: |
| raise NotImplementedError |
|
|
| def update_inference_inputs( |
| input_ids, |
| medusa_logits, |
| logits, |
| candidate_ids, |
| best_candidate, |
| accept_length, |
| ): |
| input_ids = torch.cat( |
| [ |
| input_ids, |
| candidate_ids[None, best_candidate, : accept_length + 1] |
| ], |
| dim=-1 |
| ) |
| logits = logits[ |
| None, best_candidate, accept_length : accept_length + 1 |
| ] |
| medusa_logits = medusa_logits[ |
| :, None, best_candidate, accept_length : accept_length + 1 |
| ] |
| |
| new_token = accept_length + 1 |
| return input_ids, medusa_logits, logits, new_token |
|
|
| def split_logits(full_logits): |
| |
| logits = full_logits[..., 0, :] |
| medusa_logits = full_logits[..., 1:, :].permute(2, 0, 1, 3) |
| return medusa_logits, logits |
|
|
| class MultiByteDecodingMixin: |
| def multi_byte_pred_update_cache( |
| self, |
| past_key_values, |
| retrieve_indices, |
| best_candidate, |
| new_tokens, |
| ): |
| prev_window_len = past_key_values.get_past_window_pos(0) |
| select_indices = ( |
| retrieve_indices[best_candidate, : new_tokens] + prev_window_len |
| ) |
| for layer_idx in range(self.config.num_hidden_layers): |
|
|
| past_key_values.update_past_len(new_tokens, layer_idx) |
|
|
| past_window_k = past_key_values.past_window_k[layer_idx] |
| past_window_v = past_key_values.past_window_v[layer_idx] |
|
|
| tgt_window_k = past_window_k[..., select_indices, :] |
| tgt_window_v = past_window_v[..., select_indices, :] |
|
|
| dst_window_k = past_window_k[..., prev_window_len : prev_window_len + new_tokens, :] |
| dst_window_v = past_window_v[..., prev_window_len : prev_window_len + new_tokens, :] |
|
|
| dst_window_k.copy_(tgt_window_k, non_blocking=True) |
| dst_window_v.copy_(tgt_window_v, non_blocking=True) |
|
|
| new_window_len = prev_window_len + new_tokens |
| if new_window_len >= self.config.window_size: |
| assert new_window_len < 2 * self.config.window_size |
|
|
| dump_k = past_window_k[..., :self.config.window_size, :].clone() |
| dump_v = past_window_v[..., :self.config.window_size, :].clone() |
|
|
| _window_len = new_window_len - self.config.window_size |
| |
| if _window_len > 0: |
| new_window_k = past_window_k[..., self.config.window_size : new_window_len, :] |
| new_window_v = past_window_v[..., self.config.window_size : new_window_len, :] |
|
|
| _dst_window_k = past_window_k[..., : _window_len, :] |
| _dst_window_v = past_window_v[..., : _window_len, :] |
|
|
| _dst_window_k.copy_(new_window_k, non_blocking=True) |
| _dst_window_v.copy_(new_window_v, non_blocking=True) |
|
|
| past_key_values.past_window_pos[layer_idx] = _window_len |
| else: |
| dump_k = None |
| dump_v = None |
| past_key_values.past_window_pos[layer_idx] = new_window_len |
|
|
| if dump_k is not None and dump_v is not None: |
| rfa_k, rfa_v = triton_eva_prep_kv_fwd( |
| dump_k, dump_v, |
| self.model.layers[layer_idx].self_attn.adaptive_mu_k, |
| self.model.layers[layer_idx].self_attn.adaptive_phi, |
| None, |
| self.model.layers[layer_idx].self_attn.head_dim_scaling, |
| self.model.layers[layer_idx].self_attn.chunk_size |
| ) |
| rfa_k, rfa_v = past_key_values.update_chunk_rfas( |
| rfa_k, rfa_v, layer_idx |
| ) |
| return past_key_values |
|
|
| def _multi_byte_pred_update_cache_when_prefil_len_eq_window_size( |
| self, |
| past_key_values, |
| ): |
| prev_window_len = past_key_values.get_past_window_pos(0) |
| for layer_idx in range(self.config.num_hidden_layers): |
|
|
| past_window_k = past_key_values.past_window_k[layer_idx] |
| past_window_v = past_key_values.past_window_v[layer_idx] |
|
|
| new_window_len = prev_window_len |
| if new_window_len == self.config.window_size: |
| dump_k = past_window_k[..., :self.config.window_size, :].clone() |
| dump_v = past_window_v[..., :self.config.window_size, :].clone() |
| past_key_values.past_window_pos[layer_idx] = 0 |
|
|
| if dump_k is not None and dump_v is not None: |
| rfa_k, rfa_v = triton_eva_prep_kv_fwd( |
| dump_k, dump_v, |
| self.model.layers[layer_idx].self_attn.adaptive_mu_k, |
| self.model.layers[layer_idx].self_attn.adaptive_phi, |
| None, |
| self.model.layers[layer_idx].self_attn.head_dim_scaling, |
| self.model.layers[layer_idx].self_attn.chunk_size |
| ) |
| rfa_k, rfa_v = past_key_values.update_chunk_rfas( |
| rfa_k, rfa_v, layer_idx |
| ) |
| return past_key_values |
|
|
| def multi_byte_pred_update_attn_mask( |
| self, |
| last_iter_new_tokens, |
| tree_candidate_ids, |
| past_attn_mask, |
| medusa_attn_mask, |
| past_key_values, |
| ): |
| batch_size, tree_candidate_len = tree_candidate_ids.shape |
| seen_tokens = past_key_values.get_seq_length() |
| |
| |
| assert seen_tokens > 0 |
| |
| assert last_iter_new_tokens < self.config.window_size |
| |
| if past_attn_mask is not None and seen_tokens < self.config.window_size: |
| past_attn_mask = torch.cat( |
| [ |
| past_attn_mask, |
| torch.ones( |
| [batch_size, 1, tree_candidate_len, last_iter_new_tokens], |
| dtype=torch.bool, |
| device=self.device |
| ) |
| ], |
| dim=-1 |
| ) |
| else: |
| |
| |
| |
| chunks_per_window = int(self.config.window_size // self.config.chunk_size) |
|
|
| window_tokens = seen_tokens % self.config.window_size |
| num_windows_seen_so_far = seen_tokens // self.config.window_size |
| attn_mask_len = num_windows_seen_so_far * chunks_per_window + window_tokens |
| past_attn_mask = torch.ones( |
| (batch_size, 1, tree_candidate_len, attn_mask_len), |
| dtype=torch.bool, |
| device=self.device |
| ) |
|
|
| |
| tree_attn_mask = torch.cat( |
| [ |
| past_attn_mask, |
| medusa_attn_mask.to(torch.bool) |
| ], |
| dim=-1 |
| ) |
| return tree_attn_mask, past_attn_mask |
|
|
| @torch.no_grad() |
| def multi_byte_generate( |
| self, |
| input_ids, |
| attention_mask=None, |
| temperature=0.0, |
| max_length=None, |
| max_new_tokens=None, |
| stopping_criteria=None, |
| posterior_threshold=0.09, |
| posterior_alpha=0.3, |
| top_p=0.8, |
| sampling='typical', |
| fast=True, |
| do_sample=False, |
| medusa_choices=None, |
| return_acc_lengths=False |
| ): |
| if do_sample or temperature > 0.0: |
| fast = False |
|
|
| |
| if max_new_tokens is not None: |
| max_length = max_new_tokens + input_ids.shape[-1] |
| elif max_new_tokens is None and max_length is None: |
| max_length = getattr(self.config, "max_position_embeddings", 32768) |
|
|
| |
| eos_stop_criteria = MultibyteEosTokenCriteria(self.generation_config.eos_token_id) |
| stop_criteria = StoppingCriteriaList() |
| if max_length is not None: |
| max_position_embeddings = getattr(self.config, "max_position_embeddings", None) |
| stop_criteria.append( |
| MaxLengthCriteria( |
| max_length=max_length, |
| max_position_embeddings=max_position_embeddings, |
| ) |
| ) |
| if stopping_criteria is not None and len(stopping_criteria) > 0: |
| stop_criteria.extend(stopping_criteria) |
|
|
| assert input_ids.shape[0] == 1, "Only support batch size 1 for now" |
| assert attention_mask is None, "Only support attention mask None for now" |
| |
| input_ids = input_ids.clone() |
| position_ids = torch.arange(0, input_ids.shape[1], device=self.device, dtype=int).reshape(1, -1) |
|
|
| |
| |
| |
| if medusa_choices is None: |
| medusa_choices = evabyte_7b_95 |
| medusa_buffers = generate_medusa_buffers( |
| medusa_choices, device=self.device |
| ) |
|
|
| past_key_values = EvaStaticCacheForTriton( |
| input_ids.shape[0], |
| self.config.num_attention_heads, |
| |
| self.config.window_size + 256, |
| self.config.hidden_size // self.config.num_attention_heads, |
| self.config.num_hidden_layers, |
| self.lm_head.weight.dtype, |
| self.lm_head.weight.device, |
| ) |
| |
| full_logits, past_key_values = self.forward( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| use_cache=True, |
| past_key_values=past_key_values, |
| return_all_pred_logits=True, |
| multibyte_decoding=False, |
| ) |
| |
| |
| past_key_values = self._multi_byte_pred_update_cache_when_prefil_len_eq_window_size( |
| past_key_values |
| ) |
| medusa_logits, logits = split_logits(full_logits) |
|
|
| past_attn_mask = None |
| last_iter_new_tokens = 0 |
| max_iters = 32768 |
| if return_acc_lengths: |
| acc_lengths = [] |
| for _ in range(max_iters): |
| |
| |
| |
| tree_candidate_ids, unflattened_candidate_ids = generate_candidates( |
| medusa_logits, |
| logits, |
| medusa_buffers["tree_indices"], |
| medusa_buffers["retrieve_indices"], |
| temperature=temperature, |
| posterior_alpha=posterior_alpha, |
| posterior_threshold=posterior_threshold, |
| top_p=top_p, |
| sampling=sampling, |
| fast=fast, |
| ) |
|
|
| |
| |
| |
| |
| medusa_attn_mask, past_attn_mask = self.multi_byte_pred_update_attn_mask( |
| last_iter_new_tokens, |
| tree_candidate_ids, |
| past_attn_mask, |
| medusa_buffers["medusa_attn_mask"], |
| past_key_values, |
| ) |
| medusa_position_ids = medusa_buffers["medusa_position_ids"] + input_ids.shape[1] |
|
|
| |
| |
| |
| tree_full_logits, past_key_values = self.forward( |
| tree_candidate_ids, |
| past_key_values=past_key_values, |
| attention_mask=medusa_attn_mask, |
| position_ids=medusa_position_ids, |
| return_all_pred_logits=True, |
| multibyte_decoding=True, |
| ) |
| _medusa_logits, _logits = split_logits(tree_full_logits) |
| medusa_logits = _medusa_logits[..., 0, medusa_buffers["retrieve_indices"], :] |
| logits = _logits[..., 0, medusa_buffers["retrieve_indices"], :] |
|
|
| |
| |
| |
| |
| |
| |
| |
| tree_depth = unflattened_candidate_ids.shape[-1] |
| if tree_depth + past_key_values.get_past_window_pos(0) > self.config.window_size: |
| max_acc_len = self.config.window_size - past_key_values.get_past_window_pos(0) |
| _trimmed_unflattened_candidate_ids = unflattened_candidate_ids[:, :max_acc_len] |
| _trimmed_logits = logits[:, :max_acc_len] |
| else: |
| _trimmed_unflattened_candidate_ids = unflattened_candidate_ids |
| _trimmed_logits = logits |
| best_candidate, accept_length = evaluate_posterior( |
| _trimmed_logits, |
| _trimmed_unflattened_candidate_ids, |
| temperature, |
| posterior_threshold, |
| posterior_alpha, |
| top_p=top_p, |
| sampling=sampling, |
| fast=fast |
| ) |
|
|
| |
| |
| |
| input_ids, medusa_logits, logits, last_iter_new_tokens = update_inference_inputs( |
| input_ids, |
| medusa_logits, |
| logits, |
| unflattened_candidate_ids, |
| best_candidate, |
| accept_length, |
| ) |
|
|
| past_key_values = self.multi_byte_pred_update_cache( |
| past_key_values, |
| medusa_buffers["retrieve_indices"], |
| best_candidate, |
| last_iter_new_tokens, |
| ) |
|
|
| if return_acc_lengths: |
| acc_lengths.append(last_iter_new_tokens) |
| if stop_criteria(input_ids, None) or eos_stop_criteria(input_ids, last_iter_new_tokens): |
| if return_acc_lengths: |
| return input_ids, acc_lengths |
| else: |
| return input_ids |
| if return_acc_lengths: |
| return input_ids, acc_lengths |
| else: |
| return input_ids |
|
|