quantization2 / lm-quant-toolkit /data-vis /plot-mem-consumption.R
chen459664's picture
Add files using upload-large-folder tool
21ad80b verified
#!/usr/bin/env Rscript
library(tidyverse)
library(ggthemes)
library(readr)
library(plotly)
library(optparse)
second_largest <- function(x) {
sort(unique(x), decreasing = TRUE)[2L]
}
parser <- OptionParser()
parser <- add_option(
parser, c("-d", "--csv_file"),
type = "character",
help = "Combined PPL result CSV file",
metavar = "character"
)
parser <- add_option(
parser, c("--attempt"),
type = "character",
help = "The attempt to plot",
metavar = "character"
)
args <- parse_args(parser)
if (is.null(args$csv_file)) {
csv_fp <- "data/combined.csv"
} else {
csv_fp <- args$csv_file
}
if (is.null(args$attempt)) {
the_attempt <- "mxq1"
} else {
the_attempt <- args$attempt
}
all_cols1 <- c(
"model",
"algo",
"config",
"bpp",
"load_mem_allot"
)
all_cols2 <- c(
"model",
"algo",
"config.x",
"bpp",
"load_mem_allot"
)
df_all <- read_csv(csv_fp)
df_wo_mxq <- df_all |>
filter(algo != "mxq") |>
select(all_of(all_cols1))
df_fp16 <- df_all |>
filter(algo == "fp16")
df_baseline <- df_all |>
filter(algo == "hqq") |>
left_join(df_fp16, by = c("model"), suffix = c(".x", "")) |>
select(all_of(all_cols2)) |>
rename(
config = config.x,
)
df_hqq <- df_wo_mxq |> filter(algo == "hqq" & !str_detect(config, "^mxq"))
df_mxq <- df_all |>
filter(
algo == "mxq" & attempt != "mxq-kurt-global" & attempt != "mxq-kurt-scaled"
)
df_mxq1 <- df_hqq |>
left_join(
df_mxq,
suffix = c(".x", ""),
join_by(model, bpp)
) |>
mutate(
algo = paste0(algo, "-", attempt)
) |>
select(c("model", "algo", "config.x", "bpp", "load_mem_allot")) |>
rename(
config = config.x
)
df_disp <- bind_rows(df_wo_mxq, df_baseline, df_mxq1) |>
filter(str_detect(config, "^b4")) |>
mutate(
model = factor(
model,
levels = c("Llama-2-7b-hf", "Meta-Llama-3-8B", "Llama-2-13b-hf"),
labels = c("Llama-2-7B", "Llama-3-8B", "Llama-2-13B")
),
algo = factor(
algo,
levels = (
c(
the_attempt,
"awq",
"gptq",
"hqq",
"mxq-mxq1",
"mxq-mxq2",
"bnb",
"fp16"
)
)
),
config = factor(config, levels = (c("b4g32", "b4g64", "b4g128")))
) |>
filter(algo != "bnb")
df_2nd_largest <- df_disp |>
group_by(model) |>
summarise(
second_max_mem = second_largest(load_mem_allot)
) |>
ungroup()
df_disp <- df_disp |>
left_join(
df_2nd_largest,
join_by(model)
)
plt1 <- ggplot(df_disp, aes(x = algo, y = load_mem_allot, fill = algo)) +
geom_col(aes(x = load_mem_allot, y = algo), show.legend = FALSE) +
geom_text(
data = subset(df_disp, algo == "fp16"),
aes(second_max_mem, y = algo, label = toupper(algo)),
hjust = 0,
nudge_x = -0.5,
colour = "white",
size = 3
) +
geom_text(
data = subset(df_disp, algo != "fp16"),
aes(x = load_mem_allot, y = algo, label = toupper(algo)),
hjust = 0,
nudge_x = 0.3,
colour = "black",
size = 3
) +
labs(y = "Algorithm", x = "GPU Memory(GiB)") +
scale_x_continuous(
limits = c(0, 22),
breaks = seq(0, 22, by = 4),
expand = c(0, 0),
position = "bottom"
) +
scale_y_discrete(expand = expansion(add = c(0, 0.5))) +
theme(
panel.background = element_rect(fill = "white"),
panel.grid.major.x = element_line(color = "#A8BAC4", size = 0.3),
axis.ticks.length = unit(0, "mm"),
axis.title = element_blank(),
axis.text.y = element_blank()
) +
facet_grid(config ~ model) +
scale_color_solarized()
plt1
ggsave("pdfs/llama-mem-consumption.pdf", plot = plt1, width = 8, height = 6)
ggplotly(plt1)