Skip to content

Commit

Permalink
fix: multilingual midel convert to tflite get wrong token
Browse files Browse the repository at this point in the history
  • Loading branch information
kent.sc.hung authored and Aya committed Aug 25, 2024
1 parent 0a7af19 commit b4073dd
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/generation/tf_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def _force_token(generation_idx):
batch_size = scores.shape[0]
current_token = self.force_token_array[generation_idx]

new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf")
new_scores = tf.zeros_like(scores, dtype=scores.dtype) + tf.constant([scores.dtype.min])
indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
updates = tf.zeros((batch_size,), dtype=scores.dtype)
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
Expand Down

0 comments on commit b4073dd

Please sign in to comment.