Frenchizer commited on
Commit
6fd7731
·
verified ·
1 Parent(s): 57da7e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -28
app.py CHANGED
@@ -22,43 +22,50 @@ def gradio_predict(input_text):
22
  input_ids = tokenized_input["input_ids"].astype(np.int64)
23
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
24
 
25
- # Create proper decoder_input_ids for autoregressive generation
26
  decoder_input_ids = np.array([[tokenizer.bos_token_id]], dtype=np.int64)
27
 
28
- generated_ids = []
29
- max_length = 128 # Maximum length of translation
 
 
30
 
31
- # Autoregressive generation
32
- for _ in range(max_length):
33
- outputs = session.run(
34
- None,
35
- {
36
- "input_ids": input_ids,
37
- "attention_mask": attention_mask,
38
- "decoder_input_ids": decoder_input_ids
39
- }
40
- )
41
-
42
- # Get the next token prediction
43
- next_token_logits = outputs[0][0, -1, :]
44
- next_token = np.argmax(next_token_logits)
45
-
46
- # Stop if we hit the EOS token
47
- if next_token == tokenizer.eos_token_id:
48
- break
49
-
50
- # Append the predicted token
51
- generated_ids.append(next_token)
 
 
 
52
 
53
- # Update decoder_input_ids for next iteration
54
- decoder_input_ids = np.array([[tokenizer.bos_token_id] + generated_ids], dtype=np.int64)
55
 
56
- # Decode the generated sequence
57
- translated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
58
  return translated_text
59
 
60
  except Exception as e:
61
  print(f"Detailed error: {str(e)}")
 
 
62
  return f"Error during translation: {str(e)}"
63
 
64
  # Gradio interface for the web app
 
22
  input_ids = tokenized_input["input_ids"].astype(np.int64)
23
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
24
 
25
+ # Initialize decoder input
26
  decoder_input_ids = np.array([[tokenizer.bos_token_id]], dtype=np.int64)
27
 
28
+ print("Initial shapes:")
29
+ print(f"input_ids shape: {input_ids.shape}")
30
+ print(f"attention_mask shape: {attention_mask.shape}")
31
+ print(f"decoder_input_ids shape: {decoder_input_ids.shape}")
32
 
33
+ # Run initial inference
34
+ outputs = session.run(
35
+ None,
36
+ {
37
+ "input_ids": input_ids,
38
+ "attention_mask": attention_mask,
39
+ "decoder_input_ids": decoder_input_ids
40
+ }
41
+ )
42
+
43
+ print("Output information:")
44
+ print(f"outputs type: {type(outputs)}")
45
+ print(f"outputs length: {len(outputs)}")
46
+ print(f"outputs[0] shape: {outputs[0].shape}")
47
+
48
+ # Get logits and convert to token ids
49
+ logits = outputs[0]
50
+ token_ids = np.argmax(logits[0], axis=-1)
51
+
52
+ # Find end of sequence
53
+ eos_token_id = tokenizer.eos_token_id
54
+ end_idx = np.where(token_ids == eos_token_id)[0]
55
+ if len(end_idx) > 0:
56
+ token_ids = token_ids[:end_idx[0]]
57
 
58
+ print(f"token_ids shape: {token_ids.shape}")
59
+ print(f"token_ids: {token_ids}")
60
 
61
+ # Decode the sequence
62
+ translated_text = tokenizer.decode(token_ids, skip_special_tokens=True)
63
  return translated_text
64
 
65
  except Exception as e:
66
  print(f"Detailed error: {str(e)}")
67
+ import traceback
68
+ print(traceback.format_exc())
69
  return f"Error during translation: {str(e)}"
70
 
71
  # Gradio interface for the web app