teszenofficial commited on
Commit
0968217
·
verified ·
1 Parent(s): 6e24dfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -67
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import json
5
  import time
6
  import gc
 
7
  from fastapi import FastAPI, Request
8
  from fastapi.responses import HTMLResponse, StreamingResponse
9
  from fastapi.middleware.cors import CORSMiddleware
@@ -143,27 +144,27 @@ class MTPModel(nn.Module):
143
  logits = self.lm_head(x)
144
  return logits
145
 
146
- def generate(self, input_ids, max_new_tokens=100, temperature=0.8, top_k=50, top_p=0.9, repetition_penalty=1.1):
147
- """Método de generación compatible con la interfaz"""
148
  generated = input_ids
 
149
 
150
  for _ in range(max_new_tokens):
151
- # Obtener logits para el último token
152
  with torch.no_grad():
153
  logits = self(generated)
154
  next_logits = logits[0, -1, :] / temperature
155
 
156
- # Aplicar repetition penalty
157
  if repetition_penalty != 1.0:
158
  for token_id in set(generated[0].tolist()):
159
  next_logits[token_id] /= repetition_penalty
160
 
161
- # Top-k filtering
162
  if top_k > 0:
163
  indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
164
  next_logits[indices_to_remove] = float('-inf')
165
 
166
- # Top-p filtering
167
  if top_p < 1.0:
168
  sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
169
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
@@ -173,17 +174,23 @@ class MTPModel(nn.Module):
173
  indices_to_remove = sorted_indices[sorted_indices_to_remove]
174
  next_logits[indices_to_remove] = float('-inf')
175
 
176
- # Sampling
177
  probs = F.softmax(next_logits, dim=-1)
178
  next_token = torch.multinomial(probs, num_samples=1).item()
179
 
180
- # Parar en EOS
181
- if next_token == 3: # EOS ID para SentencePiece
 
182
  break
183
 
 
 
 
 
 
 
184
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
185
 
186
- return generated
187
 
188
  # ======================
189
  # DESCARGA Y CARGA DEL MODELO
@@ -216,12 +223,16 @@ tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
216
  sp = spm.SentencePieceProcessor()
217
  sp.load(tokenizer_path)
218
  VOCAB_SIZE = sp.get_piece_size()
 
 
219
 
220
  # Actualizar vocab_size en config
221
  config["vocab_size"] = VOCAB_SIZE
222
 
223
  print(f"🧠 Inicializando modelo MTP-1.1...")
224
  print(f" → Vocabulario: {VOCAB_SIZE}")
 
 
225
  print(f" → Dimensión: {config['d_model']}")
226
  print(f" → Capas: {config['n_layers']}")
227
  print(f" → Heads: {config['n_heads']}")
@@ -252,6 +263,69 @@ if DEVICE == "cpu":
252
  param_count = sum(p.numel() for p in model.parameters())
253
  print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  # ======================
256
  # API CONFIG
257
  # ======================
