| """ |
| Instruction-tuning with LoRA on the Alpaca dataset. |
| |
| Note: If you run into a CUDA error "Expected is_sm80 to be true, but got false", uncomment the line |
| `torch.backends.cuda.enable_flash_sdp(False)` in the script below (see https://github.com/Lightning-AI/lit-llama/issues/101). |
| """ |
| import sys |
| from pathlib import Path |
| import os |
| import time |
|
|
| import lightning as L |
| import numpy as np |
| import torch |
|
|
| |
| wd = Path(__file__).parent.parent.resolve() |
| sys.path.append(str(wd)) |
|
|
| from generate import generate |
| from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict |
| from lit_llama.model import LLaMA, LLaMAConfig |
| from lit_llama.tokenizer import Tokenizer |
| from scripts.prepare_alpaca import generate_prompt |
|
|
|
|
| eval_interval = 100 |
| save_interval = 100 |
| eval_iters = 100 |
| log_interval = 1 |
|
|
| |
| learning_rate = 3e-4 |
| batch_size = 128 |
| micro_batch_size = 4 |
| gradient_accumulation_steps = batch_size // micro_batch_size |
| max_iters = 50000 * 3 // micro_batch_size |
| weight_decay = 0.0 |
| max_seq_length = 256 |
| lora_r = 8 |
| lora_alpha = 16 |
| lora_dropout = 0.05 |
| warmup_steps = 100 |
|
|
|
|
| def main( |
| data_dir: str = "data/alpaca", |
| pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth", |
| out_dir: str = "out/lora/alpaca", |
| ): |
|
|
| fabric = L.Fabric(accelerator="cpu", devices=1, precision="bf16-true") |
| |
| fabric.launch() |
| fabric.seed_everything(1337 + fabric.global_rank) |
|
|
| if fabric.global_rank == 0: |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| train_data, val_data = load_datasets(data_dir=data_dir) |
|
|
| config = LLaMAConfig.from_name("7B") |
| config.block_size = max_seq_length |
|
|
| checkpoint = torch.load(pretrained_path) |
|
|
| with fabric.init_module(), lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True): |
| model = LLaMA(config) |
| |
| model.load_state_dict(checkpoint, strict=False) |
| |
| mark_only_lora_as_trainable(model) |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
| model, optimizer = fabric.setup(model, optimizer) |
| train(fabric, model, optimizer, train_data, val_data, out_dir) |
|
|
| |
| checkpoint = lora_state_dict(model) |
| fabric.save(os.path.join(out_dir, "lit-llama-lora-finetuned.pth"), checkpoint) |
|
|
|
|
| def train( |
| fabric: L.Fabric, |
| model: torch.nn.Module, |
| optimizer: torch.optim.Optimizer, |
| train_data: np.ndarray, |
| val_data: np.ndarray, |
| out_dir: str, |
| ) -> None: |
| """The training loop. |
| |
| Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. |
| """ |
| step_count = 0 |
|
|
| for iter_num in range(max_iters): |
|
|
| if step_count <= warmup_steps: |
| |
| lr = learning_rate * step_count / warmup_steps |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
|
|
| t0 = time.time() |
|
|
| input_ids, targets = get_batch(fabric, train_data) |
| logits = model(input_ids) |
| loss = loss_fn(logits, targets) |
| fabric.backward(loss) |
|
|
| if (iter_num + 1) % gradient_accumulation_steps == 0: |
| optimizer.step() |
| optimizer.zero_grad() |
| step_count += 1 |
| |
| if step_count % eval_interval == 0: |
| val_loss = validate(fabric, model, val_data) |
| fabric.print(f"step {iter_num}: val loss {val_loss:.4f}") |
| fabric.barrier() |
|
|
| if step_count % save_interval == 0: |
| print(f"Saving LoRA weights to {out_dir}") |
| |
| |
| checkpoint = lora_state_dict(model) |
| fabric.save(os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"), checkpoint) |
|
|
| dt = time.time() - t0 |
| if iter_num % log_interval == 0: |
| fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms") |
|
|
|
|
| def generate_response(model, instruction): |
| tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model") |
| sample = {"instruction": instruction, "input": ""} |
| prompt = generate_prompt(sample) |
| encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device) |
|
|
| output = generate( |
| model, |
| idx=encoded, |
| max_seq_length=max_seq_length, |
| max_new_tokens=100, |
| ) |
| output = tokenizer.decode(output) |
| return output |
|
|
|
|
| @torch.no_grad() |
| def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor: |
| fabric.print("Validating ...") |
| model.eval() |
| losses = torch.zeros(eval_iters) |
| for k in range(eval_iters): |
| input_ids, targets = get_batch(fabric, val_data) |
| logits = model(input_ids) |
| loss = loss_fn(logits, targets) |
| losses[k] = loss.item() |
| out = losses.mean() |
|
|
| |
| instruction = "Recommend a movie for me to watch during the weekend and explain the reason." |
| |
| output = generate_response(model, instruction) |
| fabric.print(instruction) |
| fabric.print(output) |
|
|
| model.train() |
| return out.item() |
|
|
| def loss_fn(logits, targets): |
| |
| logits = logits[..., :-1, :].contiguous() |
| targets = targets[..., 1:].contiguous() |
| loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| return loss |
| |
|
|
| def get_batch(fabric: L.Fabric, data: list): |
| ix = torch.randint(len(data), (micro_batch_size,)) |
|
|
| input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] |
| labels = [data[i]["labels"].type(torch.int64) for i in ix] |
|
|
| max_len = max(len(s) for s in input_ids) |
|
|
| def pad_right(x, pad_id): |
| |
| n = max_len - len(x) |
| return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) |
|
|
| x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) |
| y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) |
| x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) |
| return x, y |
|
|
|
|
| def load_datasets(data_dir): |
| train_data = torch.load(os.path.join(data_dir, "train.pt")) |
| val_data = torch.load(os.path.join(data_dir, "test.pt")) |
| return train_data, val_data |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| torch.set_float32_matmul_precision("high") |
| |
| from jsonargparse.cli import CLI |
|
|
| CLI(main) |