BiasTest / app.py
CatoG's picture
Update app.py
d28a0dd verified
import os
import csv
from datetime import datetime
import gradio as gr
import torch
import pandas as pd
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling,
)
# =========================================================
# CONFIG
# =========================================================
# Small / moderate models that work with AutoModelForCausalLM
MODEL_CHOICES = [
# Very small / light (good for CPU Spaces)
"distilgpt2",
"gpt2",
"sshleifer/tiny-gpt2",
"LiquidAI/LFM2-350M",
"google/gemma-3-270m-it",
"Qwen/Qwen2.5-0.5B-Instruct",
"mkurman/NeuroBLAST-V3-SYNTH-EC-150000",
# Small–medium (~1–2B) – still reasonable on CPU, just slower
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"google/gemma-3-1b-it",
"meta-llama/Llama-3.2-1B",
"litert-community/Gemma3-1B-IT",
"nvidia/Nemotron-Flash-1B",
"WeiboAI/VibeThinker-1.5B",
"Qwen/Qwen3-1.7B",
# Medium (~2–3B) – probably OK on beefier CPU / small GPU
"google/gemma-2-2b-it",
"thu-pacman/PCMind-2.1-Kaiyuan-2B",
"opendatalab/MinerU-HTML", # 0.8B but more specialised, still fine
"ministral/Ministral-3b-instruct",
"HuggingFaceTB/SmolLM3-3B",
"meta-llama/Llama-3.2-3B-Instruct",
"nvidia/Nemotron-Flash-3B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
# Heavier (4–8B) – you really want a GPU Space for these
"Qwen/Qwen3-4B",
"Qwen/Qwen3-4B-Thinking-2507",
"Qwen/Qwen3-4B-Instruct-2507",
"mistralai/Mistral-7B-Instruct-v0.2",
"allenai/Olmo-3-7B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Llama-3.1-8B",
"meta-llama/Llama-3.1-8B-Instruct",
"openbmb/MiniCPM4.1-8B",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"rl-research/DR-Tulu-8B",
]
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" # or TinyLlama, or stick with distilgpt2
device = 0 if torch.cuda.is_available() else -1
# Paths for fact storage and snapshots (runtime, but in the app dir)
ROOT_DIR = os.path.dirname(__file__)
FACTS_FILE = os.path.join(ROOT_DIR, "facts_log.csv")
BASE_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "base_snapshot")
FT_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "ft_snapshot")
# Globals for current model / tokenizer / generator
tokenizer = None
model = None
text_generator = None
# =========================================================
# MODEL LOADING
# =========================================================
def load_model(model_name: str) -> str:
"""
Load tokenizer + model + text generation pipeline for the given model_name.
Updates global variables so the rest of the app uses the selected model.
"""
global tokenizer, model, text_generator
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name)
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
return f"Loaded model: {model_name}"
def init_facts_file():
"""Create CSV with header if it doesn't exist yet."""
if not os.path.exists(FACTS_FILE):
with open(FACTS_FILE, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "fact_text"])
# initial setup
model_status_text = load_model(DEFAULT_MODEL)
init_facts_file()
# =========================================================
# FACT LOGGING
# =========================================================
def log_fact(text: str):
"""Append one fact statement to facts_log.csv."""
if not text:
return
with open(FACTS_FILE, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([datetime.utcnow().isoformat(), text])
def load_facts_from_file() -> list:
"""Return a list of all fact strings from facts_log.csv."""
if not os.path.exists(FACTS_FILE):
return []
df = pd.read_csv(FACTS_FILE)
if "fact_text" not in df.columns:
return []
return [str(x) for x in df["fact_text"].tolist()]
def reset_facts_file():
"""Delete and recreate facts_log.csv."""
if os.path.exists(FACTS_FILE):
os.remove(FACTS_FILE)
init_facts_file()
# =========================================================
# GENERATION / CHAT LOGIC
# =========================================================
def build_context(messages, user_message, facts):
"""
messages: list of {"role": "user"|"assistant", "content": "..."}
facts: list of user-approved fact strings
Build a prompt for a small causal LM for CHAT USE.
Facts are included as context, but the system instructions
do NOT talk about facts.
"""
# Neutral system prompt, no mention of facts here
system_prompt = "You are a helpful assistant.\n\n"
convo = system_prompt
if facts:
convo += "Previously approved user statements:\n"
# use only last N to avoid context explosion
for f in facts[-50:]:
convo += f"- {f}\n"
convo += "\n"
convo += "Conversation:\n"
for m in messages:
if m["role"] == "user":
convo += f"User: {m['content']}\n"
elif m["role"] == "assistant":
convo += f"Assistant: {m['content']}\n"
convo += f"User: {user_message}\nAssistant:"
return convo
def generate_response(user_message, messages, facts):
"""
- messages: list of message dicts (Chatbot "messages" format)
- facts: list of fact strings
Returns:
- cleared textbox content
- updated messages (for Chatbot)
- updated messages (for state)
- last_user (for thumbs)
- last_bot (for thumbs)
"""
if not user_message.strip():
return "", messages, messages, "", ""
prompt_text = build_context(messages, user_message, facts)
outputs = text_generator(
prompt_text,
max_new_tokens=120,
do_sample=True,
top_p=0.9,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
)
full_text = outputs[0]["generated_text"]
# Use the LAST Assistant: block (the newly generated part)
if "Assistant:" in full_text:
bot_part = full_text.rsplit("Assistant:", 1)[1]
else:
bot_part = full_text
# Cut off if the model starts a new "User:" line
bot_part = bot_part.split("\nUser:")[0].strip()
bot_reply = bot_part
messages = messages + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": bot_reply},
]
return "", messages, messages, user_message, bot_reply
# =========================================================
# THUMBS HANDLERS
# =========================================================
def thumb_up(last_user, facts):
"""
Thumbs-up means: treat the LAST USER MESSAGE as a fact to be learned.
"""
if not last_user:
return "No user message to save as fact.", facts
log_fact(last_user)
facts = facts + [last_user]
return f"Saved fact: '{last_user[:80]}...'", facts
def thumb_down(last_user):
"""
Thumbs-down just gives feedback. We don't store anything for this simple demo.
"""
if not last_user:
return "No user message to rate."
return "Ignored this message as a fact (not stored)."
# =========================================================
# TRAINING ON FACTS + SNAPSHOTS
# =========================================================
def train_on_facts():
"""
Supervised fine-tuning on fact statements provided by the user.
Each fact is turned into a simple training text.
Also:
- saves a snapshot of the pre-training (base) model if not already saved
- saves a snapshot of the fine-tuned model after training
"""
global model, text_generator, tokenizer
if not os.path.exists(FACTS_FILE):
return "No facts_log.csv file found."
df = pd.read_csv(FACTS_FILE)
if "fact_text" not in df.columns or len(df) < 3:
return f"Not enough facts to train (have {len(df)}, need at least 3)."
texts = []
for _, row in df.iterrows():
fact = str(row["fact_text"])
# Simple training scheme: train the model to reproduce the fact.
texts.append(f"Fact: {fact}")
dataset = Dataset.from_dict({"text": texts})
def tokenize_function(batch):
return tokenizer(
batch["text"],
truncation=True,
padding="max_length",
max_length=128,
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=["text"],
)
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
training_args = TrainingArguments(
output_dir="facts_ft",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=2,
learning_rate=5e-5,
logging_steps=5,
save_steps=0,
report_to=[],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
# --- Save base snapshot (before training) if not already there ---
if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0:
os.makedirs(BASE_SNAPSHOT_DIR, exist_ok=True)
model.save_pretrained(BASE_SNAPSHOT_DIR)
tokenizer.save_pretrained(BASE_SNAPSHOT_DIR)
# --- Train ---
trainer.train()
# Update pipeline with the fine-tuned model
model = trainer.model
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=device,
)
# --- Save fine-tuned snapshot ---
os.makedirs(FT_SNAPSHOT_DIR, exist_ok=True)
model.save_pretrained(FT_SNAPSHOT_DIR)
tokenizer.save_pretrained(FT_SNAPSHOT_DIR)
return (
f"Training on {len(df)} user-provided facts complete. "
f"The model has been tuned toward your facts. "
f"Base and fine-tuned snapshots saved."
)
# =========================================================
# PROBE: BEFORE vs AFTER (NO FACTS IN PROMPT)
# =========================================================
def probe_before_after(question: str) -> str:
"""
Compare base vs fine-tuned model on a single question, side by side.
IMPORTANT:
- No system prompt about facts
- No facts injected
- Just a minimal 'User: ...\\nAssistant:' prompt
"""
question = (question or "").strip()
if not question:
return "Please enter a question to probe."
# Check that we at least have a base snapshot
if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0:
return (
"No base snapshot found. Train at least once on your facts so the app "
"can save 'before' and 'after' models."
)
# Load base snapshot
try:
base_tokenizer = AutoTokenizer.from_pretrained(BASE_SNAPSHOT_DIR)
base_model = AutoModelForCausalLM.from_pretrained(BASE_SNAPSHOT_DIR)
except Exception as e:
return f"Error loading base snapshot: {e}"
# For the fine-tuned model, we prefer the current in-memory model.
# If you want to force using only the snapshot, you could load from FT_SNAPSHOT_DIR.
ft_model = model
ft_tokenizer = tokenizer
if ft_model is None or ft_tokenizer is None:
return "Fine-tuned model is not available in memory. Try training on facts first."
# Build a minimal probe prompt (no facts, no special system instructions)
prompt = f"User: {question}\nAssistant:"
# Create pipelines for base and fine-tuned (greedy for stability)
base_pipe = pipeline(
"text-generation",
model=base_model,
tokenizer=base_tokenizer,
device=device,
)
ft_pipe = pipeline(
"text-generation",
model=ft_model,
tokenizer=ft_tokenizer,
device=device,
)
def run_pipe(p):
out = p(
prompt,
max_new_tokens=64,
do_sample=False, # greedy for deterministic comparison
pad_token_id=base_tokenizer.eos_token_id,
)
full = out[0]["generated_text"]
if "Assistant:" in full:
ans = full.split("Assistant:", 1)[1].strip()
else:
ans = full.strip()
return ans
try:
base_answer = run_pipe(base_pipe)
except Exception as e:
base_answer = f"Error generating with base model: {e}"
try:
ft_answer = run_pipe(ft_pipe)
except Exception as e:
ft_answer = f"Error generating with fine-tuned model: {e}"
report = f"""### Comparison Probe
**Question**
> {question}
**Base model (before fine-tuning)**
{base_answer}
---
**Fine-tuned model (after training on your facts)**
{ft_answer}
"""
return report
# =========================================================
# RESET / UTILS
# =========================================================
def reset_model_to_base(selected_model: str):
"""
Reload the currently selected base model and discard any fine-tuning
done in this session.
Note: This does NOT remove saved snapshots on disk.
"""
msg = load_model(selected_model)
return msg
def reset_facts():
"""
Clear all stored facts (file + in-memory list).
"""
reset_facts_file()
return "All stored facts have been cleared.", []
def view_facts():
"""
Show a preview of stored facts.
"""
facts = load_facts_from_file()
if not facts:
return "No facts stored yet."
preview = ""
for i, f in enumerate(facts[:50]):
preview += f"{i+1}. {f}\n"
if len(facts) > 50:
preview += f"... and {len(facts) - 50} more.\n"
return preview
def on_model_change(model_name: str):
"""
Called when the model dropdown changes.
Reloads the model and returns a status string.
(Snapshots on disk are not touched.)
"""
msg = load_model(model_name)
return msg
# =========================================================
# GRADIO UI
# =========================================================
with gr.Blocks() as demo:
gr.Markdown(
"""
# πŸ§ͺ Fact-Tuning Demo (with Before/After Comparison)
This demo lets you **teach a language model new "facts"** and then
**fine-tune its weights on those facts**.
- Send a message (a claim or statement).
- Click πŸ‘ to treat that message as a fact.
- When you've added a few facts, click **"Train on my facts"**.
- Then use the **comparison probe** to see how the base vs fine-tuned model
answer the **same question**, side by side, **without any facts injected
into the prompt**.
> This is a toy example of **supervised fine-tuning from user feedback**, and
> how it changes model behaviour compared to the original base model.
"""
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=MODEL_CHOICES,
value=DEFAULT_MODEL,
label="Base model",
)
model_status = gr.Markdown(model_status_text)
chatbot = gr.Chatbot(height=400, label="Conversation")
msg = gr.Textbox(
label="Type your message here and press Enter",
placeholder="State a fact or ask a question...",
)
state_messages = gr.State([]) # list[{"role":..., "content":...}]
state_last_user = gr.State("")
state_last_bot = gr.State("")
state_facts = gr.State(load_facts_from_file()) # in-memory facts list
fact_status = gr.Markdown("", label="Fact status")
train_status = gr.Markdown("", label="Training status")
facts_preview = gr.Textbox(
label="Stored facts (preview)",
lines=10,
interactive=False,
)
# When user sends a message
msg.submit(
generate_response,
inputs=[msg, state_messages, state_facts],
outputs=[msg, chatbot, state_messages, state_last_user, state_last_bot],
)
with gr.Row():
btn_up = gr.Button("πŸ‘ Treat last user message as fact")
btn_down = gr.Button("πŸ‘Ž Do not treat as fact")
btn_up.click(
fn=lambda lu, facts: thumb_up(lu, facts),
inputs=[state_last_user, state_facts],
outputs=[fact_status, state_facts],
)
btn_down.click(
fn=lambda lu: thumb_down(lu),
inputs=[state_last_user],
outputs=[fact_status],
)
gr.Markdown("---")
gr.Markdown("## 🧠 Training")
btn_train_facts = gr.Button("Train on my facts")
btn_train_facts.click(
fn=train_on_facts,
inputs=[],
outputs=[train_status],
)
with gr.Row():
btn_reset_model = gr.Button("Reset model to base weights")
btn_reset_facts = gr.Button("Reset all facts")
btn_reset_model.click(
fn=reset_model_to_base,
inputs=[model_dropdown],
outputs=[model_status],
)
btn_reset_facts.click(
fn=reset_facts,
inputs=[],
outputs=[fact_status, state_facts],
)
gr.Markdown("## πŸ“„ Inspect facts")
btn_view_facts = gr.Button("Refresh facts preview")
btn_view_facts.click(
fn=view_facts,
inputs=[],
outputs=[facts_preview],
)
gr.Markdown("## πŸ” Comparison probe (before vs after fine-tuning)")
probe_question = gr.Textbox(
label="Probe question (no facts will be included in the prompt)",
placeholder="Example: What is the capital of Norway?",
)
probe_output = gr.Markdown(label="Probe result")
btn_probe = gr.Button("Run comparison probe")
btn_probe.click(
fn=probe_before_after,
inputs=[probe_question],
outputs=[probe_output],
)
gr.Markdown("## 🧠 Model status")
model_dropdown.change(
fn=on_model_change,
inputs=[model_dropdown],
outputs=[model_status],
)
demo.launch()