| |
| |
| import math |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Optional |
|
|
| import lightning as L |
| import torch |
| import tqdm |
|
|
| from lit_llama import LLaMA, Tokenizer |
| from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup |
|
|
| from datasets import load_dataset |
|
|
|
|
| def load_eval_data(dataset_name: str) -> str: |
| |
| if dataset_name == "wikitext": |
| |
| testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| testdata = "\n\n".join(testdata["text"]) |
| elif dataset_name == "ptb": |
| testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") |
| testdata = "\n\n".join(testdata["sentence"]) |
| elif dataset_name == "c4": |
| testdata = load_dataset( |
| "allenai/c4", |
| "allenai--c4", |
| data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, |
| split="validation", |
| ) |
| testdata = " ".join(testdata[:1100]["text"]) |
|
|
| else: |
| raise ValueError("invalid dataset name (wikitext, ptb, c4 are allowed)") |
| return testdata |
|
|
|
|
| def main( |
| datasets: str = "wikitext,ptb,c4", |
| *, |
| |
| |
| accelerator: str = "auto", |
| checkpoint_path: Optional[Path] = None, |
| tokenizer_path: Optional[Path] = None, |
| dtype: str = "float32", |
| quantize: Optional[str] = None, |
| ) -> None: |
| """Generates text samples based on a pre-trained LLaMA model and tokenizer. |
| |
| Args: |
| datasets: The datasets to use as a comma separated string |
| # compile: Whether to compile the model. |
| accelerator: The hardware to run on. Possible choices are: |
| ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``. |
| checkpoint_path: The checkpoint path to load. |
| tokenizer_path: The tokenizer path to load. |
| quantize: Whether to quantize the model and using which method: |
| ``"llm.int8"``: LLM.int8() mode, |
| ``"gptq.int4"``: GPTQ 4-bit mode. |
| """ |
| if not checkpoint_path: |
| checkpoint_path = Path(f"./checkpoints/lit-llama/7B/lit-llama.pth") |
| if not tokenizer_path: |
| tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model") |
| assert checkpoint_path.is_file() |
| assert tokenizer_path.is_file() |
|
|
| fabric = L.Fabric(accelerator=accelerator, devices=1) |
|
|
| dt = getattr(torch, dtype, None) |
| if not isinstance(dt, torch.dtype): |
| raise ValueError(f"{dtype} is not a valid dtype.") |
| dtype = dt |
|
|
| with EmptyInitOnDevice( |
| device=fabric.device, dtype=dtype, quantization_mode=quantize |
| ): |
| print("Loading model ...", file=sys.stderr) |
| t0 = time.time() |
| checkpoint = torch.load(checkpoint_path) |
| name = llama_model_lookup(checkpoint) |
| model = LLaMA.from_name(name) |
| model.load_state_dict(checkpoint) |
| print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) |
|
|
| model.eval() |
|
|
| |
| |
|
|
| total_toks = 0 |
| model = fabric.setup_module(model) |
|
|
| tokenizer = Tokenizer(tokenizer_path) |
|
|
| for dsname in datasets.split(","): |
| test_string = load_eval_data(dsname) |
| encoded_text = tokenizer.encode( |
| test_string, bos=True, eos=False, device=fabric.device |
| ) |
| encoded_text = encoded_text[ |
| None, : 256 * model.config.block_size |
| ] |
| t0 = time.perf_counter() |
|
|
| nlls = 0 |
| toks = 0 |
| with torch.inference_mode(): |
| block_size = 2048 |
| for i in tqdm.tqdm(range(0, encoded_text.shape[1], block_size)): |
| inp = encoded_text[:, i : i + block_size] |
| logits = model(inp)[0] |
| nll = torch.nn.functional.cross_entropy( |
| logits[:-1], inp[0, 1:].to(dtype=torch.long), reduction="sum" |
| ) |
| toks += inp.size(1) - 1 |
| nlls += nll.item() |
|
|
| print(encoded_text.shape, logits.shape) |
| encoded_text = encoded_text[:, : logits.shape[0]] |
| ppl = math.exp(nlls / toks) |
| print(f"Perplexity on {dsname}: {ppl:.2f}") |
| total_toks += toks |
|
|
| t = time.perf_counter() - t0 |
| print( |
| f"\n\nTime for inference: {t:.02f} sec total, {total_toks / t:.02f} tokens/sec", |
| file=sys.stderr, |
| ) |
| print( |
| f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", |
| file=sys.stderr, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| from jsonargparse import CLI |
|
|
| torch.set_float32_matmul_precision("high") |
| CLI(main) |
|
|