| |
|
|
| library(tidyverse) |
| library(ggthemes) |
| library(readr) |
| library(optparse) |
|
|
| parser <- OptionParser() |
| parser <- add_option( |
| parser, c("-d", "--csv_file"), |
| type = "character", |
| help = "Combined PPL result CSV file", |
| metavar = "character" |
| ) |
| args <- parse_args(parser) |
|
|
| if (is.null(args$csv_file)) { |
| csv_fp <- "data/combined.csv" |
| } else { |
| csv_fp <- args$csv_file |
| } |
|
|
| all_cols1 <- c( |
| "model", |
| "algo", |
| "config", |
| "bpp", |
| "quant_duration" |
| ) |
| df_all <- read_csv(csv_fp) |
| df_wo_mxq <- df_all |> |
| filter(algo != "mxq" & algo != "fp16") |> |
| select(all_of(all_cols1)) |
|
|
| df_hqq <- df_all |> filter(algo == "hqq" & !str_detect(config, "^mxq")) |
| df_mxq <- df_all |> filter(algo == "mxq") |
| df_mxq1 <- df_hqq |> |
| left_join( |
| df_mxq, |
| suffix = c(".x", ""), |
| join_by(model, bpp) |
| ) |> |
| mutate( |
| algo = paste0(algo, "-", attempt) |
| ) |> |
| select(c("model", "algo", "config.x", "bpp", "quant_duration")) |> |
| rename( |
| config = config.x |
| ) |
|
|
| df_all <- bind_rows(df_wo_mxq, df_mxq1) |
|
|
| disp <- df_all |> |
| filter(str_detect(config, "^b4")) |> |
| mutate( |
| model = factor( |
| model, |
| levels = c("Llama-2-7b-hf", "Meta-Llama-3-8B", "Llama-2-13b-hf"), |
| labels = c("Llama-2-7B", "Llama-3-8B", "Llama-2-13B") |
| ), |
| algo = factor( |
| algo, |
| levels = ( |
| c( |
| "awq", |
| "gptq", |
| "hqq", |
| "mxq-mxq1", |
| "mxq-kurt-global", |
| "mxq-kurt-scaled", |
| "fp16" |
| ) |
| ) |
| ), |
| config = factor(config, levels = (c("b4g32", "b4g64", "b4g128"))) |
| ) |
|
|
| ggplot(disp, aes(x = algo, y = quant_duration, fill = algo)) + |
| geom_col(aes(x = quant_duration, y = algo), show.legend = FALSE) + |
| geom_text( |
| data = subset(disp, quant_duration <= 200), |
| aes(quant_duration * 1.1, y = algo, label = toupper(algo)), |
| hjust = 0, |
| nudge_x = 0.3, |
| size = 3 |
| ) + |
| geom_text( |
| data = subset(disp, quant_duration > 200), |
| aes(0, y = algo, label = toupper(algo)), |
| hjust = 0, |
| nudge_x = 0.3, |
| colour = "white", |
| size = 3 |
| ) + |
| labs(y = "Algorithm", x = "Quantation Time(Seconds)") + |
| scale_x_continuous( |
| limits = c(1, 1600), |
| |
| expand = c(0, 0), |
| trans = "log10", |
| position = "bottom" |
| ) + |
| scale_y_discrete(expand = expansion(add = c(0, 0.5))) + |
| theme( |
| panel.background = element_rect(fill = "white"), |
| panel.grid.major.x = element_line(color = "#A8BAC4", size = 0.2), |
| axis.ticks.length = unit(0, "mm"), |
| axis.title = element_blank(), |
| axis.text.x = element_text(angle = 45, vjust = 0.9, hjust = 1), |
| axis.text.y = element_blank() |
| ) + |
| |
| facet_grid(config ~ model) + |
| scale_color_tableau() |
| ggsave("pdfs/llama-quant-speed.pdf", width = 8, height = 6) |
|
|