|
|
import os |
|
|
import csv |
|
|
from datetime import datetime |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import pandas as pd |
|
|
from datasets import Dataset |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
pipeline, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
DataCollatorForLanguageModeling, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_CHOICES = [ |
|
|
|
|
|
"distilgpt2", |
|
|
"gpt2", |
|
|
"sshleifer/tiny-gpt2", |
|
|
"LiquidAI/LFM2-350M", |
|
|
"google/gemma-3-270m-it", |
|
|
"Qwen/Qwen2.5-0.5B-Instruct", |
|
|
"mkurman/NeuroBLAST-V3-SYNTH-EC-150000", |
|
|
|
|
|
|
|
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
|
"google/gemma-3-1b-it", |
|
|
"meta-llama/Llama-3.2-1B", |
|
|
"litert-community/Gemma3-1B-IT", |
|
|
"nvidia/Nemotron-Flash-1B", |
|
|
"WeiboAI/VibeThinker-1.5B", |
|
|
"Qwen/Qwen3-1.7B", |
|
|
|
|
|
|
|
|
"google/gemma-2-2b-it", |
|
|
"thu-pacman/PCMind-2.1-Kaiyuan-2B", |
|
|
"opendatalab/MinerU-HTML", |
|
|
"ministral/Ministral-3b-instruct", |
|
|
"HuggingFaceTB/SmolLM3-3B", |
|
|
"meta-llama/Llama-3.2-3B-Instruct", |
|
|
"nvidia/Nemotron-Flash-3B-Instruct", |
|
|
"Qwen/Qwen2.5-3B-Instruct", |
|
|
|
|
|
|
|
|
"Qwen/Qwen3-4B", |
|
|
"Qwen/Qwen3-4B-Thinking-2507", |
|
|
"Qwen/Qwen3-4B-Instruct-2507", |
|
|
"mistralai/Mistral-7B-Instruct-v0.2", |
|
|
"allenai/Olmo-3-7B-Instruct", |
|
|
"Qwen/Qwen2.5-7B-Instruct", |
|
|
"meta-llama/Meta-Llama-3-8B-Instruct", |
|
|
"meta-llama/Llama-3.1-8B", |
|
|
"meta-llama/Llama-3.1-8B-Instruct", |
|
|
"openbmb/MiniCPM4.1-8B", |
|
|
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B", |
|
|
"rl-research/DR-Tulu-8B", |
|
|
] |
|
|
DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" |
|
|
|
|
|
device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
|
|
|
ROOT_DIR = os.path.dirname(__file__) |
|
|
FACTS_FILE = os.path.join(ROOT_DIR, "facts_log.csv") |
|
|
BASE_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "base_snapshot") |
|
|
FT_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "ft_snapshot") |
|
|
|
|
|
|
|
|
tokenizer = None |
|
|
model = None |
|
|
text_generator = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(model_name: str) -> str: |
|
|
""" |
|
|
Load tokenizer + model + text generation pipeline for the given model_name. |
|
|
Updates global variables so the rest of the app uses the selected model. |
|
|
""" |
|
|
global tokenizer, model, text_generator |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
text_generator = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
return f"Loaded model: {model_name}" |
|
|
|
|
|
|
|
|
def init_facts_file(): |
|
|
"""Create CSV with header if it doesn't exist yet.""" |
|
|
if not os.path.exists(FACTS_FILE): |
|
|
with open(FACTS_FILE, "w", newline="", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow(["timestamp", "fact_text"]) |
|
|
|
|
|
|
|
|
|
|
|
model_status_text = load_model(DEFAULT_MODEL) |
|
|
init_facts_file() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def log_fact(text: str): |
|
|
"""Append one fact statement to facts_log.csv.""" |
|
|
if not text: |
|
|
return |
|
|
with open(FACTS_FILE, "a", newline="", encoding="utf-8") as f: |
|
|
writer = csv.writer(f) |
|
|
writer.writerow([datetime.utcnow().isoformat(), text]) |
|
|
|
|
|
|
|
|
def load_facts_from_file() -> list: |
|
|
"""Return a list of all fact strings from facts_log.csv.""" |
|
|
if not os.path.exists(FACTS_FILE): |
|
|
return [] |
|
|
df = pd.read_csv(FACTS_FILE) |
|
|
if "fact_text" not in df.columns: |
|
|
return [] |
|
|
return [str(x) for x in df["fact_text"].tolist()] |
|
|
|
|
|
|
|
|
def reset_facts_file(): |
|
|
"""Delete and recreate facts_log.csv.""" |
|
|
if os.path.exists(FACTS_FILE): |
|
|
os.remove(FACTS_FILE) |
|
|
init_facts_file() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_context(messages, user_message, facts): |
|
|
""" |
|
|
messages: list of {"role": "user"|"assistant", "content": "..."} |
|
|
facts: list of user-approved fact strings |
|
|
|
|
|
Build a prompt for a small causal LM for CHAT USE. |
|
|
Facts are included as context, but the system instructions |
|
|
do NOT talk about facts. |
|
|
""" |
|
|
|
|
|
system_prompt = "You are a helpful assistant.\n\n" |
|
|
|
|
|
convo = system_prompt |
|
|
|
|
|
if facts: |
|
|
convo += "Previously approved user statements:\n" |
|
|
|
|
|
for f in facts[-50:]: |
|
|
convo += f"- {f}\n" |
|
|
convo += "\n" |
|
|
|
|
|
convo += "Conversation:\n" |
|
|
for m in messages: |
|
|
if m["role"] == "user": |
|
|
convo += f"User: {m['content']}\n" |
|
|
elif m["role"] == "assistant": |
|
|
convo += f"Assistant: {m['content']}\n" |
|
|
|
|
|
convo += f"User: {user_message}\nAssistant:" |
|
|
return convo |
|
|
|
|
|
|
|
|
def generate_response(user_message, messages, facts): |
|
|
""" |
|
|
- messages: list of message dicts (Chatbot "messages" format) |
|
|
- facts: list of fact strings |
|
|
|
|
|
Returns: |
|
|
- cleared textbox content |
|
|
- updated messages (for Chatbot) |
|
|
- updated messages (for state) |
|
|
- last_user (for thumbs) |
|
|
- last_bot (for thumbs) |
|
|
""" |
|
|
if not user_message.strip(): |
|
|
return "", messages, messages, "", "" |
|
|
|
|
|
prompt_text = build_context(messages, user_message, facts) |
|
|
|
|
|
outputs = text_generator( |
|
|
prompt_text, |
|
|
max_new_tokens=120, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
temperature=0.7, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
full_text = outputs[0]["generated_text"] |
|
|
|
|
|
|
|
|
if "Assistant:" in full_text: |
|
|
bot_part = full_text.rsplit("Assistant:", 1)[1] |
|
|
else: |
|
|
bot_part = full_text |
|
|
|
|
|
|
|
|
bot_part = bot_part.split("\nUser:")[0].strip() |
|
|
|
|
|
bot_reply = bot_part |
|
|
|
|
|
messages = messages + [ |
|
|
{"role": "user", "content": user_message}, |
|
|
{"role": "assistant", "content": bot_reply}, |
|
|
] |
|
|
|
|
|
return "", messages, messages, user_message, bot_reply |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def thumb_up(last_user, facts): |
|
|
""" |
|
|
Thumbs-up means: treat the LAST USER MESSAGE as a fact to be learned. |
|
|
""" |
|
|
if not last_user: |
|
|
return "No user message to save as fact.", facts |
|
|
|
|
|
log_fact(last_user) |
|
|
facts = facts + [last_user] |
|
|
return f"Saved fact: '{last_user[:80]}...'", facts |
|
|
|
|
|
|
|
|
def thumb_down(last_user): |
|
|
""" |
|
|
Thumbs-down just gives feedback. We don't store anything for this simple demo. |
|
|
""" |
|
|
if not last_user: |
|
|
return "No user message to rate." |
|
|
return "Ignored this message as a fact (not stored)." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_on_facts(): |
|
|
""" |
|
|
Supervised fine-tuning on fact statements provided by the user. |
|
|
Each fact is turned into a simple training text. |
|
|
Also: |
|
|
- saves a snapshot of the pre-training (base) model if not already saved |
|
|
- saves a snapshot of the fine-tuned model after training |
|
|
""" |
|
|
global model, text_generator, tokenizer |
|
|
|
|
|
if not os.path.exists(FACTS_FILE): |
|
|
return "No facts_log.csv file found." |
|
|
|
|
|
df = pd.read_csv(FACTS_FILE) |
|
|
if "fact_text" not in df.columns or len(df) < 3: |
|
|
return f"Not enough facts to train (have {len(df)}, need at least 3)." |
|
|
|
|
|
texts = [] |
|
|
for _, row in df.iterrows(): |
|
|
fact = str(row["fact_text"]) |
|
|
|
|
|
texts.append(f"Fact: {fact}") |
|
|
|
|
|
dataset = Dataset.from_dict({"text": texts}) |
|
|
|
|
|
def tokenize_function(batch): |
|
|
return tokenizer( |
|
|
batch["text"], |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=128, |
|
|
) |
|
|
|
|
|
tokenized_dataset = dataset.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=["text"], |
|
|
) |
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=False, |
|
|
) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="facts_ft", |
|
|
overwrite_output_dir=True, |
|
|
num_train_epochs=3, |
|
|
per_device_train_batch_size=2, |
|
|
learning_rate=5e-5, |
|
|
logging_steps=5, |
|
|
save_steps=0, |
|
|
report_to=[], |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_dataset, |
|
|
data_collator=data_collator, |
|
|
) |
|
|
|
|
|
|
|
|
if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0: |
|
|
os.makedirs(BASE_SNAPSHOT_DIR, exist_ok=True) |
|
|
model.save_pretrained(BASE_SNAPSHOT_DIR) |
|
|
tokenizer.save_pretrained(BASE_SNAPSHOT_DIR) |
|
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
model = trainer.model |
|
|
text_generator = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
|
|
|
os.makedirs(FT_SNAPSHOT_DIR, exist_ok=True) |
|
|
model.save_pretrained(FT_SNAPSHOT_DIR) |
|
|
tokenizer.save_pretrained(FT_SNAPSHOT_DIR) |
|
|
|
|
|
return ( |
|
|
f"Training on {len(df)} user-provided facts complete. " |
|
|
f"The model has been tuned toward your facts. " |
|
|
f"Base and fine-tuned snapshots saved." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def probe_before_after(question: str) -> str: |
|
|
""" |
|
|
Compare base vs fine-tuned model on a single question, side by side. |
|
|
|
|
|
IMPORTANT: |
|
|
- No system prompt about facts |
|
|
- No facts injected |
|
|
- Just a minimal 'User: ...\\nAssistant:' prompt |
|
|
""" |
|
|
|
|
|
question = (question or "").strip() |
|
|
if not question: |
|
|
return "Please enter a question to probe." |
|
|
|
|
|
|
|
|
if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0: |
|
|
return ( |
|
|
"No base snapshot found. Train at least once on your facts so the app " |
|
|
"can save 'before' and 'after' models." |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
base_tokenizer = AutoTokenizer.from_pretrained(BASE_SNAPSHOT_DIR) |
|
|
base_model = AutoModelForCausalLM.from_pretrained(BASE_SNAPSHOT_DIR) |
|
|
except Exception as e: |
|
|
return f"Error loading base snapshot: {e}" |
|
|
|
|
|
|
|
|
|
|
|
ft_model = model |
|
|
ft_tokenizer = tokenizer |
|
|
|
|
|
if ft_model is None or ft_tokenizer is None: |
|
|
return "Fine-tuned model is not available in memory. Try training on facts first." |
|
|
|
|
|
|
|
|
prompt = f"User: {question}\nAssistant:" |
|
|
|
|
|
|
|
|
base_pipe = pipeline( |
|
|
"text-generation", |
|
|
model=base_model, |
|
|
tokenizer=base_tokenizer, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
ft_pipe = pipeline( |
|
|
"text-generation", |
|
|
model=ft_model, |
|
|
tokenizer=ft_tokenizer, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
def run_pipe(p): |
|
|
out = p( |
|
|
prompt, |
|
|
max_new_tokens=64, |
|
|
do_sample=False, |
|
|
pad_token_id=base_tokenizer.eos_token_id, |
|
|
) |
|
|
full = out[0]["generated_text"] |
|
|
if "Assistant:" in full: |
|
|
ans = full.split("Assistant:", 1)[1].strip() |
|
|
else: |
|
|
ans = full.strip() |
|
|
return ans |
|
|
|
|
|
try: |
|
|
base_answer = run_pipe(base_pipe) |
|
|
except Exception as e: |
|
|
base_answer = f"Error generating with base model: {e}" |
|
|
|
|
|
try: |
|
|
ft_answer = run_pipe(ft_pipe) |
|
|
except Exception as e: |
|
|
ft_answer = f"Error generating with fine-tuned model: {e}" |
|
|
|
|
|
report = f"""### Comparison Probe |
|
|
|
|
|
**Question** |
|
|
|
|
|
> {question} |
|
|
|
|
|
**Base model (before fine-tuning)** |
|
|
|
|
|
{base_answer} |
|
|
|
|
|
--- |
|
|
|
|
|
**Fine-tuned model (after training on your facts)** |
|
|
|
|
|
{ft_answer} |
|
|
""" |
|
|
return report |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_model_to_base(selected_model: str): |
|
|
""" |
|
|
Reload the currently selected base model and discard any fine-tuning |
|
|
done in this session. |
|
|
Note: This does NOT remove saved snapshots on disk. |
|
|
""" |
|
|
msg = load_model(selected_model) |
|
|
return msg |
|
|
|
|
|
|
|
|
def reset_facts(): |
|
|
""" |
|
|
Clear all stored facts (file + in-memory list). |
|
|
""" |
|
|
reset_facts_file() |
|
|
return "All stored facts have been cleared.", [] |
|
|
|
|
|
|
|
|
def view_facts(): |
|
|
""" |
|
|
Show a preview of stored facts. |
|
|
""" |
|
|
facts = load_facts_from_file() |
|
|
if not facts: |
|
|
return "No facts stored yet." |
|
|
preview = "" |
|
|
for i, f in enumerate(facts[:50]): |
|
|
preview += f"{i+1}. {f}\n" |
|
|
if len(facts) > 50: |
|
|
preview += f"... and {len(facts) - 50} more.\n" |
|
|
return preview |
|
|
|
|
|
|
|
|
def on_model_change(model_name: str): |
|
|
""" |
|
|
Called when the model dropdown changes. |
|
|
Reloads the model and returns a status string. |
|
|
(Snapshots on disk are not touched.) |
|
|
""" |
|
|
msg = load_model(model_name) |
|
|
return msg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π§ͺ Fact-Tuning Demo (with Before/After Comparison) |
|
|
|
|
|
This demo lets you **teach a language model new "facts"** and then |
|
|
**fine-tune its weights on those facts**. |
|
|
|
|
|
- Send a message (a claim or statement). |
|
|
- Click π to treat that message as a fact. |
|
|
- When you've added a few facts, click **"Train on my facts"**. |
|
|
- Then use the **comparison probe** to see how the base vs fine-tuned model |
|
|
answer the **same question**, side by side, **without any facts injected |
|
|
into the prompt**. |
|
|
|
|
|
> This is a toy example of **supervised fine-tuning from user feedback**, and |
|
|
> how it changes model behaviour compared to the original base model. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=MODEL_CHOICES, |
|
|
value=DEFAULT_MODEL, |
|
|
label="Base model", |
|
|
) |
|
|
|
|
|
model_status = gr.Markdown(model_status_text) |
|
|
|
|
|
chatbot = gr.Chatbot(height=400, label="Conversation") |
|
|
|
|
|
msg = gr.Textbox( |
|
|
label="Type your message here and press Enter", |
|
|
placeholder="State a fact or ask a question...", |
|
|
) |
|
|
|
|
|
state_messages = gr.State([]) |
|
|
state_last_user = gr.State("") |
|
|
state_last_bot = gr.State("") |
|
|
state_facts = gr.State(load_facts_from_file()) |
|
|
|
|
|
fact_status = gr.Markdown("", label="Fact status") |
|
|
train_status = gr.Markdown("", label="Training status") |
|
|
facts_preview = gr.Textbox( |
|
|
label="Stored facts (preview)", |
|
|
lines=10, |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
|
|
|
msg.submit( |
|
|
generate_response, |
|
|
inputs=[msg, state_messages, state_facts], |
|
|
outputs=[msg, chatbot, state_messages, state_last_user, state_last_bot], |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
btn_up = gr.Button("π Treat last user message as fact") |
|
|
btn_down = gr.Button("π Do not treat as fact") |
|
|
|
|
|
btn_up.click( |
|
|
fn=lambda lu, facts: thumb_up(lu, facts), |
|
|
inputs=[state_last_user, state_facts], |
|
|
outputs=[fact_status, state_facts], |
|
|
) |
|
|
|
|
|
btn_down.click( |
|
|
fn=lambda lu: thumb_down(lu), |
|
|
inputs=[state_last_user], |
|
|
outputs=[fact_status], |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
gr.Markdown("## π§ Training") |
|
|
|
|
|
btn_train_facts = gr.Button("Train on my facts") |
|
|
|
|
|
btn_train_facts.click( |
|
|
fn=train_on_facts, |
|
|
inputs=[], |
|
|
outputs=[train_status], |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
btn_reset_model = gr.Button("Reset model to base weights") |
|
|
btn_reset_facts = gr.Button("Reset all facts") |
|
|
|
|
|
btn_reset_model.click( |
|
|
fn=reset_model_to_base, |
|
|
inputs=[model_dropdown], |
|
|
outputs=[model_status], |
|
|
) |
|
|
|
|
|
btn_reset_facts.click( |
|
|
fn=reset_facts, |
|
|
inputs=[], |
|
|
outputs=[fact_status, state_facts], |
|
|
) |
|
|
|
|
|
gr.Markdown("## π Inspect facts") |
|
|
|
|
|
btn_view_facts = gr.Button("Refresh facts preview") |
|
|
|
|
|
btn_view_facts.click( |
|
|
fn=view_facts, |
|
|
inputs=[], |
|
|
outputs=[facts_preview], |
|
|
) |
|
|
|
|
|
gr.Markdown("## π Comparison probe (before vs after fine-tuning)") |
|
|
|
|
|
probe_question = gr.Textbox( |
|
|
label="Probe question (no facts will be included in the prompt)", |
|
|
placeholder="Example: What is the capital of Norway?", |
|
|
) |
|
|
|
|
|
probe_output = gr.Markdown(label="Probe result") |
|
|
|
|
|
btn_probe = gr.Button("Run comparison probe") |
|
|
|
|
|
btn_probe.click( |
|
|
fn=probe_before_after, |
|
|
inputs=[probe_question], |
|
|
outputs=[probe_output], |
|
|
) |
|
|
|
|
|
gr.Markdown("## π§ Model status") |
|
|
|
|
|
model_dropdown.change( |
|
|
fn=on_model_change, |
|
|
inputs=[model_dropdown], |
|
|
outputs=[model_status], |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|