Frenchizer commited on
Commit
ce28322
·
verified ·
1 Parent(s): 6fd7731

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -13
app.py CHANGED
@@ -22,8 +22,9 @@ def gradio_predict(input_text):
22
  input_ids = tokenized_input["input_ids"].astype(np.int64)
23
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
24
 
25
- # Initialize decoder input
26
- decoder_input_ids = np.array([[tokenizer.bos_token_id]], dtype=np.int64)
 
27
 
28
  print("Initial shapes:")
29
  print(f"input_ids shape: {input_ids.shape}")
@@ -40,24 +41,15 @@ def gradio_predict(input_text):
40
  }
41
  )
42
 
43
- print("Output information:")
44
- print(f"outputs type: {type(outputs)}")
45
- print(f"outputs length: {len(outputs)}")
46
- print(f"outputs[0] shape: {outputs[0].shape}")
47
-
48
  # Get logits and convert to token ids
49
  logits = outputs[0]
50
  token_ids = np.argmax(logits[0], axis=-1)
51
 
52
- # Find end of sequence
53
- eos_token_id = tokenizer.eos_token_id
54
- end_idx = np.where(token_ids == eos_token_id)[0]
55
  if len(end_idx) > 0:
56
  token_ids = token_ids[:end_idx[0]]
57
 
58
- print(f"token_ids shape: {token_ids.shape}")
59
- print(f"token_ids: {token_ids}")
60
-
61
  # Decode the sequence
62
  translated_text = tokenizer.decode(token_ids, skip_special_tokens=True)
63
  return translated_text
 
22
  input_ids = tokenized_input["input_ids"].astype(np.int64)
23
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
24
 
25
+ # Use a specific token ID for decoder start (for Helsinki-NLP models)
26
+ decoder_start_token_id = 59513 # This is the typical start token for Helsinki-NLP models
27
+ decoder_input_ids = np.array([[decoder_start_token_id]], dtype=np.int64)
28
 
29
  print("Initial shapes:")
30
  print(f"input_ids shape: {input_ids.shape}")
 
41
  }
42
  )
43
 
 
 
 
 
 
44
  # Get logits and convert to token ids
45
  logits = outputs[0]
46
  token_ids = np.argmax(logits[0], axis=-1)
47
 
48
+ # Find end of sequence (using pad token since eos might also be None)
49
+ end_idx = np.where(token_ids == tokenizer.pad_token_id)[0]
 
50
  if len(end_idx) > 0:
51
  token_ids = token_ids[:end_idx[0]]
52
 
 
 
 
53
  # Decode the sequence
54
  translated_text = tokenizer.decode(token_ids, skip_special_tokens=True)
55
  return translated_text