2ira's picture
offline_compression_graph_code
72c0672 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
from dataclasses import asdict, dataclass, field
from datetime import datetime
import json
import logging
from pathlib import Path
from typing import Any, Optional
from lm_eval import simple_evaluate
from omegaconf import OmegaConf
import torch
from apps.main.eval import (
ValidationArgs,
EvalHarnessLM,
LMHarnessArgs,
eval_on_val,
)
from apps.fastRNN.generate import (
PackedRNNGenerator,
PackedRNNGeneratorArgs,
load_consolidated_model_and_tokenizer,
)
from lingua.args import dump_config
from lingua.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from lingua.distributed import DistributedArgs, get_global_rank, setup_torch_distributed
EVAL_FOLDER_NAME = "{:010d}"
logger = logging.getLogger()
@dataclass
class EvalArgs:
name: str = "evals"
dump_dir: Optional[str] = None
metric_log_dir: Optional[str] = None
ckpt_dir: str = ""
generator: PackedRNNGeneratorArgs = field(default_factory=PackedRNNGeneratorArgs)
harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
wandb: Optional[Any] = None
global_step: Optional[int] = None # for in-training evaluation
def launch_eval(cfg: EvalArgs):
if not torch.distributed.is_initialized():
setup_torch_distributed(DistributedArgs())
if (
Path(cfg.ckpt_dir).exists()
and (Path(cfg.ckpt_dir) / "params.json").exists()
and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
):
consolidate_path = Path(cfg.ckpt_dir)
else:
consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
if not consolidate_path.exists() and get_global_rank() == 0:
consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
consolidate_path = str(consolidate_path)
torch.distributed.barrier()
logger.info("Loading model")
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(consolidate_path)
logger.info("Model loaded")
model.eval()
generator = PackedRNNGenerator(cfg.generator, model, tokenizer)
wrap = EvalHarnessLM(generator)
results = simple_evaluate(wrap, **asdict(cfg.harness))
val_results = None
if cfg.validation:
val_results = eval_on_val(generator, cfg.validation, train_cfg)
if get_global_rank() == 0:
with open(Path(cfg.dump_dir) / "results.json", "w") as f:
f.write(json.dumps(results))
logger.info(f"All evaluation results: {results['results']}")
if val_results is not None:
with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
f.write(json.dumps(val_results))
logger.info(f"All validation results: {val_results}")
if cfg.metric_log_dir and get_global_rank() == 0:
metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
logger.info(f"Writing metric logs to {metric_log_path}")
timestamp = {
"created_at": datetime.utcnow().isoformat(),
}
if cfg.global_step is not None:
timestamp["global_step"] = cfg.global_step
print(
json.dumps(timestamp | results["results"]),
file=open(metric_log_path, mode="a"),
flush=True,
)
val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
if val_results is not None:
print(
json.dumps(timestamp | val_results),
file=open(val_log_path, mode="a"),
flush=True,
)
del generator
def main():
"""
The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
This accepts arguments as a dot list
So if the dataclass looks like
@dataclass
class DummyArgs:
name: str
mode: LMMambaArg
@dataclass
class LMMambaArgs:
dim: int
Then you can pass model.dim=32 to change values in LMMambaArgs
or just name=tictac for top level attributes.
The behavior here is as follows:
1. We instantiate EvalArgs with its default values
2. We override those default values with the ones in the provided config file
3. We override the result with the additional arguments provided through command line
For example, if the config is the following
model:
dim: 128
n_layers: 4
and you call eval.py with eval.py model.dim=64
Then the final TrainArgs will have
model:
dim: 64
n_layers: 4
Plus all the default values in EvalArgs dataclass.
"""
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config
default_cfg = OmegaConf.structured(EvalArgs())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_object(cfg)
launch_eval(cfg)
if __name__ == "__main__":
main()