| | |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import datasets |
| | import plotly.graph_objects as go |
| | import numpy as np |
| | import polars as pl |
| |
|
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-34B", trust_remote_code=True) |
| | alpaca = datasets.load_dataset("tatsu-lab/alpaca", split="train").map( |
| | lambda ex: {"tokens": tokenizer(ex["text"])["input_ids"].__len__()}, num_proc=4 |
| | ) |
| |
|
| |
|
| | pdf = pl.DataFrame(alpaca.to_pandas()).with_columns(index=pl.int_range(0, pl.count())) |
| | tokens = pdf["tokens"].to_numpy() |
| |
|
| | |
| |
|
| |
|
| | def plot_batch(batch_size): |
| | |
| | data = pdf["tokens"].to_numpy().copy() |
| | |
| | data = data[:batch_size] |
| | |
| | max_value = max(data) |
| |
|
| | |
| | fig = go.Figure() |
| |
|
| | |
| | for i, value in enumerate(data): |
| | fig.add_trace( |
| | go.Bar( |
| | x=[value], |
| | y=[i + 1], |
| | |
| | orientation="h", |
| | marker_color="blue", |
| | ) |
| | ) |
| | fig.add_trace( |
| | go.Bar( |
| | x=[max_value - value], |
| | y=[i + 1], |
| | |
| | orientation="h", |
| | marker_color="red", |
| | ) |
| | ) |
| |
|
| | |
| | fig.update_layout( |
| | barmode="stack", |
| | |
| | |
| | |
| | showlegend=False, |
| | xaxis=dict(range=[0, max_value]), |
| | ) |
| |
|
| | |
| | return fig |
| |
|
| |
|
| | def packing(pocket=8192): |
| | num_pocket = 0 |
| | buffers = 0 |
| |
|
| | for token in tokens: |
| | tmp_len = buffers + token |
| | if tmp_len > pocket: |
| | num_pocket += 1 |
| | buffers = token |
| | else: |
| | buffers = tmp_len |
| | if buffers: |
| | num_pocket += 1 |
| | return num_pocket * pocket / tokens.sum() |
| |
|
| |
|
| | |
| |
|
| | plot_batch(30) |
| |
|
| | |
| | arrs = [] |
| | |
| | for batch_size in range(1, 100): |
| | arr = ( |
| | pdf.with_columns( |
| | batch=pl.col("tokens").max().over(pl.col("index") // batch_size) |
| | ) |
| | .select( |
| | pl.col("tokens").sum().over(pl.col("index") // batch_size).mean(), |
| | ((pl.col("batch")) / pl.col("tokens")).mean(), |
| | ) |
| | .to_numpy() |
| | ) |
| | arrs.append(arr) |
| | x_values, y_values = np.concatenate(arrs).transpose() |
| | pxs = np.linspace(tokens.max(), x_values[-1], 100) |
| | pys = [packing(pocket) for pocket in pxs] |
| |
|
| |
|
| | fig = go.Figure() |
| | |
| | fig.add_trace(go.Scatter(x=x_values, y=y_values, mode="lines", name="Batching")) |
| |
|
| |
|
| | |
| | fig.add_trace( |
| | go.Scatter( |
| | x=pxs, |
| | y=pys, |
| | mode="lines", |
| | name="Packing", |
| | |
| | ) |
| | ) |
| |
|
| | worst = tokens.max() / tokens.mean() |
| | fig.add_trace( |
| | go.Scatter( |
| | x=x_values, |
| | y=[worst] * len(x_values), |
| | mode="lines", |
| | name="Worst", |
| | line=dict(dash="dash"), |
| | ) |
| | ) |
| | fig.add_trace( |
| | go.Scatter( |
| | x=[8192], |
| | y=[packing(8192)], |
| | mode="markers", |
| | name="Chosen", |
| | |
| | ) |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | fig.update_layout( |
| | |
| | xaxis_title="throughput(tokens)", |
| | yaxis_title="computational cost(ratio)", |
| | yaxis=dict(range=[0, worst + 1]), |
| | ) |
| |
|
| | |
| |
|
| | |
| | fig.show() |
| |
|