| |
|
|
| library(tidyverse) |
| library(ggthemes) |
| library(ggbreak) |
| library(readr) |
| library(patchwork) |
| library(optparse) |
| library(ggmagnify) |
|
|
| mxq_attempt <- "MXQ2" |
| |
| plot_ppl <- function(df_disp) { |
| guideline_color <- "coral4" |
| df_wikitxt <- df_disp |> |
| filter( |
| dataset == "WikiText" |
| ) |
| min_ppl_wt <- min(df_wikitxt$ppl) |
| min_bpp_wt <- min(df_wikitxt$bpp) |
| df_c4 <- df_disp |> |
| filter( |
| dataset == "C4" |
| ) |
| min_ppl_c4 <- min(df_c4$ppl) |
| min_bpp_c4 <- min(df_c4$bpp) |
| ppl_low_bound <- floor(min_ppl_wt * 10) / 10 |
| plt <- ggplot( |
| subset(df_disp, algo != "MXQ"), |
| aes(x = bpp, y = ppl), |
| ) + |
| geom_point( |
| data = subset(df_disp, algo == "MXQ"), |
| size = 0.6, |
| aes(shape = algo, color = dataset, y = ppl) |
| ) + |
| geom_point(size = 1.5, aes(shape = algo, color = dataset, y = ppl)) + |
| geom_hline( |
| yintercept = min_ppl_wt * 1.02, |
| linetype = "dashed", |
| size = 0.1, |
| color = guideline_color |
| ) + |
| geom_hline( |
| yintercept = min_ppl_wt * 1.01, |
| linetype = "dashed", |
| size = 0.1, |
| color = guideline_color |
| ) + |
| geom_hline( |
| yintercept = min_ppl_wt, |
| linetype = "dashed", |
| size = 0.1, |
| color = guideline_color |
| ) + |
| geom_hline( |
| yintercept = min_ppl_c4 * 1.02, |
| linetype = "dashed", |
| size = 0.1, |
| color = guideline_color |
| ) + |
| geom_hline( |
| yintercept = min_ppl_c4 * 1.01, |
| linetype = "dashed", |
| size = 0.1, |
| color = guideline_color |
| ) + |
| geom_hline( |
| yintercept = min_ppl_c4, |
| linetype = "dashed", |
| size = 0.1, |
| color = guideline_color |
| ) + |
| geom_magnify( |
| from = c(4.10, 4.55, 4.65, 4.75), |
| to = c(3.70, 5.20, 5.20, 6.20), |
| colour = "#fc8d62", |
| linewidth = 0.3, |
| axes = "xy" |
| ) + |
| geom_magnify( |
| from = c(4.10, 4.55, 6.47, 6.62), |
| to = c(3.80, 5.30, 6.85, 7.80), |
| colour = "#fc8d62", |
| linewidth = 0.3, |
| axes = "xy" |
| ) + |
| annotate("text", x = 15.8, y = min_ppl_wt * 1.00, label = "FP16") + |
| annotate("text", x = 15.8, y = min_ppl_c4 * 1.00, label = "FP16") + |
| scale_x_break(c(5.5, 15.6)) + |
| scale_x_continuous( |
| limits = c(3.0, 16.2), |
| breaks = seq(3.0, 16.2, 0.25), |
| sec.axis = sec_axis(~ 100 * (16 - .) / 16, name = "% Memery Reduction") |
| ) + |
| scale_y_continuous( |
| limits = c(ppl_low_bound, min_ppl_c4 * 1.20), |
| breaks = seq( |
| round(ppl_low_bound, digits = 2), |
| round(min_ppl_c4 * 1.20, digits = 2), |
| 0.25 |
| ), |
| sec.axis = sec_axis( |
| ~ 100 * (. - min_ppl_c4) / min_ppl_c4, |
| name = "% Degradation", |
| breaks = seq(-30, 20, 5), |
| ) |
| ) + |
| labs(x = "Bit Budget", y = "Perplexity") + |
| theme_gray(base_size = 12) + |
| guides( |
| shape = guide_legend(title = "Method:"), |
| color = guide_legend(title = "Dataset:") |
| ) + |
| theme( |
| legend.position = "left" |
| ) + |
| facet_wrap(~model, scales = "free") + |
| scale_color_solarized() |
|
|
| return(plt) |
| } |
|
|
| proc_data <- function(df, model_name) { |
| df_disp <- df |> |
| filter( |
| grepl(mdl, model) & bpp >= 2.5 |
| ) |> |
| filter(is.na(attempt) | attempt == mxq_attempt) |> |
| pivot_longer( |
| cols = c("ppl_wikitext", "ppl_c4"), |
| names_to = c(".value", "dataset"), |
| names_sep = "_" |
| ) |> |
| mutate( |
| dataset = factor( |
| dataset, |
| levels = c( |
| "wikitext", |
| "c4" |
| ), |
| labels = c( |
| "WikiText", |
| "C4" |
| ), |
| ) |
| ) |
| return(df_disp) |
| } |
|
|
| parser <- OptionParser() |
| parser <- add_option( |
| parser, c("-d", "--csv_file"), |
| type = "character", |
| help = "Combined PPL result CSV file", |
| metavar = "character" |
| ) |
| args <- parse_args(parser) |
|
|
| if (is.null(args$csv_file)) { |
| csv_fp <- "data/combined.csv" |
| } else { |
| csv_fp <- args$csv_file |
| } |
|
|
| all_cols <- c( |
| "model", "algo", "config", "attempt", |
| "bpp", "ppl_wikitext", "ppl_c4" |
| ) |
| df_all <- read_csv(csv_fp) |> |
| select(all_of(all_cols)) |> |
| mutate( |
| model = factor( |
| model, |
| levels = c("Llama-2-7b-hf", "Meta-Llama-3-8B", "Llama-2-13b-hf"), |
| labels = c("Llama-2-7B", "Llama-3-8B", "Llama-2-13B") |
| ), |
| algo = factor( |
| algo, |
| levels = c("mxq", "fp16", "awq", "gptq", "bnb", "hqq"), |
| labels = c("MXQ", "FP16", "AWQ", "GPTQ", "BnB", "HQQ"), |
| ), |
| attempt = factor( |
| attempt, |
| levels = c( |
| "mxq1", |
| "kurt-scaled" |
| ), |
| labels = c( |
| "MXQ1", |
| "KURT-SCALED" |
| ), |
| ) |
| ) |
|
|
| models <- unique(df_all$model) |
| for (mdl in models) { |
| df_disp <- proc_data(df_all, mdl) |
| plt <- plot_ppl(df_disp) |
|
|
| pdf.options(reset = TRUE, onefile = FALSE) |
| ggsave( |
| paste0("pdfs/ppl-mem-", mdl, ".pdf"), |
| plot = plt, |
| width = 8, |
| height = 5 |
| ) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|