Skip to content

Commit

Permalink
Add modeling tests (huggingface#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
saurabhdash2512 authored Mar 13, 2024
1 parent d959443 commit 82a0f3b
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/transformers/models/cohere/tokenization_cohere_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def default_chat_template(self):
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{ I am doing well! }}<|END_OF_TURN_TOKEN|>
Use add_generation_prompt to add a prompt for the model to generate a response:
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01", trust_remote_code=True)
>>> messages = [{"role": "user", "content": "Hello, how are you?"}]
>>> tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
<BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
Expand Down Expand Up @@ -740,4 +742,4 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id

return output
return output
320 changes: 320 additions & 0 deletions tests/models/cohere/test_modeling_cohere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Cohere model. """

import unittest

from parameterized import parameterized

from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import (
require_torch,
slow,
torch_device,
)

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ids_tensor


if is_torch_available():
import torch

from transformers import CohereForCausalLM, CohereModel


class CohereModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_input_mask=True,
use_token_type_ids=False,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
pad_token_id=0,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)

input_mask = None
if self.use_input_mask:
input_mask = torch.tril(torch.ones(self.batch_size, self.seq_length)).to(torch_device)

token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)

sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)

config = self.get_config()

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

def get_config(self):
return CohereConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
)

def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = CohereModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))

def create_and_check_model_as_decoder(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.add_cross_attention = True
model = CohereModel(config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
result = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
)
result = model(input_ids, attention_mask=input_mask)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))

def create_and_check_for_causal_lm(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
model = CohereForCausalLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))

def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
):
config.is_decoder = True
config.add_cross_attention = True
model = CohereForCausalLM(config=config)
model.to(torch_device)
model.eval()

# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values

# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)

# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)

output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]

# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()

self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])

# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict


@require_torch
class CohereModelTest(unittest.TestCase):
all_model_classes = (CohereModel, CohereForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (CohereForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": CohereModel,
"text-generation": CohereForCausalLM,
}
if is_torch_available()
else {}
)
test_headmasking = False
test_pruning = False
fx_compatible = True

# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]

def setUp(self):
self.model_tester = CohereModelTester(self)
self.config_tester = ConfigTester(self, config_class=CohereConfig, hidden_size=37)

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip("TODO @gante fix this")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass


@require_torch
class CohereIntegrationTest(unittest.TestCase):
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
@slow
def test_model_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", device_map="auto")
out = model(torch.tensor(input_ids).unsqueeze(0))
# # Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[0.5077, -2.5771, -1.1590, -2.6220, -1.7837, -2.4421, -1.3293, -2.2028]])
torch.testing.assert_close(out[0].mean(-1).cpu(), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([ 1.8525, 5.0039, 2.7734, 3.6270, 0.9390, -0.4587, 3.4062, 0.9468, \
3.7324, 1.2344, 5.3047, 4.7266, 5.9414, 5.5195, 1.8047, 3.5215, \
1.5752, 3.7031, 6.2891, 3.4785, 2.0293, 4.2539, 2.8086, 4.7070, \
3.6953, 4.0391, 3.9766, 3.3066, 2.9395, 3.3105]) # fmt: skip
torch.testing.assert_close(out[0][0, 0, :30].cpu(), EXPECTED_SLICE, atol=1e-5, rtol=1e-5)

0 comments on commit 82a0f3b

Please sign in to comment.