Skip to content

Commit

Permalink
[Flax BERT] Update deprecated 'split' method (huggingface#28012)
Browse files Browse the repository at this point in the history
* [Flax BERT] Update deprecated 'split' method

* fix copies
  • Loading branch information
sanchit-gandhi authored and iantbutler01 committed Dec 16, 2023
1 parent c8c9eb7 commit 4335027
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ def __call__(
hidden_states = outputs[0]

logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roberta/modeling_flax_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,7 @@ def __call__(
hidden_states = outputs[0]

logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def __call__(
hidden_states = outputs[0]

logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ def __call__(
hidden_states = outputs[0]

logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)

Expand Down

0 comments on commit 4335027

Please sign in to comment.