@@ -270,8 +344,8 @@ app.add_middleware(
270
 
271
  class PromptRequest(BaseModel):
272
  text: str = Field(..., max_length=2000, description="Texto de entrada")
273
- max_tokens: int = Field(default=150, ge=10, le=300, description="Tokens máximos a generar")
274
- temperature: float = Field(default=0.7, ge=0.1, le=2.0, description="Temperatura de muestreo")
275
  top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
276
  top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
277
  repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
@@ -310,50 +384,78 @@ async def generate(req: PromptRequest):
310
  global ACTIVE_REQUESTS
311
  ACTIVE_REQUESTS += 1
312
 
313
- dyn_max_tokens = req.max_tokens
314
- dyn_temperature = req.temperature
315
-
316
- if ACTIVE_REQUESTS > 2:
317
- print(f"⚠️ Carga alta ({ACTIVE_REQUESTS} requests). Ajustando parámetros.")
318
- dyn_max_tokens = min(dyn_max_tokens, 120)
319
- dyn_temperature = max(0.5, dyn_temperature * 0.9)
320
-
321
- user_input = req.text.strip()
322
- if not user_input:
323
- ACTIVE_REQUESTS -= 1
324
- return {"reply": "", "tokens_generated": 0}
325
-
326
- full_prompt = build_prompt(user_input)
327
- tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
328
- input_ids = torch.tensor([tokens], device=DEVICE)
329
-
330
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  with torch.no_grad():
332
- output_ids = model.generate(
333
  input_ids,
334
  max_new_tokens=dyn_max_tokens,
335
  temperature=dyn_temperature,
336
  top_k=req.top_k,
337
  top_p=req.top_p,
338
- repetition_penalty=req.repetition_penalty
 
339
  )
340
 
 
341
  gen_tokens = output_ids[0, len(tokens):].tolist()
342
 
 
343
  safe_tokens = [
344
  t for t in gen_tokens
345
  if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
346
  ]
347
 
348
- response = tokenizer_wrapper.decode(safe_tokens).strip()
 
349
 
350
- if "###" in response:
351
- response = response.split("###")[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  return {
354
- "reply": response,
355
  "tokens_generated": len(safe_tokens),
356
- "model": "MTP-1.1"
 
357
  }
358
 
359
  except Exception as e:
@@ -393,7 +495,7 @@ def model_info():
393
  }
394
 
395
  # ======================
396
- # INTERFAZ WEB (MODERNA DE MTP-3)
397
  # ======================
398
  @app.get("/", response_class=HTMLResponse)
399
  def chat_ui():
@@ -403,7 +505,7 @@ def chat_ui():
403
  <head>
404
  <meta charset="UTF-8">
405
  <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
406
- <title>MTP 1.1</title>
407
  <link rel="preconnect" href="https://fonts.googleapis.com">
408
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
409
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
@@ -415,8 +517,6 @@ def chat_ui():
415
  --text-primary: #e3e3e3;
416
  --text-secondary: #9aa0a6;
417
  --user-bubble: #282a2c;
418
- --bot-actions-color: #c4c7c5;
419
- --logo-url: url('https://i.postimg.cc/yxS54PF3/IMG-3082.jpg');
420
  }
