From 82a0f3b635d9fe1d7fba72c5f34f8a5f8eea21ee Mon Sep 17 00:00:00 2001 From: Saurabh Dash <111897126+saurabhdash2512@users.noreply.github.com> Date: Thu, 14 Mar 2024 01:03:36 +0530 Subject: [PATCH] Add modeling tests (#9) --- .../models/cohere/tokenization_cohere_fast.py | 6 +- tests/models/cohere/test_modeling_cohere.py | 320 ++++++++++++++++++ 2 files changed, 324 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/cohere/tokenization_cohere_fast.py b/src/transformers/models/cohere/tokenization_cohere_fast.py index b066218b81b4f1..2c0f63c0a50393 100644 --- a/src/transformers/models/cohere/tokenization_cohere_fast.py +++ b/src/transformers/models/cohere/tokenization_cohere_fast.py @@ -198,7 +198,9 @@ def default_chat_template(self): <|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{ I am doing well! }}<|END_OF_TURN_TOKEN|> Use add_generation_prompt to add a prompt for the model to generate a response: - + + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01", trust_remote_code=True) >>> messages = [{"role": "user", "content": "Hello, how are you?"}] >>> tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) <|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> @@ -740,4 +742,4 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): if token_ids_1 is not None: output = output + bos_token_id + token_ids_1 + eos_token_id - return output + return output \ No newline at end of file diff --git a/tests/models/cohere/test_modeling_cohere.py b/tests/models/cohere/test_modeling_cohere.py index e69de29bb2d1d6..988e4469b97cc5 100644 --- a/tests/models/cohere/test_modeling_cohere.py +++ b/tests/models/cohere/test_modeling_cohere.py @@ -0,0 +1,320 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Testing suite for the PyTorch Cohere model. """ + +import unittest + +from parameterized import parameterized + +from transformers import CohereConfig, is_torch_available +from transformers.testing_utils import ( + require_torch, + slow, + torch_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ids_tensor + + +if is_torch_available(): + import torch + + from transformers import CohereForCausalLM, CohereModel + + +class CohereModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return CohereConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = CohereModel(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = CohereModel(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = CohereForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = CohereForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class CohereModelTest(unittest.TestCase): + all_model_classes = (CohereModel, CohereForCausalLM) if is_torch_available() else () + all_generative_model_classes = (CohereForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": CohereModel, + "text-generation": CohereForCausalLM, + } + if is_torch_available() + else {} + ) + test_headmasking = False + test_pruning = False + fx_compatible = True + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + def setUp(self): + self.model_tester = CohereModelTester(self) + self.config_tester = ConfigTester(self, config_class=CohereConfig, hidden_size=37) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip("TODO @gante fix this") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + + +@require_torch +class CohereIntegrationTest(unittest.TestCase): + @unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!") + @slow + def test_model_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", device_map="auto") + out = model(torch.tensor(input_ids).unsqueeze(0)) + # # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[0.5077, -2.5771, -1.1590, -2.6220, -1.7837, -2.4421, -1.3293, -2.2028]]) + torch.testing.assert_close(out[0].mean(-1).cpu(), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([ 1.8525, 5.0039, 2.7734, 3.6270, 0.9390, -0.4587, 3.4062, 0.9468, \ + 3.7324, 1.2344, 5.3047, 4.7266, 5.9414, 5.5195, 1.8047, 3.5215, \ + 1.5752, 3.7031, 6.2891, 3.4785, 2.0293, 4.2539, 2.8086, 4.7070, \ + 3.6953, 4.0391, 3.9766, 3.3066, 2.9395, 3.3105]) # fmt: skip + torch.testing.assert_close(out[0][0, 0, :30].cpu(), EXPECTED_SLICE, atol=1e-5, rtol=1e-5)