|
|
import torch |
|
|
from pathlib import Path |
|
|
import gradio as gr |
|
|
import json |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
MODEL_NAME = "FlameF0X/i3-80m" |
|
|
LOCAL_SAFETENSORS = Path("model.safetensors") |
|
|
LOCAL_BIN = Path("pytorch_model.bin") |
|
|
VOCAB_JSON = Path("chunk_vocab_combined.json") |
|
|
|
|
|
|
|
|
with open(VOCAB_JSON, 'r') as f: |
|
|
vocab_data = json.load(f) |
|
|
VOCAB_SIZE = vocab_data["vocab_size"] |
|
|
|
|
|
|
|
|
from app_classes import i3Model, ChunkTokenizer |
|
|
|
|
|
tokenizer = ChunkTokenizer() |
|
|
tokenizer.load(VOCAB_JSON) |
|
|
|
|
|
model = i3Model( |
|
|
vocab_size=VOCAB_SIZE, |
|
|
d_model=512, |
|
|
n_heads=16, |
|
|
max_seq_len=256, |
|
|
d_state=32 |
|
|
).to(DEVICE) |
|
|
|
|
|
|
|
|
try: |
|
|
if LOCAL_SAFETENSORS.exists(): |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(LOCAL_SAFETENSORS) |
|
|
model.load_state_dict(state_dict) |
|
|
print("β
Loaded weights from local safetensors") |
|
|
elif LOCAL_BIN.exists(): |
|
|
state_dict = torch.load(LOCAL_BIN, map_location=DEVICE, weights_only=False) |
|
|
model.load_state_dict(state_dict) |
|
|
print("β
Loaded weights from local .bin") |
|
|
else: |
|
|
print("β‘ Downloading model from HuggingFace...") |
|
|
bin_file = hf_hub_download(repo_id=MODEL_NAME, filename="pytorch_model.bin") |
|
|
state_dict = torch.load(bin_file, map_location=DEVICE, weights_only=False) |
|
|
model.load_state_dict(state_dict) |
|
|
print("β
Loaded weights from HuggingFace") |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load model weights: {e}") |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
def generate_text(prompt, max_tokens=100, temperature=0.8, top_k=40): |
|
|
if not prompt.strip(): |
|
|
return "β οΈ Please enter a prompt to generate text." |
|
|
|
|
|
try: |
|
|
idx = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long).to(DEVICE) |
|
|
out_idx = model.generate(idx, max_new_tokens=max_tokens, temperature=temperature, top_k=top_k) |
|
|
return tokenizer.decode(out_idx[0].cpu()) |
|
|
except Exception as e: |
|
|
return f"β Generation error: {str(e)}" |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
} |
|
|
.main-header { |
|
|
text-align: center; |
|
|
margin-bottom: 2rem; |
|
|
} |
|
|
.param-card { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
padding: 1.5rem; |
|
|
border-radius: 12px; |
|
|
margin-bottom: 1rem; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π i3-80M Text Generation |
|
|
### Powered by Mamba-based Architecture |
|
|
Generate creative text using the i3-80M language model with customizable parameters. |
|
|
""", |
|
|
elem_classes="main-header" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
prompt_input = gr.Textbox( |
|
|
label="βοΈ Enter Your Prompt", |
|
|
placeholder="Once upon a time in a distant galaxy...", |
|
|
lines=4, |
|
|
max_lines=8 |
|
|
) |
|
|
|
|
|
with gr.Accordion("βοΈ Generation Parameters", open=True): |
|
|
with gr.Row(): |
|
|
max_tokens_input = gr.Slider( |
|
|
10, 500, |
|
|
value=100, |
|
|
step=10, |
|
|
label="Max Tokens", |
|
|
info="Maximum number of tokens to generate" |
|
|
) |
|
|
temp_input = gr.Slider( |
|
|
0.1, 2.0, |
|
|
value=0.8, |
|
|
step=0.05, |
|
|
label="Temperature", |
|
|
info="Higher = more creative, Lower = more focused" |
|
|
) |
|
|
|
|
|
topk_input = gr.Slider( |
|
|
1, 100, |
|
|
value=40, |
|
|
step=1, |
|
|
label="Top-k Sampling", |
|
|
info="Number of top tokens to consider" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
generate_btn = gr.Button("π¨ Generate Text", variant="primary", size="lg") |
|
|
clear_btn = gr.ClearButton(components=[prompt_input], value="ποΈ Clear", size="lg") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output_text = gr.Textbox( |
|
|
label="π Generated Output", |
|
|
lines=12, |
|
|
max_lines=20, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["The future of artificial intelligence is", 150, 0.7, 50], |
|
|
["In a world where technology and nature coexist", 200, 0.9, 40], |
|
|
["The scientist discovered something remarkable", 120, 0.8, 45], |
|
|
], |
|
|
inputs=[prompt_input, max_tokens_input, temp_input, topk_input], |
|
|
label="π‘ Try These Examples" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Accordion("π§ Developer Info", open=False): |
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown(f""" |
|
|
**Model Architecture:** |
|
|
- **Model:** i3-80M |
|
|
- **Device:** {DEVICE} |
|
|
- **Vocab Size:** {VOCAB_SIZE:,} |
|
|
- **Parameters:** {total_params:,} ({total_params/1e6:.2f}M) |
|
|
""") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown(f""" |
|
|
**Configuration:** |
|
|
- **d_model:** 512 |
|
|
- **n_heads:** 16 |
|
|
- **max_seq_len:** 256 |
|
|
- **d_state:** 32 |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
<div style="text-align: center; color: #666;"> |
|
|
<p>Built with β€οΈ using Gradio | Model: FlameF0X/i3-80m</p> |
|
|
</div> |
|
|
""", |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
generate_text, |
|
|
inputs=[prompt_input, max_tokens_input, temp_input, topk_input], |
|
|
outputs=[output_text] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=False) |