File size: 3,653 Bytes
21ad80b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env Rscript

library(tidyverse)
library(ggthemes)
library(readr)
library(plotly)
library(optparse)

second_largest <- function(x) {
  sort(unique(x), decreasing = TRUE)[2L]
}

parser <- OptionParser()
parser <- add_option(
  parser, c("-d", "--csv_file"),
  type = "character",
  help = "Combined PPL result CSV file",
  metavar = "character"
)
parser <- add_option(
  parser, c("--attempt"),
  type = "character",
  help = "The attempt to plot",
  metavar = "character"
)

args <- parse_args(parser)

if (is.null(args$csv_file)) {
  csv_fp <- "data/combined.csv"
} else {
  csv_fp <- args$csv_file
}
if (is.null(args$attempt)) {
  the_attempt <- "mxq1"
} else {
  the_attempt <- args$attempt
}

all_cols1 <- c(
  "model",
  "algo",
  "config",
  "bpp",
  "load_mem_allot"
)
all_cols2 <- c(
  "model",
  "algo",
  "config.x",
  "bpp",
  "load_mem_allot"
)
df_all <- read_csv(csv_fp)
df_wo_mxq <- df_all |>
  filter(algo != "mxq") |>
  select(all_of(all_cols1))

df_fp16 <- df_all |>
  filter(algo == "fp16")

df_baseline <- df_all |>
  filter(algo == "hqq") |>
  left_join(df_fp16, by = c("model"), suffix = c(".x", "")) |>
  select(all_of(all_cols2)) |>
  rename(
    config = config.x,
  )

df_hqq <- df_wo_mxq |> filter(algo == "hqq" & !str_detect(config, "^mxq"))
df_mxq <- df_all |>
  filter(
    algo == "mxq" & attempt != "mxq-kurt-global" & attempt != "mxq-kurt-scaled"
  )
df_mxq1 <- df_hqq |>
  left_join(
    df_mxq,
    suffix = c(".x", ""),
    join_by(model, bpp)
  ) |>
  mutate(
    algo = paste0(algo, "-", attempt)
  ) |>
  select(c("model", "algo", "config.x", "bpp", "load_mem_allot")) |>
  rename(
    config = config.x
  )


df_disp <- bind_rows(df_wo_mxq, df_baseline, df_mxq1) |>
  filter(str_detect(config, "^b4")) |>
  mutate(
    model = factor(
      model,
      levels = c("Llama-2-7b-hf", "Meta-Llama-3-8B", "Llama-2-13b-hf"),
      labels = c("Llama-2-7B", "Llama-3-8B", "Llama-2-13B")
    ),
    algo = factor(
      algo,
      levels = (
        c(
          the_attempt,
          "awq",
          "gptq",
          "hqq",
          "mxq-mxq1",
          "mxq-mxq2",
          "bnb",
          "fp16"
        )
      )
    ),
    config = factor(config, levels = (c("b4g32", "b4g64", "b4g128")))
  ) |>
  filter(algo != "bnb")

df_2nd_largest <- df_disp |>
  group_by(model) |>
  summarise(
    second_max_mem = second_largest(load_mem_allot)
  ) |>
  ungroup()
df_disp <- df_disp |>
  left_join(
    df_2nd_largest,
    join_by(model)
  )

plt1 <- ggplot(df_disp, aes(x = algo, y = load_mem_allot, fill = algo)) +
  geom_col(aes(x = load_mem_allot, y = algo), show.legend = FALSE) +
  geom_text(
    data = subset(df_disp, algo == "fp16"),
    aes(second_max_mem, y = algo, label = toupper(algo)),
    hjust = 0,
    nudge_x = -0.5,
    colour = "white",
    size = 3
  ) +
  geom_text(
    data = subset(df_disp, algo != "fp16"),
    aes(x = load_mem_allot, y = algo, label = toupper(algo)),
    hjust = 0,
    nudge_x = 0.3,
    colour = "black",
    size = 3
  ) +
  labs(y = "Algorithm", x = "GPU Memory(GiB)") +
  scale_x_continuous(
    limits = c(0, 22),
    breaks = seq(0, 22, by = 4),
    expand = c(0, 0),
    position = "bottom"
  ) +
  scale_y_discrete(expand = expansion(add = c(0, 0.5))) +
  theme(
    panel.background = element_rect(fill = "white"),
    panel.grid.major.x = element_line(color = "#A8BAC4", size = 0.3),
    axis.ticks.length = unit(0, "mm"),
    axis.title = element_blank(),
    axis.text.y = element_blank()
  ) +
  facet_grid(config ~ model) +
  scale_color_solarized()
plt1
ggsave("pdfs/llama-mem-consumption.pdf", plot = plt1, width = 8, height = 6)

ggplotly(plt1)