| | import sys |
| | from pathlib import Path |
| | sys.path.append(str(Path(__file__).resolve().parent.parent)) |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from models.model import Microformer |
| | from tokenizers import Tokenizer |
| | from config import VOCAB_SIZE, EMBED_DIM, NUM_HEADS, FF_DIM, NUM_LAYERS, MAX_SEQ_LEN, ADAPTER_DIM |
| | import sqlite3 |
| | from datetime import datetime |
| |
|
| | |
| | tokenizer = Tokenizer.from_file("data/tokenizer.json") |
| | VOCAB_SIZE = tokenizer.get_vocab_size() |
| |
|
| | model = Microformer( |
| | vocab_size=VOCAB_SIZE, |
| | embed_dim=EMBED_DIM, |
| | num_heads=NUM_HEADS, |
| | ff_dim=FF_DIM, |
| | num_layers=NUM_LAYERS, |
| | max_seq_len=MAX_SEQ_LEN, |
| | long_term_adapter_dim=ADAPTER_DIM, |
| | session_adapter_dim=ADAPTER_DIM |
| | ) |
| | model.load_state_dict(torch.load("microformer.pt")) |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model.to(device) |
| |
|
| | |
| | model.freeze_except_adapters(session_only=True, include_output=True) |
| |
|
| | criterion = nn.CrossEntropyLoss() |
| | optimizer = optim.Adam( |
| | filter(lambda p: p.requires_grad, model.parameters()), |
| | lr=1e-2 |
| | ) |
| |
|
| | |
| | conn = sqlite3.connect("memory.db") |
| | c = conn.cursor() |
| | c.execute(""" |
| | CREATE TABLE IF NOT EXISTS memory ( |
| | timestamp TEXT, |
| | prompt TEXT, |
| | response TEXT |
| | ) |
| | """) |
| | conn.commit() |
| |
|
| | def top_k_top_p_filtering(logits, top_k=50, top_p=0.9): |
| | logits = logits.squeeze(0) |
| | probs = torch.softmax(logits, dim=-1) |
| |
|
| | |
| | sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
| | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| |
|
| | |
| | sorted_mask = cumulative_probs > top_p |
| | sorted_mask[1:] = sorted_mask[:-1].clone() |
| | sorted_mask[0] = False |
| |
|
| | |
| | if top_k < sorted_probs.size(0): |
| | sorted_mask[top_k:] = True |
| |
|
| | |
| | sorted_probs[sorted_mask] = 0.0 |
| |
|
| | |
| | sorted_probs /= sorted_probs.sum() |
| | sampled_relative_index = torch.multinomial(sorted_probs, 1).item() |
| | sampled_token_id = sorted_indices[sampled_relative_index].item() |
| |
|
| | return sampled_token_id |
| |
|
| | def generate(prompt, length=100, temperature=1.0, top_p=0.9, top_k=50): |
| | input_ids = tokenizer.encode(prompt).ids |
| | input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
| |
|
| | eos_token_id = tokenizer.token_to_id("<EOS>") |
| |
|
| | for _ in range(length): |
| | with torch.no_grad(): |
| | logits = model(input_tensor) |
| | logits = logits[:, -1, :] / temperature |
| |
|
| | |
| | for token_id in input_tensor[0].tolist(): |
| | logits[0, token_id] *= 0.8 |
| |
|
| | next_token_id = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
| |
|
| | input_tensor = torch.cat([input_tensor, torch.tensor([[next_token_id]], device=device)], dim=1) |
| |
|
| | if next_token_id == eos_token_id: |
| | break |
| |
|
| | output_ids = input_tensor[0].tolist() |
| | decoded = tokenizer.decode(output_ids) |
| |
|
| | if "<EOS>" in decoded: |
| | decoded = decoded.split("<EOS>")[0].strip() |
| |
|
| | return decoded |
| |
|
| | def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64): |
| | |
| | ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")] |
| | if len(ids) < 2: |
| | return None |
| |
|
| | ids = ids[:max_len + 1] |
| | input_ids = ids[:-1] |
| | target_ids = ids[1:] |
| | input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids)) |
| | target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids)) |
| | input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
| | target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device) |
| |
|
| | model.train() |
| | logits = model(input_tensor) |
| | logits = logits.view(-1, logits.size(-1)) |
| | targets = target_tensor.view(-1) |
| | loss = loss_fn(logits, targets) |
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| | model.eval() |
| | return loss.item() |
| |
|
| | |
| | def reset_session_adapters(model): |
| | for layer in model.layers: |
| | if getattr(layer, 'session_adapter', None) is not None: |
| | for param in layer.session_adapter.parameters(): |
| | if param.data is not None: |
| | nn.init.zeros_(param.data) |
| |
|
| | if __name__ == "__main__": |
| | while True: |
| | prompt = input("\nEnter a prompt (or 'exit' to quit): ") |
| | if prompt.lower() in {"exit", "quit"}: |
| | break |
| | temp = float(input("Temperature (e.g. 0.7, 1.0): ")) |
| |
|
| | output = generate(prompt, length=100, temperature=temp, top_p=0.9, top_k=50) |
| | print("\nGenerated text:\n") |
| | print(output) |
| |
|
| | |
| | teach = input("\nDo you want to teach the model a better answer? (y/N): ").strip().lower() |
| | if teach == "y": |
| | your_answer = input("Type your ideal response for this prompt: ") |
| | model.freeze_except_adapters(session_only=True, include_output=True) |
| | online_text = prompt + " " + your_answer |
| | loss = online_unsupervised_update( |
| | model, tokenizer, online_text, optimizer, criterion, device, max_len=MAX_SEQ_LEN |
| | ) |
| | print(f"[Online update loss: {loss:.4f}]") |
| | else: |
| | model.freeze_except_adapters(session_only=True, include_output=True) |
| | online_text = prompt + " " + output |
| | loss = online_unsupervised_update( |
| | model, tokenizer, online_text, optimizer, criterion, device, max_len=MAX_SEQ_LEN |
| | ) |
| | print(f"[Online (self-improve) update loss: {loss:.4f}]") |
| |
|
| | |
| | c.execute("INSERT INTO memory (timestamp, prompt, response) VALUES (?, ?, ?)", |
| | (datetime.now().isoformat(timespec='seconds'), prompt, output)) |
| | conn.commit() |
| |
|
| | print("\nRecent memory:") |
| | for row in c.execute("SELECT * FROM memory ORDER BY timestamp DESC LIMIT 5"): |
| | print(f"[{row[0]}] {row[1]} → {row[2]}") |
| |
|
| | |
| | |
| |
|