Frenchizer commited on
Commit
76720d2
·
verified ·
1 Parent(s): 10fd17d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -18,23 +18,18 @@ def gradio_predict(input_text):
18
  max_length=512
19
  )
20
 
21
- # Get shapes from actual input
22
- batch_size = tokenized_input["input_ids"].shape[0] # Should be 1
23
- seq_length = tokenized_input["input_ids"].shape[1] # Should be 512
24
-
25
- # Prepare inputs with correct shapes
26
  input_ids = tokenized_input["input_ids"].astype(np.int64)
27
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
28
 
29
- # Create decoder_input_ids with matching shape
30
- # Usually starts with pad_token_id or bos_token_id
31
- decoder_input_ids = np.full((batch_size, seq_length), tokenizer.pad_token_id, dtype=np.int64)
32
- decoder_input_ids[:, 0] = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.pad_token_id
33
 
34
- print("Debug shapes:")
35
- print(f"input_ids shape: {input_ids.shape}")
36
- print(f"attention_mask shape: {attention_mask.shape}")
37
- print(f"decoder_input_ids shape: {decoder_input_ids.shape}")
38
 
39
  # Run inference
40
  outputs = session.run(
@@ -46,12 +41,27 @@ def gradio_predict(input_text):
46
  }
47
  )
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Decode output
50
- translated_text = tokenizer.decode(outputs[0][0], skip_special_tokens=True)
51
  return translated_text
52
 
53
  except Exception as e:
54
- print(f"Detailed error: {str(e)}") # This will show in the Space's logs
 
 
55
  return f"Error during translation: {str(e)}"
56
 
57
  # Gradio interface for the web app
 
18
  max_length=512
19
  )
20
 
21
+ # Prepare inputs
 
 
 
 
22
  input_ids = tokenized_input["input_ids"].astype(np.int64)
23
  attention_mask = tokenized_input["attention_mask"].astype(np.int64)
24
 
25
+ # Initialize decoder_input_ids with start token
26
+ decoder_input_ids = np.zeros((1, 512), dtype=np.int64)
27
+ decoder_input_ids[:, 0] = tokenizer.bos_token_id or tokenizer.pad_token_id
 
28
 
29
+ print("Input values:")
30
+ print(f"First few input_ids: {input_ids[0][:10]}")
31
+ print(f"First few attention_mask: {attention_mask[0][:10]}")
32
+ print(f"First few decoder_input_ids: {decoder_input_ids[0][:10]}")
33
 
34
  # Run inference
35
  outputs = session.run(
 
41
  }
42
  )
43
 
44
+ print("Output shape and type:")
45
+ print(f"Output type: {type(outputs)}")
46
+ print(f"Output[0] type: {type(outputs[0])}")
47
+ print(f"Output[0] shape: {outputs[0].shape}")
48
+
49
+ # Process outputs more carefully
50
+ output_ids = outputs[0]
51
+ if isinstance(output_ids, np.ndarray):
52
+ output_ids = output_ids[0] # Take first sequence
53
+ # Convert to list of integers if needed
54
+ if isinstance(output_ids, np.ndarray):
55
+ output_ids = output_ids.tolist()
56
+
57
  # Decode output
58
+ translated_text = tokenizer.decode(output_ids, skip_special_tokens=True)
59
  return translated_text
60
 
61
  except Exception as e:
62
+ print(f"Detailed error: {str(e)}")
63
+ import traceback
64
+ print(traceback.format_exc())
65
  return f"Error during translation: {str(e)}"
66
 
67
  # Gradio interface for the web app