forked from thomwolf/transformers
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d959443
commit 82a0f3b
Showing
2 changed files
with
324 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" Testing suite for the PyTorch Cohere model. """ | ||
|
||
import unittest | ||
|
||
from parameterized import parameterized | ||
|
||
from transformers import CohereConfig, is_torch_available | ||
from transformers.testing_utils import ( | ||
require_torch, | ||
slow, | ||
torch_device, | ||
) | ||
|
||
from ...test_configuration_common import ConfigTester | ||
from ...test_modeling_common import ids_tensor | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
from transformers import CohereForCausalLM, CohereModel | ||
|
||
|
||
class CohereModelTester: | ||
def __init__( | ||
self, | ||
parent, | ||
batch_size=13, | ||
seq_length=7, | ||
is_training=True, | ||
use_input_mask=True, | ||
use_token_type_ids=False, | ||
use_labels=True, | ||
vocab_size=99, | ||
hidden_size=32, | ||
num_hidden_layers=2, | ||
num_attention_heads=4, | ||
intermediate_size=37, | ||
hidden_act="gelu", | ||
hidden_dropout_prob=0.1, | ||
attention_probs_dropout_prob=0.1, | ||
max_position_embeddings=512, | ||
type_vocab_size=16, | ||
type_sequence_label_size=2, | ||
initializer_range=0.02, | ||
num_labels=3, | ||
num_choices=4, | ||
pad_token_id=0, | ||
scope=None, | ||
): | ||
self.parent = parent | ||
self.batch_size = batch_size | ||
self.seq_length = seq_length | ||
self.is_training = is_training | ||
self.use_input_mask = use_input_mask | ||
self.use_token_type_ids = use_token_type_ids | ||
self.use_labels = use_labels | ||
self.vocab_size = vocab_size | ||
self.hidden_size = hidden_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.intermediate_size = intermediate_size | ||
self.hidden_act = hidden_act | ||
self.hidden_dropout_prob = hidden_dropout_prob | ||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | ||
self.max_position_embeddings = max_position_embeddings | ||
self.type_vocab_size = type_vocab_size | ||
self.type_sequence_label_size = type_sequence_label_size | ||
self.initializer_range = initializer_range | ||
self.num_labels = num_labels | ||
self.num_choices = num_choices | ||
self.pad_token_id = pad_token_id | ||
self.scope = scope | ||
|
||
def prepare_config_and_inputs(self): | ||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) | ||
|
||
input_mask = None | ||
if self.use_input_mask: | ||
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) | ||
|
||
token_type_ids = None | ||
if self.use_token_type_ids: | ||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) | ||
|
||
sequence_labels = None | ||
token_labels = None | ||
choice_labels = None | ||
if self.use_labels: | ||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) | ||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) | ||
choice_labels = ids_tensor([self.batch_size], self.num_choices) | ||
|
||
config = self.get_config() | ||
|
||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels | ||
|
||
def get_config(self): | ||
return CohereConfig( | ||
vocab_size=self.vocab_size, | ||
hidden_size=self.hidden_size, | ||
num_hidden_layers=self.num_hidden_layers, | ||
num_attention_heads=self.num_attention_heads, | ||
intermediate_size=self.intermediate_size, | ||
hidden_act=self.hidden_act, | ||
hidden_dropout_prob=self.hidden_dropout_prob, | ||
attention_probs_dropout_prob=self.attention_probs_dropout_prob, | ||
max_position_embeddings=self.max_position_embeddings, | ||
type_vocab_size=self.type_vocab_size, | ||
is_decoder=False, | ||
initializer_range=self.initializer_range, | ||
pad_token_id=self.pad_token_id, | ||
) | ||
|
||
def create_and_check_model( | ||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels | ||
): | ||
model = CohereModel(config=config) | ||
model.to(torch_device) | ||
model.eval() | ||
result = model(input_ids, attention_mask=input_mask) | ||
result = model(input_ids) | ||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) | ||
|
||
def create_and_check_model_as_decoder( | ||
self, | ||
config, | ||
input_ids, | ||
token_type_ids, | ||
input_mask, | ||
sequence_labels, | ||
token_labels, | ||
choice_labels, | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
): | ||
config.add_cross_attention = True | ||
model = CohereModel(config) | ||
model.to(torch_device) | ||
model.eval() | ||
result = model( | ||
input_ids, | ||
attention_mask=input_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
) | ||
result = model( | ||
input_ids, | ||
attention_mask=input_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
) | ||
result = model(input_ids, attention_mask=input_mask) | ||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) | ||
|
||
def create_and_check_for_causal_lm( | ||
self, | ||
config, | ||
input_ids, | ||
token_type_ids, | ||
input_mask, | ||
sequence_labels, | ||
token_labels, | ||
choice_labels, | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
): | ||
model = CohereForCausalLM(config=config) | ||
model.to(torch_device) | ||
model.eval() | ||
result = model(input_ids, attention_mask=input_mask, labels=token_labels) | ||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) | ||
|
||
def create_and_check_decoder_model_past_large_inputs( | ||
self, | ||
config, | ||
input_ids, | ||
token_type_ids, | ||
input_mask, | ||
sequence_labels, | ||
token_labels, | ||
choice_labels, | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
): | ||
config.is_decoder = True | ||
config.add_cross_attention = True | ||
model = CohereForCausalLM(config=config) | ||
model.to(torch_device) | ||
model.eval() | ||
|
||
# first forward pass | ||
outputs = model( | ||
input_ids, | ||
attention_mask=input_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
use_cache=True, | ||
) | ||
past_key_values = outputs.past_key_values | ||
|
||
# create hypothetical multiple next token and extent to next_input_ids | ||
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) | ||
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) | ||
|
||
# append to next input_ids and | ||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) | ||
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) | ||
|
||
output_from_no_past = model( | ||
next_input_ids, | ||
attention_mask=next_attention_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
output_hidden_states=True, | ||
)["hidden_states"][0] | ||
output_from_past = model( | ||
next_tokens, | ||
attention_mask=next_attention_mask, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
past_key_values=past_key_values, | ||
output_hidden_states=True, | ||
)["hidden_states"][0] | ||
|
||
# select random slice | ||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() | ||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() | ||
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() | ||
|
||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) | ||
|
||
# test that outputs are equal for slice | ||
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) | ||
|
||
def prepare_config_and_inputs_for_common(self): | ||
config_and_inputs = self.prepare_config_and_inputs() | ||
( | ||
config, | ||
input_ids, | ||
token_type_ids, | ||
input_mask, | ||
sequence_labels, | ||
token_labels, | ||
choice_labels, | ||
) = config_and_inputs | ||
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} | ||
return config, inputs_dict | ||
|
||
|
||
@require_torch | ||
class CohereModelTest(unittest.TestCase): | ||
all_model_classes = (CohereModel, CohereForCausalLM) if is_torch_available() else () | ||
all_generative_model_classes = (CohereForCausalLM,) if is_torch_available() else () | ||
pipeline_model_mapping = ( | ||
{ | ||
"feature-extraction": CohereModel, | ||
"text-generation": CohereForCausalLM, | ||
} | ||
if is_torch_available() | ||
else {} | ||
) | ||
test_headmasking = False | ||
test_pruning = False | ||
fx_compatible = True | ||
|
||
# Need to use `0.8` instead of `0.9` for `test_cpu_offload` | ||
# This is because we are hitting edge cases with the causal_mask buffer | ||
model_split_percents = [0.5, 0.7, 0.8] | ||
|
||
def setUp(self): | ||
self.model_tester = CohereModelTester(self) | ||
self.config_tester = ConfigTester(self, config_class=CohereConfig, hidden_size=37) | ||
|
||
def test_model(self): | ||
config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||
self.model_tester.create_and_check_model(*config_and_inputs) | ||
|
||
def test_model_various_embeddings(self): | ||
config_and_inputs = self.model_tester.prepare_config_and_inputs() | ||
for type in ["absolute", "relative_key", "relative_key_query"]: | ||
config_and_inputs[0].position_embedding_type = type | ||
self.model_tester.create_and_check_model(*config_and_inputs) | ||
|
||
@unittest.skip("TODO @gante fix this") | ||
@parameterized.expand([(1, False), (1, True), (4, False)]) | ||
def test_new_cache_format(self, num_beams, do_sample): | ||
pass | ||
|
||
|
||
@require_torch | ||
class CohereIntegrationTest(unittest.TestCase): | ||
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") | ||
@slow | ||
def test_model_logits(self): | ||
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] | ||
model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", device_map="auto") | ||
out = model(torch.tensor(input_ids).unsqueeze(0)) | ||
# # Expected mean on dim = -1 | ||
EXPECTED_MEAN = torch.tensor([[0.5077, -2.5771, -1.1590, -2.6220, -1.7837, -2.4421, -1.3293, -2.2028]]) | ||
torch.testing.assert_close(out[0].mean(-1).cpu(), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) | ||
# slicing logits[0, 0, 0:30] | ||
EXPECTED_SLICE = torch.tensor([ 1.8525, 5.0039, 2.7734, 3.6270, 0.9390, -0.4587, 3.4062, 0.9468, \ | ||
3.7324, 1.2344, 5.3047, 4.7266, 5.9414, 5.5195, 1.8047, 3.5215, \ | ||
1.5752, 3.7031, 6.2891, 3.4785, 2.0293, 4.2539, 2.8086, 4.7070, \ | ||
3.6953, 4.0391, 3.9766, 3.3066, 2.9395, 3.3105]) # fmt: skip | ||
torch.testing.assert_close(out[0][0, 0, :30].cpu(), EXPECTED_SLICE, atol=1e-5, rtol=1e-5) |