|
|
| from collections import deque
|
| import torch
|
| import torch.nn.functional as F
|
| import config
|
|
|
| class ActiveMemory:
|
| """
|
| An active memory module that stores and retrieves examples to enhance reasoning.
|
| Supports both logging for analysis and retrieval for improved predictions.
|
| """
|
| def __init__(self, max_size=config.MEMORY_MAX_SIZE, retrieval_k=config.MEMORY_RETRIEVAL_K):
|
| self.max_size = max_size
|
| self.retrieval_k = retrieval_k
|
| self.memory = deque(maxlen=max_size)
|
| self.device = config.DEVICE
|
| print(f"Initialized ActiveMemory with max size {self.max_size}, retrieval_k={self.retrieval_k}")
|
|
|
| def add(self, input_data, hidden_states, output, reasoning_trace, final_hidden_states=None, final_output=None):
|
| """
|
| Adds a new entry to the memory.
|
|
|
| Args:
|
| input_data: The input to the model (tokenized IDs, attention masks, etc.)
|
| hidden_states (H0): Initial hidden states from the base model
|
| output (y0): Initial prediction from the model
|
| reasoning_trace (T): Reasoning trace (all hidden states)
|
| final_hidden_states (H1, optional): Final hidden states after retroactive update
|
| final_output (y1, optional): Final prediction after retroactive update
|
| """
|
|
|
| entry = {
|
| 'input_ids': input_data.get('input_ids', None).cpu().detach() if input_data.get('input_ids', None) is not None else None,
|
| 'attention_mask': input_data.get('attention_mask', None).cpu().detach() if input_data.get('attention_mask', None) is not None else None,
|
| 'token_type_ids': input_data.get('token_type_ids', None).cpu().detach() if input_data.get('token_type_ids', None) is not None else None,
|
| 'hidden_states': hidden_states.cpu().detach(),
|
| 'output': {k: v.cpu().detach() for k, v in output.items()} if isinstance(output, dict) else output.cpu().detach(),
|
| 'reasoning_trace': tuple(h.cpu().detach() for h in reasoning_trace) if isinstance(reasoning_trace, tuple) else reasoning_trace.cpu().detach(),
|
| }
|
|
|
|
|
| if final_hidden_states is not None:
|
| entry['final_hidden_states'] = final_hidden_states.cpu().detach()
|
| if final_output is not None:
|
| entry['final_output'] = {k: v.cpu().detach() for k, v in final_output.items()} if isinstance(final_output, dict) else final_output.cpu().detach()
|
|
|
|
|
|
|
| if entry['hidden_states'] is not None and entry['attention_mask'] is not None:
|
|
|
| mask = entry['attention_mask'].unsqueeze(-1).float()
|
| masked_embeddings = entry['hidden_states'] * mask
|
| sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
| sum_mask = torch.sum(mask, dim=1)
|
| sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| entry['summary_vector'] = (sum_embeddings / sum_mask).squeeze(0)
|
| else:
|
|
|
| entry['summary_vector'] = entry['hidden_states'].mean(dim=1).squeeze(0)
|
|
|
| self.memory.append(entry)
|
|
|
| def retrieve(self, query_hidden_states, query_attention_mask=None, k=None):
|
| """
|
| Retrieves the k most similar examples from memory based on hidden state similarity.
|
|
|
| Args:
|
| query_hidden_states: Hidden states to compare against memory
|
| query_attention_mask: Attention mask for the query
|
| k: Number of examples to retrieve (defaults to self.retrieval_k)
|
|
|
| Returns:
|
| List of retrieved memory entries, ordered by similarity (most similar first)
|
| """
|
| if len(self.memory) == 0:
|
| return []
|
|
|
| if k is None:
|
| k = self.retrieval_k
|
|
|
| k = min(k, len(self.memory))
|
|
|
|
|
| if query_attention_mask is not None:
|
| mask = query_attention_mask.unsqueeze(-1).float()
|
| masked_embeddings = query_hidden_states * mask
|
| sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
| sum_mask = torch.sum(mask, dim=1)
|
| sum_mask = torch.clamp(sum_mask, min=1e-9)
|
| query_vector = (sum_embeddings / sum_mask).squeeze(0)
|
| else:
|
| query_vector = query_hidden_states.mean(dim=1).squeeze(0)
|
|
|
|
|
| query_vector = query_vector.cpu().detach()
|
|
|
|
|
| similarities = []
|
| for i, entry in enumerate(self.memory):
|
| memory_vector = entry['summary_vector']
|
|
|
| similarity = F.cosine_similarity(query_vector, memory_vector, dim=0)
|
| similarities.append((i, similarity.item()))
|
|
|
|
|
| similarities.sort(key=lambda x: x[1], reverse=True)
|
| top_k_indices = [idx for idx, _ in similarities[:k]]
|
|
|
|
|
| retrieved_entries = [self.memory[idx] for idx in top_k_indices]
|
|
|
|
|
| device = query_hidden_states.device
|
| for entry in retrieved_entries:
|
|
|
| if 'hidden_states' in entry:
|
| entry['hidden_states'] = entry['hidden_states'].to(device)
|
| if 'final_hidden_states' in entry:
|
| entry['final_hidden_states'] = entry['final_hidden_states'].to(device)
|
|
|
| return retrieved_entries
|
|
|
| def get_memory_context(self, query_hidden_states, query_attention_mask=None):
|
| """
|
| Retrieves and processes memory entries to create a context tensor for the model.
|
|
|
| Args:
|
| query_hidden_states: Hidden states to compare against memory
|
| query_attention_mask: Attention mask for the query
|
|
|
| Returns:
|
| memory_context: Tensor of shape (batch_size, seq_len, hidden_dim) containing
|
| processed memory information, or None if memory is empty
|
| """
|
|
|
| retrieved = self.retrieve(query_hidden_states, query_attention_mask)
|
|
|
| if not retrieved:
|
| return None
|
|
|
|
|
| device = query_hidden_states.device
|
| batch_size, seq_len, hidden_dim = query_hidden_states.shape
|
|
|
|
|
|
|
| memory_tensors = []
|
| for entry in retrieved:
|
|
|
| if 'final_hidden_states' in entry and entry['final_hidden_states'] is not None:
|
| memory_tensors.append(entry['final_hidden_states'])
|
| elif 'hidden_states' in entry:
|
| memory_tensors.append(entry['hidden_states'])
|
|
|
| if not memory_tensors:
|
| return None
|
|
|
|
|
|
|
| padded_tensors = []
|
| for tensor in memory_tensors:
|
| if tensor.size(1) < seq_len:
|
|
|
| padding = torch.zeros(1, seq_len - tensor.size(1), hidden_dim, device=device)
|
| padded_tensor = torch.cat([tensor, padding], dim=1)
|
| padded_tensors.append(padded_tensor)
|
| elif tensor.size(1) > seq_len:
|
|
|
| padded_tensors.append(tensor[:, :seq_len, :])
|
| else:
|
| padded_tensors.append(tensor)
|
|
|
|
|
| memory_context = torch.stack(padded_tensors).mean(dim=0)
|
|
|
|
|
| if memory_context.size(0) == 1 and batch_size > 1:
|
| memory_context = memory_context.expand(batch_size, -1, -1)
|
|
|
| return memory_context
|
|
|
| def clear(self):
|
| """Clears all entries from memory."""
|
| self.memory.clear()
|
|
|
| def __len__(self):
|
| return len(self.memory)
|
|
|