| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import unittest |
| |
|
| | import numpy as np |
| |
|
| | from transformers import is_flax_available |
| | from transformers.testing_utils import require_flax |
| |
|
| | from .test_modeling_flax_common import ids_tensor |
| |
|
| |
|
| | if is_flax_available(): |
| | import jax |
| | import jax.numpy as jnp |
| | from transformers.generation_flax_logits_process import ( |
| | FlaxLogitsProcessorList, |
| | FlaxTemperatureLogitsWarper, |
| | FlaxTopKLogitsWarper, |
| | FlaxTopPLogitsWarper, |
| | ) |
| |
|
| |
|
| | @require_flax |
| | class LogitsProcessorTest(unittest.TestCase): |
| | def _get_uniform_logits(self, batch_size: int, length: int): |
| | scores = np.ones((batch_size, length)) / length |
| | return scores |
| |
|
| | def test_temperature_dist_warper(self): |
| | input_ids = None |
| | length = 20 |
| |
|
| | scores = self._get_uniform_logits(batch_size=2, length=length) |
| |
|
| | |
| | scores[1, 5] = (1 / length) + 0.1 |
| | scores[1, 10] = (1 / length) - 0.4 |
| |
|
| | |
| | probs = jax.nn.softmax(scores, axis=-1) |
| |
|
| | temp_dist_warper_sharper = FlaxTemperatureLogitsWarper(temperature=0.5) |
| | temp_dist_warper_smoother = FlaxTemperatureLogitsWarper(temperature=1.3) |
| |
|
| | warped_prob_sharp = jax.nn.softmax(temp_dist_warper_sharper(input_ids, scores.copy()), axis=-1) |
| | warped_prob_smooth = jax.nn.softmax(temp_dist_warper_smoother(input_ids, scores.copy()), axis=-1) |
| |
|
| | |
| | self.assertTrue(jnp.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)) |
| | self.assertTrue(jnp.allclose(probs[0, :], warped_prob_smooth[0, :], atol=1e-3)) |
| |
|
| | |
| | self.assertLess(probs[1, :].max(), warped_prob_sharp[1, :].max()) |
| | self.assertGreater(probs[1, :].min(), warped_prob_sharp[1, :].min()) |
| |
|
| | |
| | self.assertGreater(probs[1, :].max(), warped_prob_smooth[1, :].max()) |
| | self.assertLess(probs[1, :].min(), warped_prob_smooth[1, :].min()) |
| |
|
| | def test_top_k_dist_warper(self): |
| | input_ids = None |
| | vocab_size = 10 |
| | batch_size = 2 |
| |
|
| | |
| | ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() |
| | ramp_logits[1:, : vocab_size // 2] = ramp_logits[1:, : vocab_size // 2] + vocab_size |
| |
|
| | top_k_warp = FlaxTopKLogitsWarper(3) |
| |
|
| | scores = top_k_warp(input_ids, ramp_logits) |
| |
|
| | |
| | self.assertListEqual(jnp.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False]) |
| | self.assertListEqual(jnp.isinf(scores[1]).tolist(), 2 * [True] + 3 * [False] + 5 * [True]) |
| |
|
| | |
| | length = 5 |
| | top_k_warp_safety_check = FlaxTopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3) |
| |
|
| | ramp_logits = np.broadcast_to(np.arange(length)[None, :], (batch_size, length)).copy() |
| | scores = top_k_warp_safety_check(input_ids, ramp_logits) |
| |
|
| | |
| | self.assertListEqual((scores == 0.0).sum(axis=-1).tolist(), [2, 2]) |
| |
|
| | def test_top_p_dist_warper(self): |
| | input_ids = None |
| | vocab_size = 10 |
| | batch_size = 2 |
| |
|
| | |
| | dist = np.log(np.array([[0.3, 0.1, 0.1, 0.5], [0.15, 0.3, 0.3, 0.25]])) |
| |
|
| | top_p_warp = FlaxTopPLogitsWarper(0.7) |
| | filtered_dist = np.exp(top_p_warp(input_ids, dist)) |
| |
|
| | |
| | |
| | EXPECTED_FILTERED_DIST = np.array([[0.3, 0.0, 0.0, 0.5], [0.0, 0.3, 0.3, 0.25]]) |
| | self.assertTrue(np.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3)) |
| |
|
| | |
| | ramp_logits = np.broadcast_to(np.arange(vocab_size)[None, :], (batch_size, vocab_size)).copy() - ( |
| | vocab_size // 2 |
| | ) |
| |
|
| | |
| | ramp_logits[1] = ramp_logits[1] * 100.0 |
| |
|
| | |
| | top_p_warp = FlaxTopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) |
| | filtered_dist = top_p_warp(input_ids, ramp_logits) |
| |
|
| | |
| | self.assertListEqual((filtered_dist != 0.0).sum(axis=-1).tolist(), [3, 2]) |
| |
|
| | def test_processor_list(self): |
| | batch_size = 4 |
| | sequence_length = 10 |
| | vocab_size = 15 |
| |
|
| | |
| | input_ids = ids_tensor((batch_size, sequence_length), vocab_size) |
| | input_ids_comp = input_ids.copy() |
| |
|
| | scores = self._get_uniform_logits(batch_size, vocab_size) |
| | scores_comp = scores.copy() |
| |
|
| | |
| | temp_dist_warp = FlaxTemperatureLogitsWarper(temperature=0.5) |
| | top_k_warp = FlaxTopKLogitsWarper(3) |
| | top_p_warp = FlaxTopPLogitsWarper(0.8) |
| |
|
| | |
| | scores = temp_dist_warp(input_ids, scores) |
| | scores = top_k_warp(input_ids, scores) |
| | scores = top_p_warp(input_ids, scores) |
| |
|
| | |
| | processor = FlaxLogitsProcessorList([temp_dist_warp, top_k_warp, top_p_warp]) |
| | scores_comp = processor(input_ids, scores_comp) |
| |
|
| | |
| | self.assertTrue(jnp.allclose(scores, scores_comp, atol=1e-3)) |
| |
|
| | |
| | self.assertListEqual(input_ids.tolist(), input_ids_comp.tolist()) |
| |
|