import os import csv from datetime import datetime import gradio as gr import torch import pandas as pd from datasets import Dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, pipeline, Trainer, TrainingArguments, DataCollatorForLanguageModeling, ) # ========================================================= # CONFIG # ========================================================= # Small / moderate models that work with AutoModelForCausalLM MODEL_CHOICES = [ # Very small / light (good for CPU Spaces) "distilgpt2", "gpt2", "sshleifer/tiny-gpt2", "LiquidAI/LFM2-350M", "google/gemma-3-270m-it", "Qwen/Qwen2.5-0.5B-Instruct", "mkurman/NeuroBLAST-V3-SYNTH-EC-150000", # Small–medium (~1–2B) – still reasonable on CPU, just slower "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "google/gemma-3-1b-it", "meta-llama/Llama-3.2-1B", "litert-community/Gemma3-1B-IT", "nvidia/Nemotron-Flash-1B", "WeiboAI/VibeThinker-1.5B", "Qwen/Qwen3-1.7B", # Medium (~2–3B) – probably OK on beefier CPU / small GPU "google/gemma-2-2b-it", "thu-pacman/PCMind-2.1-Kaiyuan-2B", "opendatalab/MinerU-HTML", # 0.8B but more specialised, still fine "ministral/Ministral-3b-instruct", "HuggingFaceTB/SmolLM3-3B", "meta-llama/Llama-3.2-3B-Instruct", "nvidia/Nemotron-Flash-3B-Instruct", "Qwen/Qwen2.5-3B-Instruct", # Heavier (4–8B) – you really want a GPU Space for these "Qwen/Qwen3-4B", "Qwen/Qwen3-4B-Thinking-2507", "Qwen/Qwen3-4B-Instruct-2507", "mistralai/Mistral-7B-Instruct-v0.2", "allenai/Olmo-3-7B-Instruct", "Qwen/Qwen2.5-7B-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Llama-3.1-8B", "meta-llama/Llama-3.1-8B-Instruct", "openbmb/MiniCPM4.1-8B", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "rl-research/DR-Tulu-8B", ] DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" # or TinyLlama, or stick with distilgpt2 device = 0 if torch.cuda.is_available() else -1 # Paths for fact storage and snapshots (runtime, but in the app dir) ROOT_DIR = os.path.dirname(__file__) FACTS_FILE = os.path.join(ROOT_DIR, "facts_log.csv") BASE_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "base_snapshot") FT_SNAPSHOT_DIR = os.path.join(ROOT_DIR, "ft_snapshot") # Globals for current model / tokenizer / generator tokenizer = None model = None text_generator = None # ========================================================= # MODEL LOADING # ========================================================= def load_model(model_name: str) -> str: """ Load tokenizer + model + text generation pipeline for the given model_name. Updates global variables so the rest of the app uses the selected model. """ global tokenizer, model, text_generator tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(model_name) text_generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=device, ) return f"Loaded model: {model_name}" def init_facts_file(): """Create CSV with header if it doesn't exist yet.""" if not os.path.exists(FACTS_FILE): with open(FACTS_FILE, "w", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow(["timestamp", "fact_text"]) # initial setup model_status_text = load_model(DEFAULT_MODEL) init_facts_file() # ========================================================= # FACT LOGGING # ========================================================= def log_fact(text: str): """Append one fact statement to facts_log.csv.""" if not text: return with open(FACTS_FILE, "a", newline="", encoding="utf-8") as f: writer = csv.writer(f) writer.writerow([datetime.utcnow().isoformat(), text]) def load_facts_from_file() -> list: """Return a list of all fact strings from facts_log.csv.""" if not os.path.exists(FACTS_FILE): return [] df = pd.read_csv(FACTS_FILE) if "fact_text" not in df.columns: return [] return [str(x) for x in df["fact_text"].tolist()] def reset_facts_file(): """Delete and recreate facts_log.csv.""" if os.path.exists(FACTS_FILE): os.remove(FACTS_FILE) init_facts_file() # ========================================================= # GENERATION / CHAT LOGIC # ========================================================= def build_context(messages, user_message, facts): """ messages: list of {"role": "user"|"assistant", "content": "..."} facts: list of user-approved fact strings Build a prompt for a small causal LM for CHAT USE. Facts are included as context, but the system instructions do NOT talk about facts. """ # Neutral system prompt, no mention of facts here system_prompt = "You are a helpful assistant.\n\n" convo = system_prompt if facts: convo += "Previously approved user statements:\n" # use only last N to avoid context explosion for f in facts[-50:]: convo += f"- {f}\n" convo += "\n" convo += "Conversation:\n" for m in messages: if m["role"] == "user": convo += f"User: {m['content']}\n" elif m["role"] == "assistant": convo += f"Assistant: {m['content']}\n" convo += f"User: {user_message}\nAssistant:" return convo def generate_response(user_message, messages, facts): """ - messages: list of message dicts (Chatbot "messages" format) - facts: list of fact strings Returns: - cleared textbox content - updated messages (for Chatbot) - updated messages (for state) - last_user (for thumbs) - last_bot (for thumbs) """ if not user_message.strip(): return "", messages, messages, "", "" prompt_text = build_context(messages, user_message, facts) outputs = text_generator( prompt_text, max_new_tokens=120, do_sample=True, top_p=0.9, temperature=0.7, pad_token_id=tokenizer.eos_token_id, ) full_text = outputs[0]["generated_text"] # Use the LAST Assistant: block (the newly generated part) if "Assistant:" in full_text: bot_part = full_text.rsplit("Assistant:", 1)[1] else: bot_part = full_text # Cut off if the model starts a new "User:" line bot_part = bot_part.split("\nUser:")[0].strip() bot_reply = bot_part messages = messages + [ {"role": "user", "content": user_message}, {"role": "assistant", "content": bot_reply}, ] return "", messages, messages, user_message, bot_reply # ========================================================= # THUMBS HANDLERS # ========================================================= def thumb_up(last_user, facts): """ Thumbs-up means: treat the LAST USER MESSAGE as a fact to be learned. """ if not last_user: return "No user message to save as fact.", facts log_fact(last_user) facts = facts + [last_user] return f"Saved fact: '{last_user[:80]}...'", facts def thumb_down(last_user): """ Thumbs-down just gives feedback. We don't store anything for this simple demo. """ if not last_user: return "No user message to rate." return "Ignored this message as a fact (not stored)." # ========================================================= # TRAINING ON FACTS + SNAPSHOTS # ========================================================= def train_on_facts(): """ Supervised fine-tuning on fact statements provided by the user. Each fact is turned into a simple training text. Also: - saves a snapshot of the pre-training (base) model if not already saved - saves a snapshot of the fine-tuned model after training """ global model, text_generator, tokenizer if not os.path.exists(FACTS_FILE): return "No facts_log.csv file found." df = pd.read_csv(FACTS_FILE) if "fact_text" not in df.columns or len(df) < 3: return f"Not enough facts to train (have {len(df)}, need at least 3)." texts = [] for _, row in df.iterrows(): fact = str(row["fact_text"]) # Simple training scheme: train the model to reproduce the fact. texts.append(f"Fact: {fact}") dataset = Dataset.from_dict({"text": texts}) def tokenize_function(batch): return tokenizer( batch["text"], truncation=True, padding="max_length", max_length=128, ) tokenized_dataset = dataset.map( tokenize_function, batched=True, remove_columns=["text"], ) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, ) training_args = TrainingArguments( output_dir="facts_ft", overwrite_output_dir=True, num_train_epochs=3, per_device_train_batch_size=2, learning_rate=5e-5, logging_steps=5, save_steps=0, report_to=[], ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=data_collator, ) # --- Save base snapshot (before training) if not already there --- if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0: os.makedirs(BASE_SNAPSHOT_DIR, exist_ok=True) model.save_pretrained(BASE_SNAPSHOT_DIR) tokenizer.save_pretrained(BASE_SNAPSHOT_DIR) # --- Train --- trainer.train() # Update pipeline with the fine-tuned model model = trainer.model text_generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=device, ) # --- Save fine-tuned snapshot --- os.makedirs(FT_SNAPSHOT_DIR, exist_ok=True) model.save_pretrained(FT_SNAPSHOT_DIR) tokenizer.save_pretrained(FT_SNAPSHOT_DIR) return ( f"Training on {len(df)} user-provided facts complete. " f"The model has been tuned toward your facts. " f"Base and fine-tuned snapshots saved." ) # ========================================================= # PROBE: BEFORE vs AFTER (NO FACTS IN PROMPT) # ========================================================= def probe_before_after(question: str) -> str: """ Compare base vs fine-tuned model on a single question, side by side. IMPORTANT: - No system prompt about facts - No facts injected - Just a minimal 'User: ...\\nAssistant:' prompt """ question = (question or "").strip() if not question: return "Please enter a question to probe." # Check that we at least have a base snapshot if not os.path.exists(BASE_SNAPSHOT_DIR) or len(os.listdir(BASE_SNAPSHOT_DIR)) == 0: return ( "No base snapshot found. Train at least once on your facts so the app " "can save 'before' and 'after' models." ) # Load base snapshot try: base_tokenizer = AutoTokenizer.from_pretrained(BASE_SNAPSHOT_DIR) base_model = AutoModelForCausalLM.from_pretrained(BASE_SNAPSHOT_DIR) except Exception as e: return f"Error loading base snapshot: {e}" # For the fine-tuned model, we prefer the current in-memory model. # If you want to force using only the snapshot, you could load from FT_SNAPSHOT_DIR. ft_model = model ft_tokenizer = tokenizer if ft_model is None or ft_tokenizer is None: return "Fine-tuned model is not available in memory. Try training on facts first." # Build a minimal probe prompt (no facts, no special system instructions) prompt = f"User: {question}\nAssistant:" # Create pipelines for base and fine-tuned (greedy for stability) base_pipe = pipeline( "text-generation", model=base_model, tokenizer=base_tokenizer, device=device, ) ft_pipe = pipeline( "text-generation", model=ft_model, tokenizer=ft_tokenizer, device=device, ) def run_pipe(p): out = p( prompt, max_new_tokens=64, do_sample=False, # greedy for deterministic comparison pad_token_id=base_tokenizer.eos_token_id, ) full = out[0]["generated_text"] if "Assistant:" in full: ans = full.split("Assistant:", 1)[1].strip() else: ans = full.strip() return ans try: base_answer = run_pipe(base_pipe) except Exception as e: base_answer = f"Error generating with base model: {e}" try: ft_answer = run_pipe(ft_pipe) except Exception as e: ft_answer = f"Error generating with fine-tuned model: {e}" report = f"""### Comparison Probe **Question** > {question} **Base model (before fine-tuning)** {base_answer} --- **Fine-tuned model (after training on your facts)** {ft_answer} """ return report # ========================================================= # RESET / UTILS # ========================================================= def reset_model_to_base(selected_model: str): """ Reload the currently selected base model and discard any fine-tuning done in this session. Note: This does NOT remove saved snapshots on disk. """ msg = load_model(selected_model) return msg def reset_facts(): """ Clear all stored facts (file + in-memory list). """ reset_facts_file() return "All stored facts have been cleared.", [] def view_facts(): """ Show a preview of stored facts. """ facts = load_facts_from_file() if not facts: return "No facts stored yet." preview = "" for i, f in enumerate(facts[:50]): preview += f"{i+1}. {f}\n" if len(facts) > 50: preview += f"... and {len(facts) - 50} more.\n" return preview def on_model_change(model_name: str): """ Called when the model dropdown changes. Reloads the model and returns a status string. (Snapshots on disk are not touched.) """ msg = load_model(model_name) return msg # ========================================================= # GRADIO UI # ========================================================= with gr.Blocks() as demo: gr.Markdown( """ # 🧪 Fact-Tuning Demo (with Before/After Comparison) This demo lets you **teach a language model new "facts"** and then **fine-tune its weights on those facts**. - Send a message (a claim or statement). - Click 👍 to treat that message as a fact. - When you've added a few facts, click **"Train on my facts"**. - Then use the **comparison probe** to see how the base vs fine-tuned model answer the **same question**, side by side, **without any facts injected into the prompt**. > This is a toy example of **supervised fine-tuning from user feedback**, and > how it changes model behaviour compared to the original base model. """ ) with gr.Row(): model_dropdown = gr.Dropdown( choices=MODEL_CHOICES, value=DEFAULT_MODEL, label="Base model", ) model_status = gr.Markdown(model_status_text) chatbot = gr.Chatbot(height=400, label="Conversation") msg = gr.Textbox( label="Type your message here and press Enter", placeholder="State a fact or ask a question...", ) state_messages = gr.State([]) # list[{"role":..., "content":...}] state_last_user = gr.State("") state_last_bot = gr.State("") state_facts = gr.State(load_facts_from_file()) # in-memory facts list fact_status = gr.Markdown("", label="Fact status") train_status = gr.Markdown("", label="Training status") facts_preview = gr.Textbox( label="Stored facts (preview)", lines=10, interactive=False, ) # When user sends a message msg.submit( generate_response, inputs=[msg, state_messages, state_facts], outputs=[msg, chatbot, state_messages, state_last_user, state_last_bot], ) with gr.Row(): btn_up = gr.Button("👍 Treat last user message as fact") btn_down = gr.Button("👎 Do not treat as fact") btn_up.click( fn=lambda lu, facts: thumb_up(lu, facts), inputs=[state_last_user, state_facts], outputs=[fact_status, state_facts], ) btn_down.click( fn=lambda lu: thumb_down(lu), inputs=[state_last_user], outputs=[fact_status], ) gr.Markdown("---") gr.Markdown("## 🧠 Training") btn_train_facts = gr.Button("Train on my facts") btn_train_facts.click( fn=train_on_facts, inputs=[], outputs=[train_status], ) with gr.Row(): btn_reset_model = gr.Button("Reset model to base weights") btn_reset_facts = gr.Button("Reset all facts") btn_reset_model.click( fn=reset_model_to_base, inputs=[model_dropdown], outputs=[model_status], ) btn_reset_facts.click( fn=reset_facts, inputs=[], outputs=[fact_status, state_facts], ) gr.Markdown("## 📄 Inspect facts") btn_view_facts = gr.Button("Refresh facts preview") btn_view_facts.click( fn=view_facts, inputs=[], outputs=[facts_preview], ) gr.Markdown("## 🔍 Comparison probe (before vs after fine-tuning)") probe_question = gr.Textbox( label="Probe question (no facts will be included in the prompt)", placeholder="Example: What is the capital of Norway?", ) probe_output = gr.Markdown(label="Probe result") btn_probe = gr.Button("Run comparison probe") btn_probe.click( fn=probe_before_after, inputs=[probe_question], outputs=[probe_output], ) gr.Markdown("## 🧠 Model status") model_dropdown.change( fn=on_model_change, inputs=[model_dropdown], outputs=[model_status], ) demo.launch()