| | from pathlib import Path |
| |
|
| | import matplotlib as matplotlib |
| | import matplotlib.cm as cm |
| | import pandas as pd |
| | import streamlit as st |
| | import tokenizers |
| | import torch |
| | import torch.nn.functional as F |
| | from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode |
| |
|
| | PROJ = Path(__file__).parent |
| |
|
| | tokenizer_hash_funcs = { |
| | tokenizers.Tokenizer: lambda _: None, |
| | tokenizers.AddedToken: lambda _: None, |
| | } |
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | classmap = { |
| | "O": "O", |
| | "PER": "🙎", |
| | "person": "🙎", |
| | "LOC": "🌎", |
| | "location": "🌎", |
| | "ORG": "🏤", |
| | "corporation": "🏤", |
| | "product": "📱", |
| | "creative": "🎷", |
| | "MISC": "🎷", |
| | } |
| |
|
| |
|
| | def aggrid_interactive_table(df: pd.DataFrame) -> dict: |
| | """Creates an st-aggrid interactive table based on a dataframe. |
| | |
| | Args: |
| | df (pd.DataFrame]): Source dataframe |
| | Returns: |
| | dict: The selected row |
| | """ |
| | options = GridOptionsBuilder.from_dataframe( |
| | df, enableRowGroup=True, enableValue=True, enablePivot=True |
| | ) |
| |
|
| | options.configure_side_bar() |
| | |
| |
|
| | options.configure_selection("single") |
| | selection = AgGrid( |
| | df, |
| | enable_enterprise_modules=True, |
| | gridOptions=options.build(), |
| | theme="light", |
| | update_mode=GridUpdateMode.NO_UPDATE, |
| | allow_unsafe_jscode=True, |
| | ) |
| |
|
| | return selection |
| |
|
| |
|
| | def explode_df(df: pd.DataFrame) -> pd.DataFrame: |
| | """Takes a dataframe and explodes all the fields.""" |
| |
|
| | df_tokens = df.apply(pd.Series.explode) |
| | if "losses" in df.columns: |
| | df_tokens["losses"] = df_tokens["losses"].astype(float) |
| | return df_tokens |
| |
|
| |
|
| | def align_sample(row: pd.Series): |
| | """Uses word_ids to align all lists in a sample.""" |
| |
|
| | columns = row.axes[0].to_list() |
| | indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]] |
| |
|
| | out = {} |
| |
|
| | tokens = [] |
| | for i, tok in enumerate(row.tokens): |
| | if row.word_ids[i] == -1: |
| | continue |
| |
|
| | if row.word_ids[i] != row.word_ids[i - 1]: |
| | tokens.append(tok.lstrip("▁").lstrip("##").rstrip("@@")) |
| | else: |
| | tokens[-1] += tok.lstrip("▁").lstrip("##").rstrip("@@") |
| | out["tokens"] = tokens |
| |
|
| | if "preds" in columns: |
| | out["preds"] = [row.preds[i] for i in indices] |
| |
|
| | if "labels" in columns: |
| | out["labels"] = [row.labels[i] for i in indices] |
| |
|
| | if "losses" in columns: |
| | out["losses"] = [row.losses[i] for i in indices] |
| |
|
| | if "probs" in columns: |
| | out["probs"] = [row.probs[i] for i in indices] |
| |
|
| | if "hidden_states" in columns: |
| | out["hidden_states"] = [row.hidden_states[i] for i in indices] |
| |
|
| | if "ids" in columns: |
| | out["ids"] = row.ids |
| |
|
| | assert len(tokens) == len(out["preds"]), (tokens, row.tokens) |
| |
|
| | return out |
| |
|
| |
|
| | @st.cache( |
| | allow_output_mutation=True, |
| | hash_funcs=tokenizer_hash_funcs, |
| | ) |
| | def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame: |
| | """Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities. |
| | |
| | Args: |
| | text (str): The text to be processed |
| | tokenizer: Tokenizer to use |
| | model (_type_): Model to use |
| | device (torch.device): The device we want pytorch to use for its calcultaions. |
| | |
| | Returns: |
| | pd.DataFrame: A data frame holding the tagged text. |
| | """ |
| |
|
| | tokens = tokenizer(text).tokens() |
| | tokenized = tokenizer(text, return_tensors="pt") |
| | word_ids = [w if w is not None else -1 for w in tokenized.word_ids()] |
| | input_ids = tokenized.input_ids.to(device) |
| | outputs = model(input_ids, output_hidden_states=True) |
| | preds = torch.argmax(outputs.logits, dim=2) |
| | preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()] |
| | hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy() |
| | |
| |
|
| | probs = 1 // ( |
| | torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy() |
| | ) |
| |
|
| | df = pd.DataFrame( |
| | [[tokens, word_ids, preds, probs, hidden_states]], |
| | columns="tokens word_ids preds probs hidden_states".split(), |
| | ) |
| | merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist()) |
| | return explode_df(merged_df).reset_index().drop(columns=["index"]) |
| |
|
| |
|
| | def get_bg_color(label: str): |
| | """Retrieves a label's color from the session state.""" |
| | return st.session_state[f"color_{label}"] |
| |
|
| |
|
| | def get_fg_color(bg_color_hex: str) -> str: |
| | """Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast. |
| | |
| | Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/ |
| | |
| | Args: |
| | bg_color_hex (str): The background color given as a HEX stirng. |
| | |
| | Returns: |
| | str: Either "black" or "white". |
| | """ |
| | r = int(bg_color_hex[1:3], 16) |
| | g = int(bg_color_hex[3:5], 16) |
| | b = int(bg_color_hex[5:7], 16) |
| | yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000 |
| | return "black" if (yiq >= 128) else "white" |
| |
|
| |
|
| | def colorize_classes(df: pd.DataFrame) -> pd.DataFrame: |
| | """Colorizes the errors in the dataframe.""" |
| |
|
| | def colorize_row(row): |
| | return [ |
| | "background-color: " |
| | + ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink") |
| | + ";" |
| | ] * len(row) |
| |
|
| | def colorize_col(col): |
| | if col.name == "labels" or col.name == "preds": |
| | bgs = [] |
| | fgs = [] |
| | for v in col.values: |
| | bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff") |
| | fgs.append(get_fg_color(bgs[-1])) |
| | return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)] |
| | return [""] * len(col) |
| |
|
| | df = df.reset_index().drop(columns=["index"]).T |
| | return df |
| |
|
| |
|
| | def htmlify_labeled_example(example: pd.DataFrame) -> str: |
| | """Builds an HTML (string) representation of a single example. |
| | |
| | Args: |
| | example (pd.DataFrame): The example to process. |
| | |
| | Returns: |
| | str: An HTML string representation of a single example. |
| | """ |
| | html = [] |
| |
|
| | for _, row in example.iterrows(): |
| | pred = row.preds.split("-")[1] if "-" in row.preds else "O" |
| | label = row.labels |
| | label_class = row.labels.split("-")[1] if "-" in row.labels else "O" |
| |
|
| | color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000" |
| | true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000" |
| |
|
| | font_color = get_fg_color(color) if color else "white" |
| | true_font_color = get_fg_color(true_color) if true_color else "white" |
| |
|
| | is_correct = row.preds == row.labels |
| | loss_html = ( |
| | "" |
| | if float(row.losses) < 0.01 |
| | else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>" |
| | ) |
| | loss_html = "" |
| |
|
| | if row.labels == row.preds == "O": |
| | html.append(f"<span>{row.tokens}</span>") |
| | elif row.labels == "IGN": |
| | assert False |
| | else: |
| | opacity = "1" if not is_correct else "0.5" |
| | correct = ( |
| | "" |
| | if is_correct |
| | else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>" |
| | ) |
| | pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else "" |
| | html.append( |
| | f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}" |
| | ) |
| |
|
| | return " ".join(html) |
| |
|
| |
|
| | def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str: |
| | """Turns a value into a color using a color map.""" |
| | norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) |
| | cmap = cm.get_cmap(cmap_name) |
| | rgba = cmap(norm(abs(value))) |
| | color = matplotlib.colors.rgb2hex(rgba[:3]) |
| | return color |
| |
|