| library(tidyverse) |
| library(readr) |
| library(ggthemes) |
| library(ggplot2) |
| library(patchwork) |
| library(plotly) |
|
|
| weight_grid <- function( |
| df_wdist, df_kurtosis, mod, show_legend = FALSE, show_cfg = TRUE) { |
| df_mod_wdist <- df_wdist |> filter(module == mod) |
| df_mod_kurt <- df_kurtosis |> filter(module == mod) |
| |
| line_plot <- ggplot(df_mod_kurt, aes(x = layer, y = kurtosis)) + |
| geom_line(color = "blue") + |
| theme_gray(base_size = 14) + |
| theme_minimal() + |
| theme( |
| axis.title.x = element_blank(), |
| axis.text.x = element_blank() |
| ) |
|
|
| |
| df_mod_wdist1 <- df_mod_wdist |> filter(attempt == "PCT5") |
| bar_plot1 <- ggplot( |
| df_mod_wdist1, |
| aes(x = layer, y = abs_val, fill = nth_percentile) |
| ) + |
| geom_bar(stat = "identity", color = "gray50") + |
| theme_gray(base_size = 14) + |
| labs( |
| x = df_mod_wdist1$mod_disp[1], |
| y = "Absolute Value", |
| fill = "nth percentile" |
| ) + |
| geom_text( |
| data = subset(df_mod_wdist1, nth_percentile == 100), |
| aes(x = layer, label = quant_cfg), |
| angle = 90, |
| vjust = 0.20, |
| position = position_stack(vjust = 0.5), |
| colour = "white", |
| size = 2 |
| ) + |
| theme(legend.position = "none") + |
| scale_color_solarized() |
|
|
| df_mod_wdist2 <- df_mod_wdist |> filter(attempt == "PCT6") |
| bar_plot2 <- ggplot( |
| df_mod_wdist2, |
| aes(x = layer, y = abs_val, fill = nth_percentile) |
| ) + |
| geom_bar(stat = "identity", color = "gray50") + |
| theme_gray(base_size = 14) + |
| labs( |
| x = df_mod_wdist2$mod_disp[1], |
| y = "Absolute Value", |
| fill = "nth percentile" |
| ) + |
| geom_text( |
| data = subset(df_mod_wdist2, nth_percentile == 100), |
| aes(x = layer, label = quant_cfg), |
| angle = 90, |
| vjust = 0.20, |
| position = position_stack(vjust = 0.5), |
| colour = "white", |
| size = 2 |
| ) + |
| theme(legend.position = "none") + |
| scale_color_solarized() |
|
|
| df_mod_wdist3 <- df_mod_wdist |> filter(attempt == "kurt-scaled") |
| bar_plot3 <- ggplot( |
| df_mod_wdist3, |
| aes(x = layer, y = abs_val, fill = nth_percentile) |
| ) + |
| geom_bar(stat = "identity", color = "gray50") + |
| theme_gray(base_size = 14) + |
| labs( |
| x = df_mod_wdist3$mod_disp[1], |
| y = "Absolute Value", |
| fill = "nth percentile" |
| ) + |
| geom_text( |
| data = subset(df_mod_wdist3, nth_percentile == 100), |
| aes(x = layer, label = quant_cfg), |
| angle = 90, |
| vjust = 0.20, |
| position = position_stack(vjust = 0.5), |
| colour = "white", |
| size = 2 |
| ) |
| if (show_legend) { |
| bar_plot3 <- bar_plot3 + |
| theme( |
| legend.position = "bottom", |
| legend.text = element_text(size = 16), |
| legend.title = element_text(size = 16) |
| ) + |
| |
| scale_color_solarized() |
| } else { |
| bar_plot3 <- bar_plot3 + |
| theme(legend.position = "none") + |
| scale_color_solarized() |
| } |
|
|
| |
| combined_plot <- line_plot / bar_plot1 / bar_plot2 / bar_plot3 + plot_layout(heights = c(1, 3, 3, 3)) |
| return(combined_plot) |
| } |
|
|
| weight_grid_only <- function( |
| df_wdist, df_kurtosis, mod, show_legend = FALSE) { |
| return(weight_grid(df_wdist, df_kurtosis, mod, show_legend, show_cfg = FALSE)) |
| } |
|
|
| budget <- 4.51 |
| model_id <- "Llama-2-13b-hf" |
| df_cfg1 <- read_csv("data/mxq-quant-cfgs-mxq1-5pct-tol.csv") |
| df_cfg2 <- read_csv("data/mxq-quant-cfgs-kurt-scaled-6pct-tol.csv") |
| df_cfg3 <- read_csv("data/kurt/scaled/llama-mxq-cfgs.csv") |
| df_cfg1$attempt <- "PCT5" |
| df_cfg2$attempt <- "PCT6" |
| df_cfg3$attempt <- "kurt-scaled" |
| df_cfg <- bind_rows(df_cfg1, df_cfg2, df_cfg3) |
|
|
| df_cfg_1 <- df_cfg |> |
| filter(bit_budget == budget & model == model_id) |> |
| mutate( |
| quant_cfg = paste0("b", b1, "g", g1) |
| ) |> |
| select(-c("b1", "g1", "b2", "g2", "bit_budget")) |
|
|
| df_all <- read_csv(paste0("data/wdist/wdist-", model_id, ".csv")) |
| percentiles <- c("0", "99", "99.9", "99.99", "100") |
| df_module_param_count <- df_all |> |
| select( |
| module, param_count |
| ) |> |
| group_by(module) |> |
| summarise( |
| param_count = sum(param_count) |
| ) |> |
| mutate( |
| mod_disp = paste0(module, "(", formatC(param_count, big.mark = ","), ")") |
| ) |
|
|
| df_cfg_1 <- df_cfg_1 |> |
| left_join(df_module_param_count, by = c("module")) |> |
| mutate( |
| mod_disp = paste0(attempt, " ", mod_disp) |
| ) |
|
|
| all_cols <- c("module", "layer", percentiles) |
| df_wdist <- df_all |> |
| mutate( |
| `0` = percentile_0, |
| `99` = percentile_99 - percentile_0, |
| `99.9` = percentile_999 - percentile_99, |
| `99.99` = percentile_9999 - percentile_999, |
| `100` = percentile_100 - percentile_9999, |
| ) |> |
| select(all_of(all_cols)) |> |
| pivot_longer( |
| cols = percentiles, |
| names_to = "nth_percentile", |
| names_transform = list(nth_percentile = as.numeric), |
| values_to = "abs_val" |
| ) |> |
| mutate( |
| nth_percentile = factor(nth_percentile, levels = rev(percentiles)) |
| ) |> |
| filter(!grepl("_layernorm", module)) |> |
| left_join(df_cfg_1, by = c("module", "layer"), relationship = "many-to-many") |
|
|
|
|
| k_cols <- c("module", "layer", "kurtosis") |
| df_kurtosis <- df_all |> |
| select(all_of(k_cols)) |
|
|
| p1 <- weight_grid(df_wdist, df_kurtosis, "mlp.down_proj") |
| p2 <- weight_grid(df_wdist, df_kurtosis, "mlp.gate_proj") |
| p3 <- weight_grid(df_wdist, df_kurtosis, "mlp.up_proj") |
| p4 <- weight_grid(df_wdist, df_kurtosis, "self_attn.k_proj") |
| p5 <- weight_grid(df_wdist, df_kurtosis, "self_attn.o_proj") |
| p6 <- weight_grid(df_wdist, df_kurtosis, "self_attn.q_proj", TRUE) |
| p7 <- weight_grid(df_wdist, df_kurtosis, "self_attn.v_proj") |
|
|
| final_plot1 <- (p1 | p2) |
| final_plot1 |
| ggsave( |
| paste0("pdfs/", model_id, "-mxq-cfgs-from-model1.pdf"), |
| plot = final_plot1, width = 16, height = 9 |
| ) |
|
|
| final_plot2 <- (p3 | p3) |
| final_plot2 |
| ggsave( |
| paste0("pdfs/", model_id, "-mxq-cfgs-from-model2.pdf"), |
| plot = final_plot2, width = 11, height = 6 |
| ) |
|
|
| final_plot3 <- (p4 | p5) |
| final_plot3 |
| ggsave( |
| paste0("pdfs/", model_id, "-mxq-cfgs-from-model3.pdf"), |
| plot = final_plot3, width = 11, height = 6 |
| ) |
|
|
| final_plot4 <- (p6 | p7) |
| final_plot4 |
| ggsave( |
| paste0("pdfs/", model_id, "-mxq-cfgs-from-model4.pdf"), |
| plot = final_plot4, width = 11, height = 6 |
| ) |
|
|