421
  * { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
422
  body {
@@ -452,10 +552,13 @@ header {
452
  width: 32px;
453
  height: 32px;
454
  border-radius: 50%;
455
- background-image: var(--logo-url);
456
- background-size: cover;
457
- background-position: center;
458
- border: 1px solid rgba(255,255,255,0.1);
 
 
 
459
  }
460
  .brand-text {
461
  font-weight: 500;
@@ -472,6 +575,28 @@ header {
472
  border-radius: 12px;
473
  font-weight: 600;
474
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  .chat-scroll {
476
  flex: 1;
477
  overflow-y: auto;
@@ -522,8 +647,13 @@ header {
522
  height: 34px;
523
  min-width: 34px;
524
  border-radius: 50%;
525
- background-image: var(--logo-url);
526
- background-size: cover;
 
 
 
 
 
527
  box-shadow: 0 2px 6px rgba(0,0,0,0.2);
528
  }
529
  .bot-actions {
@@ -590,6 +720,9 @@ header {
590
  font-family: inherit;
591
  padding: 10px 0;
592
  }
 
 
 
593
  #mainBtn {
594
  background: white;
595
  color: black;
@@ -616,12 +749,6 @@ header {
616
  to { opacity: 1; transform: translateY(0); }
617
  }
618
  @keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
619
- @keyframes pulseAvatar {
620
- 0% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0.4); }
621
- 70% { box-shadow: 0 0 0 8px rgba(74, 158, 255, 0); }
622
- 100% { box-shadow: 0 0 0 0 rgba(74, 158, 255, 0); }
623
- }
624
- .pulsing { animation: pulseAvatar 1.5s infinite; }
625
  ::-webkit-scrollbar { width: 8px; }
626
  ::-webkit-scrollbar-track { background: transparent; }
627
  ::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
@@ -630,18 +757,22 @@ header {
630
  <body>
631
  <header>
632
  <div class="brand-wrapper" onclick="location.reload()">
633
- <div class="brand-logo"></div>
634
  <div class="brand-text">
635
- MTP <span class="version-badge">1.1</span>
636
  </div>
637
  </div>
 
 
 
 
638
  </header>
639
  <div id="chatScroll" class="chat-scroll">
640
  <div class="msg-row bot" style="animation-delay: 0.1s;">
641
- <div class="bot-avatar"></div>
642
  <div class="msg-content-wrapper">
643
  <div class="msg-text">
644
- ¡Hola! Soy MTP 1.1. ¿En qué puedo ayudarte hoy?
645
  </div>
646
  </div>
647
  </div>
@@ -659,26 +790,34 @@ header {
659
  const chatScroll = document.getElementById('chatScroll');
660
  const userInput = document.getElementById('userInput');
661
  const mainBtn = document.getElementById('mainBtn');
 
662
  let isGenerating = false;
663
  let abortController = null;
664
  let typingTimeout = null;
665
  let lastUserPrompt = "";
 
666
  const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
667
  const ICON_STOP = `<svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
 
668
  mainBtn.innerHTML = ICON_SEND;
 
669
  function scrollToBottom() {
670
  chatScroll.scrollTop = chatScroll.scrollHeight;
671
  }
 
672
  function setBtnState(state) {
673
  if (state === 'sending') {
674
  mainBtn.innerHTML = ICON_STOP;
675
  isGenerating = true;
 
676
  } else {
677
  mainBtn.innerHTML = ICON_SEND;
678
  isGenerating = false;
679
  abortController = null;
 
680
  }
681
  }
 
682
  function handleBtnClick() {
683
  if (isGenerating) {
684
  stopGeneration();
@@ -686,39 +825,44 @@ function handleBtnClick() {
686
  sendMessage();
687
  }
688
  }
 
689
  function stopGeneration() {
690
  if (abortController) abortController.abort();
691
  if (typingTimeout) clearTimeout(typingTimeout);
692
  const activeCursor = document.querySelector('.typing-cursor');
693
  if (activeCursor) activeCursor.classList.remove('typing-cursor');
694
- const activeAvatar = document.querySelector('.pulsing');
695
- if (activeAvatar) activeAvatar.classList.remove('pulsing');
696
  setBtnState('idle');
697
  userInput.focus();
698
  }
 
699
  async function sendMessage(textOverride = null) {
700
  const text = textOverride || userInput.value.trim();
701
  if (!text) return;
702
  lastUserPrompt = text;
 
703
  if (!textOverride) {
704
  userInput.value = '';
705
  addMessage(text, 'user');
706
  }
 
707
  setBtnState('sending');
708
  abortController = new AbortController();
 
709
  const botRow = document.createElement('div');
710
  botRow.className = 'msg-row bot';
711
  const avatar = document.createElement('div');
712
- avatar.className = 'bot-avatar pulsing';
 
713
  const wrapper = document.createElement('div');
714
  wrapper.className = 'msg-content-wrapper';
715
  const msgText = document.createElement('div');
716
- msgText.className = 'msg-text';
717
  wrapper.appendChild(msgText);
718
  botRow.appendChild(avatar);
719
  botRow.appendChild(wrapper);
720
  chatScroll.appendChild(botRow);
721
  scrollToBottom();
 
722
  try {
723
  const response = await fetch('/generate', {
724
  method: 'POST',
@@ -726,11 +870,15 @@ async function sendMessage(textOverride = null) {
726
  body: JSON.stringify({ text: text }),
727
  signal: abortController.signal
728
  });
 
729
  const data = await response.json();
730
- if (!isGenerating) return;
731
- avatar.classList.remove('pulsing');
732
- const reply = data.reply || "No entendí eso.";
 
 
733
  await typeWriter(msgText, reply);
 
734
  if (isGenerating) {
735
  addActions(wrapper, reply);
736
  setBtnState('idle');
@@ -739,13 +887,13 @@ async function sendMessage(textOverride = null) {
739
  if (error.name === 'AbortError') {
740
  msgText.textContent += " [Detenido]";
741
  } else {
742
- avatar.classList.remove('pulsing');
743
- msgText.textContent = "Error de conexión.";
744
  msgText.style.color = "#ff8b8b";
745
- setBtnState('idle');
746
  }
 
747
  }
748
  }
 
749
  function addMessage(text, sender) {
750
  const row = document.createElement('div');
751
  row.className = `msg-row ${sender}`;
@@ -756,10 +904,13 @@ function addMessage(text, sender) {
756
  chatScroll.appendChild(row);
757
  scrollToBottom();
758
  }
 
759
  function typeWriter(element, text, speed = 12) {
760
  return new Promise(resolve => {
761
  let i = 0;
 
762
  element.classList.add('typing-cursor');
 
763
  function type() {
764
  if (!isGenerating) {
765
  element.classList.remove('typing-cursor');
@@ -779,30 +930,43 @@ function typeWriter(element, text, speed = 12) {
779
  type();
780
  });
781
  }
 
782
  function addActions(wrapperElement, textToCopy) {
783
  const actionsDiv = document.createElement('div');
784
  actionsDiv.className = 'bot-actions';
 
785
  const copyBtn = document.createElement('button');
786
  copyBtn.className = 'action-btn';
787
  copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
788
  copyBtn.onclick = () => {
789
  navigator.clipboard.writeText(textToCopy);
 
 
 
 
790
  };
 
791
  const regenBtn = document.createElement('button');
792
  regenBtn.className = 'action-btn';
793
  regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
794
  regenBtn.onclick = () => {
795
  sendMessage(lastUserPrompt);
796
  };
 
797
  actionsDiv.appendChild(copyBtn);
798
  actionsDiv.appendChild(regenBtn);
799
  wrapperElement.appendChild(actionsDiv);
800
  requestAnimationFrame(() => actionsDiv.style.opacity = "1");
801
  scrollToBottom();
802
  }
 
803
  userInput.addEventListener('keydown', (e) => {
804
- if (e.key === 'Enter') handleBtnClick();
 
 
 
805
  });
 
806
  window.onload = () => userInput.focus();
807
  </script>
808
  </body>
 
4
  import json
5
  import time
6
  import gc
7
+ import re
8
  from fastapi import FastAPI, Request
9
  from fastapi.responses import HTMLResponse, StreamingResponse
10
  from fastapi.middleware.cors import CORSMiddleware
 
144
  logits = self.lm_head(x)
145
  return logits
146
 
147
+ def generate(self, input_ids, max_new_tokens=100, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.1, eos_token_id=3):
148
+ """Método de generación mejorado con parada limpia"""
149
  generated = input_ids
150
+ eos_detected = False
151
 
152
  for _ in range(max_new_tokens):
 
153
  with torch.no_grad():
154
  logits = self(generated)
155
  next_logits = logits[0, -1, :] / temperature
156
 
157
+ # Repetition penalty
158
  if repetition_penalty != 1.0:
159
  for token_id in set(generated[0].tolist()):
160
  next_logits[token_id] /= repetition_penalty
161
 
162
+ # Top-k
163
  if top_k > 0:
164
  indices_to_remove = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
165
  next_logits[indices_to_remove] = float('-inf')
166
 
167
+ # Top-p
168
  if top_p < 1.0:
169
  sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
170
  cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
 
174
  indices_to_remove = sorted_indices[sorted_indices_to_remove]
175
  next_logits[indices_to_remove] = float('-inf')
176
 
 
177
  probs = F.softmax(next_logits, dim=-1)
178
  next_token = torch.multinomial(probs, num_samples=1).item()
179
 
180
+ # Detener en EOS o tokens sospechosos
181
+ if next_token == eos_token_id:
182
+ eos_detected = True
183
  break
184
 
185
+ # Detener si detectamos repetición excesiva del mismo token
186
+ if len(generated[0]) > 10:
187
+ last_tokens = generated[0][-10:].tolist()
188
+ if len(set(last_tokens)) == 1:
189
+ break
190
+
191
  generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
192
 
193
+ return generated, eos_detected
194
 
195
  # ======================
196
  # DESCARGA Y CARGA DEL MODELO
 
223
  sp = spm.SentencePieceProcessor()
224
  sp.load(tokenizer_path)
225
  VOCAB_SIZE = sp.get_piece_size()
226
+ EOS_TOKEN_ID = sp.eos_id()
227
+ BOS_TOKEN_ID = sp.bos_id()
228
 
229
  # Actualizar vocab_size en config
230
  config["vocab_size"] = VOCAB_SIZE
231
 
232
  print(f"🧠 Inicializando modelo MTP-1.1...")
233
  print(f" → Vocabulario: {VOCAB_SIZE}")
234
+ print(f" → EOS token ID: {EOS_TOKEN_ID}")
235
+ print(f" → BOS token ID: {BOS_TOKEN_ID}")
236
  print(f" → Dimensión: {config['d_model']}")
237
  print(f" → Capas: {config['n_layers']}")
238
  print(f" → Heads: {config['n_heads']}")
 
263
  param_count = sum(p.numel() for p in model.parameters())
264
  print(f"✅ Modelo cargado: {param_count:,} parámetros ({param_count/1e6:.1f}M)")
265
 
266
+ # ======================
267
+ # FUNCIONES DE LIMPIEZA DE RESPUESTAS
268
+ # ======================
269
+ def clean_response(text: str, original_prompt: str = None) -> str:
270
+ """Limpia la respuesta generada eliminando basura y repeticiones"""
271
+ if not text:
272
+ return "Lo siento, no pude generar una respuesta."
273
+
274
+ # Eliminar el prompt original si aparece al inicio
275
+ if original_prompt:
276
+ prompt_clean = original_prompt.strip().lower()
277
+ text_lower = text.lower()
278
+ if text_lower.startswith(prompt_clean):
279
+ text = text[len(original_prompt):].strip()
280
+ elif prompt_clean in text_lower[:50]:
281
+ # Buscar después del prompt
282
+ idx = text_lower.find(prompt_clean)
283
+ if idx != -1:
284
+ text = text[idx + len(original_prompt):].strip()
285
+
286
+ # Eliminar partes que contienen "###"
287
+ if "###" in text:
288
+ text = text.split("###")[0].strip()
289
+
290
+ # Eliminar repeticiones absurdas (patrones como "xxx" repetido)
291
+ words = text.split()
292
+ if len(words) > 10:
293
+ unique_words = []
294
+ last_word = None
295
+ repeat_count = 0
296
+ for w in words:
297
+ if w == last_word:
298
+ repeat_count += 1
299
+ if repeat_count > 2:
300
+ continue
301
+ else:
302
+ repeat_count = 0
303
+ unique_words.append(w)
304
+ last_word = w
305
+ text = " ".join(unique_words)
306
+
307
+ # Eliminar fragmentos que parecen basura (patrones sin sentido)
308
+ garbage_patterns = [
309
+ r'[a-z]{20,}', # Palabras muy largas sin sentido
310
+ r'\d{5,}', # Números muy largos
311
+ r'[^\w\s\.\,\!\?\-áéíóúüñ]{10,}', # Caracteres extraños repetidos
312
+ ]
313
+ for pattern in garbage_patterns:
314
+ text = re.sub(pattern, '', text)
315
+
316
+ # Limpiar espacios múltiples
317
+ text = re.sub(r'\s+', ' ', text).strip()
318
+
319
+ # Capitalizar primera letra
320
+ if text and len(text) > 0:
321
+ text = text[0].upper() + text[1:] if len(text) > 1 else text.upper()
322
+
323
+ # Si la respuesta es demasiado corta o vacía, dar mensaje por defecto
324
+ if len(text) < 3:
325
+ return "Entendido. ¿Algo más en lo que pueda ayudarte?"
326
+
327
+ return text
328
+
329
  # ======================
330
  # API CONFIG
331
  # ======================
 
344
 
345
  class PromptRequest(BaseModel):
346
  text: str = Field(..., max_length=2000, description="Texto de entrada")
347
+ max_tokens: int = Field(default=100, ge=10, le=200, description="Tokens máximos a generar")
348
+ temperature: float = Field(default=0.7, ge=0.1, le=1.5, description="Temperatura de muestreo")
349
  top_k: int = Field(default=50, ge=1, le=100, description="Top-k sampling")
350
  top_p: float = Field(default=0.9, ge=0.1, le=1.0, description="Top-p (nucleus) sampling")
351
  repetition_penalty: float = Field(default=1.1, ge=1.0, le=2.0, description="Penalización por repetición")
 
384
  global ACTIVE_REQUESTS
385
  ACTIVE_REQUESTS += 1
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  try:
388
+ user_input = req.text.strip()
389
+ if not user_input:
390
+ return {"reply": "", "tokens_generated": 0}
391
+
392
+ # Construir prompt
393
+ full_prompt = build_prompt(user_input)
394
+ tokens = [tokenizer_wrapper.bos_id()] + tokenizer_wrapper.encode(full_prompt)
395
+ input_ids = torch.tensor([tokens], device=DEVICE)
396
+
397
+ # Parámetros dinámicos según carga
398
+ dyn_max_tokens = req.max_tokens
399
+ dyn_temperature = req.temperature
400
+
401
+ if ACTIVE_REQUESTS > 2:
402
+ dyn_max_tokens = min(dyn_max_tokens, 80)
403
+ dyn_temperature = max(0.5, dyn_temperature * 0.9)
404
+
405
+ # Generar
406
  with torch.no_grad():
407
+ output_ids, eos_detected = model.generate(
408
  input_ids,
409
  max_new_tokens=dyn_max_tokens,
410
  temperature=dyn_temperature,
411
  top_k=req.top_k,
412
  top_p=req.top_p,
413
+ repetition_penalty=req.repetition_penalty,
414
+ eos_token_id=tokenizer_wrapper.eos_id()
415
  )
416
 
417
+ # Extraer solo los tokens generados (excluyendo el prompt)
418
  gen_tokens = output_ids[0, len(tokens):].tolist()
419
 
420
+ # Filtrar tokens inválidos
421
  safe_tokens = [
422
  t for t in gen_tokens
423
  if 0 <= t < VOCAB_SIZE and t != tokenizer_wrapper.eos_id()
424
  ]
425
 
426
+ # Decodificar
427
+ raw_response = tokenizer_wrapper.decode(safe_tokens).strip()
428
 
429
+ # Limpiar respuesta
430
+ clean_reply = clean_response(raw_response, user_input)
431
+
432
+ # Si EOS no fue detectado y la respuesta parece incompleta, truncar
433
+ if not eos_detected and len(clean_reply) > 200:
434
+ # Buscar un punto final para truncar
435
+ last_period = clean_reply.rfind('.')
436
+ if last_period > 100:
437
+ clean_reply = clean_reply[:last_period + 1]
438
+
439
+ # Eliminar frases sin sentido comunes
440
+ nonsense_phrases = [
441
+ "foompañances", "ciudadores", "mejtedon", "calportedon",
442
+ "rápidodcor", "rápidodarse", "miel", "baon", "domol"
443
+ ]
444
+ for phrase in nonsense_phrases:
445
+ clean_reply = clean_reply.replace(phrase, "")
446
+
447
+ # Limpiar espacios dobles nuevamente
448
+ clean_reply = re.sub(r'\s+', ' ', clean_reply).strip()
449
+
450
+ # Si la respuesta sigue siendo muy larga y no tiene puntos, cortar
451
+ if len(clean_reply) > 300 and '.' not in clean_reply[-50:]:
452
+ clean_reply = clean_reply[:250] + "..."
453
 
454
  return {
455
+ "reply": clean_reply,
456
  "tokens_generated": len(safe_tokens),
457
+ "model": "MTP-1.1",
458
+ "eos_detected": eos_detected
459
  }
460
 
461
  except Exception as e:
 
495
  }
496
 
497
  # ======================
498
+ # INTERFAZ WEB (MODERNA)
499
  # ======================
500
  @app.get("/", response_class=HTMLResponse)
501
  def chat_ui():
 
505
  <head>
506
  <meta charset="UTF-8">
507
  <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
508
+ <title>MTP 1.1 - Chat IA</title>
509
  <link rel="preconnect" href="https://fonts.googleapis.com">
510
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
511
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap" rel="stylesheet">
 
517
  --text-primary: #e3e3e3;
518
  --text-secondary: #9aa0a6;
519
  --user-bubble: #282a2c;
 
 
520
  }
521
  * { box-sizing: border-box; outline: none; -webkit-tap-highlight-color: transparent; }
522
  body {
 
552
  width: 32px;
553
  height: 32px;
554
  border-radius: 50%;
555
+ background: linear-gradient(135deg, #4a9eff, #8a6eff);
556
+ display: flex;
557
+ align-items: center;
558
+ justify-content: center;
559
+ font-weight: bold;
560
+ font-size: 14px;
561
+ color: white;
562
  }
563
  .brand-text {
564
  font-weight: 500;
 
575
  border-radius: 12px;
576
  font-weight: 600;
577
  }
578
+ .status-badge {
579
+ font-size: 0.7rem;
580
+ background: rgba(76, 175, 80, 0.15);
581
+ color: #4caf50;
582
+ padding: 2px 8px;
583
+ border-radius: 12px;
584
+ font-weight: 500;
585
+ display: flex;
586
+ align-items: center;
587
+ gap: 6px;
588
+ }
589
+ .status-badge .dot {
590
+ width: 8px;
591
+ height: 8px;
592
+ background: #4caf50;
593
+ border-radius: 50%;
594
+ animation: pulse 1.5s infinite;
595
+ }
596
+ @keyframes pulse {
597
+ 0%, 100% { opacity: 1; transform: scale(1); }
598
+ 50% { opacity: 0.5; transform: scale(0.8); }
599
+ }
600
  .chat-scroll {
601
  flex: 1;
602
  overflow-y: auto;
 
647
  height: 34px;
648
  min-width: 34px;
649
  border-radius: 50%;
650
+ background: linear-gradient(135deg, #4a9eff, #8a6eff);
651
+ display: flex;
652
+ align-items: center;
653
+ justify-content: center;
654
+ font-weight: bold;
655
+ font-size: 14px;
656
+ color: white;
657
  box-shadow: 0 2px 6px rgba(0,0,0,0.2);
658
  }
659
  .bot-actions {
 
720
  font-family: inherit;
721
  padding: 10px 0;
722
  }
723
+ #userInput::placeholder {
724
+ color: var(--text-secondary);
725
+ }
726
  #mainBtn {
727
  background: white;
728
  color: black;
 
749
  to { opacity: 1; transform: translateY(0); }
750
  }
751
  @keyframes blink { 0%, 100% { opacity: 1; } 50% { opacity: 0; } }
 
 
 
 
 
 
752
  ::-webkit-scrollbar { width: 8px; }
753
  ::-webkit-scrollbar-track { background: transparent; }
754
  ::-webkit-scrollbar-thumb { background: #333; border-radius: 4px; }
 
757
  <body>
758
  <header>
759
  <div class="brand-wrapper" onclick="location.reload()">
760
+ <div class="brand-logo">M</div>
761
  <div class="brand-text">
762
+ MTP <span class="version-badge">2.5</span>
763
  </div>
764
  </div>
765
+ <div class="status-badge">
766
+ <span class="dot"></span>
767
+ <span id="statusText">Conectado</span>
768
+ </div>
769
  </header>
770
  <div id="chatScroll" class="chat-scroll">
771
  <div class="msg-row bot" style="animation-delay: 0.1s;">
772
+ <div class="bot-avatar">M</div>
773
  <div class="msg-content-wrapper">
774
  <div class="msg-text">
775
+ ¡Hola! Soy MTP 2.5 ¿En qué puedo ayudarte hoy?
776
  </div>
777
  </div>
778
  </div>
 
790
  const chatScroll = document.getElementById('chatScroll');
791
  const userInput = document.getElementById('userInput');
792
  const mainBtn = document.getElementById('mainBtn');
793
+ const statusText = document.getElementById('statusText');
794
  let isGenerating = false;
795
  let abortController = null;
796
  let typingTimeout = null;
797
  let lastUserPrompt = "";
798
+
799
  const ICON_SEND = `<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"></path></svg>`;
800
  const ICON_STOP = `<svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="currentColor" stroke-width="0"><rect x="2" y="2" width="20" height="20" rx="4" ry="4"></rect></svg>`;
801
+
802
  mainBtn.innerHTML = ICON_SEND;
803
+
804
  function scrollToBottom() {
805
  chatScroll.scrollTop = chatScroll.scrollHeight;
806
  }
807
+
808
  function setBtnState(state) {
809
  if (state === 'sending') {
810
  mainBtn.innerHTML = ICON_STOP;
811
  isGenerating = true;
812
+ statusText.textContent = "Pensando...";
813
  } else {
814
  mainBtn.innerHTML = ICON_SEND;
815
  isGenerating = false;
816
  abortController = null;
817
+ statusText.textContent = "Conectado";
818
  }
819
  }
820
+
821
  function handleBtnClick() {
822
  if (isGenerating) {
823
  stopGeneration();
 
825
  sendMessage();
826
  }
827
  }
828
+
829
  function stopGeneration() {
830
  if (abortController) abortController.abort();
831
  if (typingTimeout) clearTimeout(typingTimeout);
832
  const activeCursor = document.querySelector('.typing-cursor');
833
  if (activeCursor) activeCursor.classList.remove('typing-cursor');
 
 
834
  setBtnState('idle');
835
  userInput.focus();
836
  }
837
+
838
  async function sendMessage(textOverride = null) {
839
  const text = textOverride || userInput.value.trim();
840
  if (!text) return;
841
  lastUserPrompt = text;
842
+
843
  if (!textOverride) {
844
  userInput.value = '';
845
  addMessage(text, 'user');
846
  }
847
+
848
  setBtnState('sending');
849
  abortController = new AbortController();
850
+
851
  const botRow = document.createElement('div');
852
  botRow.className = 'msg-row bot';
853
  const avatar = document.createElement('div');
854
+ avatar.className = 'bot-avatar';
855
+ avatar.textContent = 'M';
856
  const wrapper = document.createElement('div');
857
  wrapper.className = 'msg-content-wrapper';
858
  const msgText = document.createElement('div');
859
+ msgText.className = 'msg-text';
860
  wrapper.appendChild(msgText);
861
  botRow.appendChild(avatar);
862
  botRow.appendChild(wrapper);
863
  chatScroll.appendChild(botRow);
864
  scrollToBottom();
865
+
866
  try {
867
  const response = await fetch('/generate', {
868
  method: 'POST',
 
870
  body: JSON.stringify({ text: text }),
871
  signal: abortController.signal
872
  });
873
+
874
  const data = await response.json();
875
+
876
+ if (!isGenerating) return;
877
+
878
+ const reply = data.reply || "Lo siento, no pude procesar tu solicitud.";
879
+
880
  await typeWriter(msgText, reply);
881
+
882
  if (isGenerating) {
883
  addActions(wrapper, reply);
884
  setBtnState('idle');
 
887
  if (error.name === 'AbortError') {
888
  msgText.textContent += " [Detenido]";
889
  } else {
890
+ msgText.textContent = "Error de conexión. Intenta de nuevo.";
 
891
  msgText.style.color = "#ff8b8b";
 
892
  }
893
+ setBtnState('idle');
894
  }
895
  }
896
+
897
  function addMessage(text, sender) {
898
  const row = document.createElement('div');
899
  row.className = `msg-row ${sender}`;
 
904
  chatScroll.appendChild(row);
905
  scrollToBottom();
906
  }
907
+
908
  function typeWriter(element, text, speed = 12) {
909
  return new Promise(resolve => {
910
  let i = 0;
911
+ element.textContent = '';
912
  element.classList.add('typing-cursor');
913
+
914
  function type() {
915
  if (!isGenerating) {
916
  element.classList.remove('typing-cursor');
 
930
  type();
931
  });
932
  }
933
+
934
  function addActions(wrapperElement, textToCopy) {
935
  const actionsDiv = document.createElement('div');
936
  actionsDiv.className = 'bot-actions';
937
+
938
  const copyBtn = document.createElement('button');
939
  copyBtn.className = 'action-btn';
940
  copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
941
  copyBtn.onclick = () => {
942
  navigator.clipboard.writeText(textToCopy);
943
+ copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M20 6L9 17l-5-5"></path></svg>`;
944
+ setTimeout(() => {
945
+ copyBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path></svg>`;
946
+ }, 1500);
947
  };
948
+
949
  const regenBtn = document.createElement('button');
950
  regenBtn.className = 'action-btn';
951
  regenBtn.innerHTML = `<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M23 4v6h-6"></path><path d="M1 20v-6h6"></path><path d="M3.51 9a9 9 0 0 1 14.85-3.36L23 10M1 14l4.64 4.36A9 9 0 0 0 20.49 15"></path></svg>`;
952
  regenBtn.onclick = () => {
953
  sendMessage(lastUserPrompt);
954
  };
955
+
956
  actionsDiv.appendChild(copyBtn);
957
  actionsDiv.appendChild(regenBtn);
958
  wrapperElement.appendChild(actionsDiv);
959
  requestAnimationFrame(() => actionsDiv.style.opacity = "1");
960
  scrollToBottom();
961
  }
962
+
963
  userInput.addEventListener('keydown', (e) => {
964
+ if (e.key === 'Enter' && !e.shiftKey) {
965
+ e.preventDefault();
966
+ handleBtnClick();
967
+ }
968
  });
969
+
970
  window.onload = () => userInput.focus();
971
  </script>
972
  </body>