[DRAFT] fix: transformers 5.x compat (cache_position + kwargs naming)

#6
Files changed (1) hide show
  1. modeling_dots_ocr.py +11 -3
modeling_dots_ocr.py CHANGED
@@ -80,7 +80,7 @@ class DotsOCRForCausalLM(Qwen2ForCausalLM):
80
  return_dict: Optional[bool] = None,
81
  use_cache: Optional[bool] = None,
82
  logits_to_keep: int = 0,
83
- **loss_kwargs,
84
  ) -> Union[Tuple, CausalLMOutputWithPast]:
85
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
  assert len(input_ids) >= 1, f"empty input_ids {input_ids.shape=} will cause gradnorm nan"
@@ -99,7 +99,7 @@ class DotsOCRForCausalLM(Qwen2ForCausalLM):
99
  output_hidden_states=output_hidden_states,
100
  # return_dict=return_dict,
101
  logits_to_keep=logits_to_keep,
102
- **loss_kwargs,
103
  )
104
 
105
  return outputs
@@ -125,7 +125,15 @@ class DotsOCRForCausalLM(Qwen2ForCausalLM):
125
  **kwargs,
126
  )
127
 
128
- if cache_position[0] == 0:
 
 
 
 
 
 
 
 
129
  model_inputs["pixel_values"] = pixel_values
130
 
131
  return model_inputs
 
80
  return_dict: Optional[bool] = None,
81
  use_cache: Optional[bool] = None,
82
  logits_to_keep: int = 0,
83
+ **kwargs,
84
  ) -> Union[Tuple, CausalLMOutputWithPast]:
85
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
86
  assert len(input_ids) >= 1, f"empty input_ids {input_ids.shape=} will cause gradnorm nan"
 
99
  output_hidden_states=output_hidden_states,
100
  # return_dict=return_dict,
101
  logits_to_keep=logits_to_keep,
102
+ **kwargs,
103
  )
104
 
105
  return outputs
 
125
  **kwargs,
126
  )
127
 
128
+ # Pass pixel_values only on the first generation step (prefill).
129
+ # Compatible with both transformers 4.x (cache_position available)
130
+ # and 5.x (cache_position removed, use past_key_values instead).
131
+ is_prefill = (
132
+ (cache_position is not None and cache_position[0] == 0)
133
+ or past_key_values is None
134
+ or (hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0)
135
+ )
136
+ if is_prefill:
137
  model_inputs["pixel_values"] = pixel_values
138
 
139
  return model_inputs