Byte-lingua-code / lingua /profiling.py
2ira's picture
offline_compression_graph_code
72c0672 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import contextlib
from dataclasses import dataclass
import os
from pathlib import Path
import torch.distributed
import logging
from torch.profiler.profiler import profile
import xformers.profiler
from xformers.profiler import (
MemSnapshotsProfiler,
PyTorchProfiler,
)
from lingua.distributed import get_is_master
import wandb
@dataclass
class ProfilerArgs:
run: bool = False
trace_folder: str = "profiling"
mem_warmup: int = 100
mem_steps: int = 2
profile_warmup: int = 102
profile_steps: int = 2
logger = logging.getLogger()
def perfetto_to_html(json_file, html_file):
import viztracer
import gzip
import string
root = os.path.dirname(viztracer.__file__)
sub = {}
json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file)
with open(
os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8"
) as f:
tmpl = f.read()
with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f:
sub["trace_viewer_full"] = f.read()
with json_file as j:
content = j.read()
if isinstance(content, bytes):
content = content.decode("utf-8")
sub["json_data"] = content.replace("</script>", "<\\/script>") # type: ignore
with open(html_file, "w+", encoding="utf-8") as output_file:
output_file.write(string.Template(tmpl).substitute(sub))
class PyTorchProfilerWandb(PyTorchProfiler):
def __init__(self, main_profiler) -> None:
self.main_profiler = main_profiler
self.num_steps = 0
self.pytorch_profiler = torch.profiler.profile(
on_trace_ready=self._on_trace,
profile_memory=True,
record_shapes=True,
# With stack gives huge profile traces
# and bugs out because of some non ascii
# character somewhere in pytorch
with_stack=False,
with_flops=True,
activities=self.ACTIVITIES,
)
def _analyze_trace(self, prof: profile):
logger.info("Begin analyze trace")
super()._analyze_trace(prof)
logger.info("End analyze trace")
def _on_trace(self, prof: torch.profiler.profiler.profile) -> None:
super()._on_trace(prof)
if get_is_master() and wandb.run is not None:
filename = list(
Path(self.main_profiler.output_dir).glob(
"profile_CPU_CUDA*/*.pt.trace.json*"
)
)[0]
html_path = str(filename).replace(".json", ".html")
perfetto_to_html(filename, html_path)
wandb.log({"profile_trace": wandb.Html(html_path)})
class MemSnapshotsProfilerWandb(MemSnapshotsProfiler):
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
if get_is_master() and wandb.run is not None:
filename = list(
Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html")
)[0]
wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)})
@contextlib.contextmanager
def maybe_run_profiler(dump_dir, module, config: ProfilerArgs):
# get user defined profiler settings
if config.run:
trace_dir = os.path.join(dump_dir, config.trace_folder)
logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
if get_is_master() and not os.path.exists(trace_dir):
os.makedirs(trace_dir)
if torch.distributed.is_initialized():
torch.distributed.barrier()
with xformers.profiler.profile(
output_dir=trace_dir,
module=module,
schedule=[
(
MemSnapshotsProfilerWandb,
config.mem_warmup,
config.mem_warmup + config.mem_steps,
),
(
PyTorchProfilerWandb,
config.profile_warmup,
config.profile_warmup + config.profile_steps,
),
],
) as profiler:
yield profiler
else:
torch_profiler = contextlib.nullcontext()
yield None