| |
|
|
| library(tidyverse) |
| library(ggthemes) |
| library(readr) |
| library(plotly) |
| library(optparse) |
|
|
| second_largest <- function(x) { |
| sort(unique(x), decreasing = TRUE)[2L] |
| } |
|
|
| parser <- OptionParser() |
| parser <- add_option( |
| parser, c("-d", "--csv_file"), |
| type = "character", |
| help = "Combined PPL result CSV file", |
| metavar = "character" |
| ) |
| parser <- add_option( |
| parser, c("--attempt"), |
| type = "character", |
| help = "The attempt to plot", |
| metavar = "character" |
| ) |
|
|
| args <- parse_args(parser) |
|
|
| if (is.null(args$csv_file)) { |
| csv_fp <- "data/combined.csv" |
| } else { |
| csv_fp <- args$csv_file |
| } |
| if (is.null(args$attempt)) { |
| the_attempt <- "mxq1" |
| } else { |
| the_attempt <- args$attempt |
| } |
|
|
| all_cols1 <- c( |
| "model", |
| "algo", |
| "config", |
| "bpp", |
| "load_mem_allot" |
| ) |
| all_cols2 <- c( |
| "model", |
| "algo", |
| "config.x", |
| "bpp", |
| "load_mem_allot" |
| ) |
| df_all <- read_csv(csv_fp) |
| df_wo_mxq <- df_all |> |
| filter(algo != "mxq") |> |
| select(all_of(all_cols1)) |
|
|
| df_fp16 <- df_all |> |
| filter(algo == "fp16") |
|
|
| df_baseline <- df_all |> |
| filter(algo == "hqq") |> |
| left_join(df_fp16, by = c("model"), suffix = c(".x", "")) |> |
| select(all_of(all_cols2)) |> |
| rename( |
| config = config.x, |
| ) |
|
|
| df_hqq <- df_wo_mxq |> filter(algo == "hqq" & !str_detect(config, "^mxq")) |
| df_mxq <- df_all |> |
| filter( |
| algo == "mxq" & attempt != "mxq-kurt-global" & attempt != "mxq-kurt-scaled" |
| ) |
| df_mxq1 <- df_hqq |> |
| left_join( |
| df_mxq, |
| suffix = c(".x", ""), |
| join_by(model, bpp) |
| ) |> |
| mutate( |
| algo = paste0(algo, "-", attempt) |
| ) |> |
| select(c("model", "algo", "config.x", "bpp", "load_mem_allot")) |> |
| rename( |
| config = config.x |
| ) |
|
|
|
|
| df_disp <- bind_rows(df_wo_mxq, df_baseline, df_mxq1) |> |
| filter(str_detect(config, "^b4")) |> |
| mutate( |
| model = factor( |
| model, |
| levels = c("Llama-2-7b-hf", "Meta-Llama-3-8B", "Llama-2-13b-hf"), |
| labels = c("Llama-2-7B", "Llama-3-8B", "Llama-2-13B") |
| ), |
| algo = factor( |
| algo, |
| levels = ( |
| c( |
| the_attempt, |
| "awq", |
| "gptq", |
| "hqq", |
| "mxq-mxq1", |
| "mxq-mxq2", |
| "bnb", |
| "fp16" |
| ) |
| ) |
| ), |
| config = factor(config, levels = (c("b4g32", "b4g64", "b4g128"))) |
| ) |> |
| filter(algo != "bnb") |
|
|
| df_2nd_largest <- df_disp |> |
| group_by(model) |> |
| summarise( |
| second_max_mem = second_largest(load_mem_allot) |
| ) |> |
| ungroup() |
| df_disp <- df_disp |> |
| left_join( |
| df_2nd_largest, |
| join_by(model) |
| ) |
|
|
| plt1 <- ggplot(df_disp, aes(x = algo, y = load_mem_allot, fill = algo)) + |
| geom_col(aes(x = load_mem_allot, y = algo), show.legend = FALSE) + |
| geom_text( |
| data = subset(df_disp, algo == "fp16"), |
| aes(second_max_mem, y = algo, label = toupper(algo)), |
| hjust = 0, |
| nudge_x = -0.5, |
| colour = "white", |
| size = 3 |
| ) + |
| geom_text( |
| data = subset(df_disp, algo != "fp16"), |
| aes(x = load_mem_allot, y = algo, label = toupper(algo)), |
| hjust = 0, |
| nudge_x = 0.3, |
| colour = "black", |
| size = 3 |
| ) + |
| labs(y = "Algorithm", x = "GPU Memory(GiB)") + |
| scale_x_continuous( |
| limits = c(0, 22), |
| breaks = seq(0, 22, by = 4), |
| expand = c(0, 0), |
| position = "bottom" |
| ) + |
| scale_y_discrete(expand = expansion(add = c(0, 0.5))) + |
| theme( |
| panel.background = element_rect(fill = "white"), |
| panel.grid.major.x = element_line(color = "#A8BAC4", size = 0.3), |
| axis.ticks.length = unit(0, "mm"), |
| axis.title = element_blank(), |
| axis.text.y = element_blank() |
| ) + |
| facet_grid(config ~ model) + |
| scale_color_solarized() |
| plt1 |
| ggsave("pdfs/llama-mem-consumption.pdf", plot = plt1, width = 8, height = 6) |
|
|
| ggplotly(plt1) |
|
|