chen459664's picture
Add files using upload-large-folder tool
21ad80b verified
#!/usr/bin/env Rscript
library(tidyverse)
library(ggthemes)
library(ggbreak)
library(readr)
library(patchwork)
library(optparse)
library(ggmagnify)
mxq_attempt <- "MXQ2"
# Plot memory drop vs PPL loss ---------------------------------
plot_ppl <- function(df_disp) {
guideline_color <- "coral4"
df_wikitxt <- df_disp |>
filter(
dataset == "WikiText"
)
min_ppl_wt <- min(df_wikitxt$ppl)
min_bpp_wt <- min(df_wikitxt$bpp)
df_c4 <- df_disp |>
filter(
dataset == "C4"
)
min_ppl_c4 <- min(df_c4$ppl)
min_bpp_c4 <- min(df_c4$bpp)
ppl_low_bound <- floor(min_ppl_wt * 10) / 10
plt <- ggplot(
subset(df_disp, algo != "MXQ"),
aes(x = bpp, y = ppl),
) +
geom_point(
data = subset(df_disp, algo == "MXQ"),
size = 0.6,
aes(shape = algo, color = dataset, y = ppl)
) +
geom_point(size = 1.5, aes(shape = algo, color = dataset, y = ppl)) +
geom_hline(
yintercept = min_ppl_wt * 1.02,
linetype = "dashed",
size = 0.1,
color = guideline_color
) +
geom_hline(
yintercept = min_ppl_wt * 1.01,
linetype = "dashed",
size = 0.1,
color = guideline_color
) +
geom_hline(
yintercept = min_ppl_wt,
linetype = "dashed",
size = 0.1,
color = guideline_color
) +
geom_hline(
yintercept = min_ppl_c4 * 1.02,
linetype = "dashed",
size = 0.1,
color = guideline_color
) +
geom_hline(
yintercept = min_ppl_c4 * 1.01,
linetype = "dashed",
size = 0.1,
color = guideline_color
) +
geom_hline(
yintercept = min_ppl_c4,
linetype = "dashed",
size = 0.1,
color = guideline_color
) +
geom_magnify(
from = c(4.10, 4.55, 4.65, 4.75),
to = c(3.70, 5.20, 5.20, 6.20),
colour = "#fc8d62",
linewidth = 0.3,
axes = "xy"
) +
geom_magnify(
from = c(4.10, 4.55, 6.47, 6.62),
to = c(3.80, 5.30, 6.85, 7.80),
colour = "#fc8d62",
linewidth = 0.3,
axes = "xy"
) +
annotate("text", x = 15.8, y = min_ppl_wt * 1.00, label = "FP16") +
annotate("text", x = 15.8, y = min_ppl_c4 * 1.00, label = "FP16") +
scale_x_break(c(5.5, 15.6)) +
scale_x_continuous(
limits = c(3.0, 16.2),
breaks = seq(3.0, 16.2, 0.25),
sec.axis = sec_axis(~ 100 * (16 - .) / 16, name = "% Memery Reduction")
) +
scale_y_continuous(
limits = c(ppl_low_bound, min_ppl_c4 * 1.20),
breaks = seq(
round(ppl_low_bound, digits = 2),
round(min_ppl_c4 * 1.20, digits = 2),
0.25
),
sec.axis = sec_axis(
~ 100 * (. - min_ppl_c4) / min_ppl_c4,
name = "% Degradation",
breaks = seq(-30, 20, 5),
)
) +
labs(x = "Bit Budget", y = "Perplexity") +
theme_gray(base_size = 12) +
guides(
shape = guide_legend(title = "Method:"),
color = guide_legend(title = "Dataset:")
) +
theme(
legend.position = "left"
) +
facet_wrap(~model, scales = "free") +
scale_color_solarized()
return(plt)
}
proc_data <- function(df, model_name) {
df_disp <- df |>
filter(
grepl(mdl, model) & bpp >= 2.5
) |>
filter(is.na(attempt) | attempt == mxq_attempt) |>
pivot_longer(
cols = c("ppl_wikitext", "ppl_c4"),
names_to = c(".value", "dataset"),
names_sep = "_"
) |>
mutate(
dataset = factor(
dataset,
levels = c(
"wikitext",
"c4"
),
labels = c(
"WikiText",
"C4"
),
)
)
return(df_disp)
}
parser <- OptionParser()
parser <- add_option(
parser, c("-d", "--csv_file"),
type = "character",
help = "Combined PPL result CSV file",
metavar = "character"
)
args <- parse_args(parser)
if (is.null(args$csv_file)) {
csv_fp <- "data/combined.csv"
} else {
csv_fp <- args$csv_file
}
all_cols <- c(
"model", "algo", "config", "attempt",
"bpp", "ppl_wikitext", "ppl_c4"
)
df_all <- read_csv(csv_fp) |>
select(all_of(all_cols)) |>
mutate(
model = factor(
model,
levels = c("Llama-2-7b-hf", "Meta-Llama-3-8B", "Llama-2-13b-hf"),
labels = c("Llama-2-7B", "Llama-3-8B", "Llama-2-13B")
),
algo = factor(
algo,
levels = c("mxq", "fp16", "awq", "gptq", "bnb", "hqq"),
labels = c("MXQ", "FP16", "AWQ", "GPTQ", "BnB", "HQQ"),
),
attempt = factor(
attempt,
levels = c(
"mxq1",
"kurt-scaled"
),
labels = c(
"MXQ1",
"KURT-SCALED"
),
)
)
models <- unique(df_all$model)
for (mdl in models) {
df_disp <- proc_data(df_all, mdl)
plt <- plot_ppl(df_disp)
pdf.options(reset = TRUE, onefile = FALSE)
ggsave(
paste0("pdfs/ppl-mem-", mdl, ".pdf"),
plot = plt,
width = 8,
height = 5
)
}
# mdl <- "Llama-2-13B"
# df_disp <- proc_data(df_all, mdl)
# plt <- plot_ppl(df_disp)
#
# pdf.options(reset = TRUE, onefile = FALSE)
# ggsave(
# paste0("pdfs/ppl-mem-", mdl, ".png"),
# plot = plt,
# width = 8,
# height = 5,
# dpi = 600
# )