Skip to content

Commit

Permalink
fix most tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Sep 16, 2024
1 parent 07da55c commit f51ccec
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 139 deletions.
33 changes: 18 additions & 15 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,10 @@ class DynamicCache(Cache):
```
"""

def __init__(self, config: PretrainedConfig) -> None:
def __init__(self, num_hidden_layers: int) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = [[] for _ in range(config.num_hidden_layers)]
self.value_cache: List[torch.Tensor] = [[] for _ in range(config.num_hidden_layers)]
self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
Expand Down Expand Up @@ -409,7 +409,7 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls()
cache = cls(len(past_key_values))
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
Expand All @@ -431,23 +431,23 @@ def crop(self, max_length: int):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]

def batch_split(self, full_batch_size: int, split_size: int, config: PretrainedConfig) -> List["DynamicCache"]:
def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache(config)
current_split = DynamicCache(num_hidden_layers)
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"], config: PretrainedConfig) -> "DynamicCache":
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls(config)
cache = cls(num_hidden_layers)
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
Expand Down Expand Up @@ -1342,7 +1342,10 @@ def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
cache = cls(
self_attention_cache=DynamicCache(len(past_key_values)),
cross_attention_cache=DynamicCache(len(past_key_values)),
)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
Expand Down Expand Up @@ -1398,25 +1401,25 @@ def crop(self, maximum_length: int):
self.self_attention_cache.crop(maximum_length)

def batch_split(
self, full_batch_size: int, split_size: int, config: PretrainedConfig
self, full_batch_size: int, split_size: int, num_hidden_layers: int
) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, config)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, config)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)

out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out

@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"], config: PretrainedConfig) -> "EncoderDecoderCache":
def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache(config)
cross_attention_cache = DynamicCache(config)
self_attention_cache = DynamicCache(num_hidden_layers)
cross_attention_cache = DynamicCache(num_hidden_layers)
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
Expand Down
28 changes: 16 additions & 12 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ def _validate_assistant(self, assistant_model):
"Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
)

if not self.config.vocab_size == assistant_model.config.vocab_size:
if not self.config.get_text_config().vocab_size == assistant_model.config.get_text_config().vocab_size:
raise ValueError("Make sure the main and assistant model use the same tokenizer")

def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
Expand Down Expand Up @@ -1575,12 +1575,11 @@ def _prepare_cache_for_generation(
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
else:
num_hidden_layers = self.config.get_text_config().num_hidden_layers
model_kwargs[cache_name] = (
DynamicCache(self.config.get_text_config())
DynamicCache(num_hidden_layers)
if not requires_cross_attention_cache
else EncoderDecoderCache(
DynamicCache(self.config.get_text_config()), DynamicCache(self.config.get_text_config())
)
else EncoderDecoderCache(DynamicCache(num_hidden_layers), DynamicCache(num_hidden_layers))
)

def _supports_num_logits_to_keep(self) -> bool:
Expand Down Expand Up @@ -3070,7 +3069,9 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx):
"legacy tuple format or `DynamicCache`"
)
past_key_values = self._reorder_cache(past_key_values, beam_idx)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values = DynamicCache.from_legacy_cache(
past_key_values,
)
# Standard code path: use the `Cache.reorder_cache`
else:
past_key_values.reorder_cache(beam_idx)
Expand Down Expand Up @@ -4260,7 +4261,7 @@ def _ranking_fast(
return selected_idx


def _split(data, full_batch_size: int, config: PretrainedConfig, split_size: int = None):
def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = None):
"""
Takes care of three cases:
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
Expand All @@ -4278,7 +4279,7 @@ def _split(data, full_batch_size: int, config: PretrainedConfig, split_size: int
elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, config, split_size)
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
Expand Down Expand Up @@ -4331,17 +4332,19 @@ def _split_model_inputs(
keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]

num_hidden_layers = config.get_text_config().num_hidden_layers

# we split the tensors and tuples of tensors
data_split_list = [
{k: _split(model_input[k], full_batch_size, config, split_size)[i] for k in non_bool_keys}
{k: _split(model_input[k], full_batch_size, num_hidden_layers, split_size)[i] for k in non_bool_keys}
for i in range(full_batch_size // split_size)
]
# bool values are the same and replicated for each split
bool_data = {k: model_input[k] for k in bool_keys}
# encoder_outputs is a ModelOutput object and should be split by its own
if "encoder_outputs" in model_input:
encoder_outputs_split = _split_model_inputs(
model_input["encoder_outputs"], split_size, full_batch_size, config=config
model_input["encoder_outputs"], split_size, full_batch_size, num_hidden_layers=num_hidden_layers
)
data_split_list = [
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
Expand Down Expand Up @@ -4370,6 +4373,7 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf

# Infer the class from the first object in the list
model_output_cls = type(model_outputs[0])
num_hidden_layers = config.get_text_config().num_hidden_layers

# Ensure all objects are of the same type
if not all(isinstance(obj, model_output_cls) for obj in model_outputs):
Expand All @@ -4386,9 +4390,9 @@ def _concat(data):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, config=config)
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, config=config)
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mllama/configuration_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.use_scaled_rope = use_scaled_rope
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,13 @@ def write_model(
n_heads_vision = 16
n_heads_per_shard_vision = n_heads_vision // num_shards
dims_per_head_vision = dim_vision // n_heads_vision
rope_scaling = {"rope_type": "llama3", "factor": 8.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192}
rope_scaling = {
"rope_type": "llama3",
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
}
max_position_embeddings = 16_384

if params.get("n_kv_heads", None) is not None:
Expand Down Expand Up @@ -268,7 +274,6 @@ def write_model(
cross_layer_shift = list(range(cross_attention_frequency - 1, n_total_layers, cross_attention_frequency + 1))
attn_layer_shift = [k for k in range(n_total_layers) if k not in cross_layer_shift]


state_dict = {}
for key in all_keys:
# Sharded
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/models/mllama/image_processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ class MllamaImageProcessor(BaseImageProcessor):
The maximum number of tiles to split the image into.
"""

model_input_names = ["pixel_values", "num_tiles", "aspect_ratios", "aspect_ratio_ids", "aspect_ratio_mask"]
model_input_names = ["pixel_values", "num_tiles", "aspect_ratio_ids", "aspect_ratio_mask"]

def __init__(
self,
Expand Down Expand Up @@ -734,20 +734,16 @@ def preprocess(

images, num_tiles = pack_images(batch_images, max_image_tiles)

# TODO: aspect ratios not be needed when ids are supported in modeling code
aspect_ratios = pack_aspect_ratios(batch_aspect_ratios, pad_value=1)
aspect_ratio_ids = convert_aspect_ratios_to_ids(batch_aspect_ratios, max_image_tiles=max_image_tiles)
aspect_ratio_mask = build_aspect_ratio_mask(batch_aspect_ratios, max_image_tiles=max_image_tiles)

# images (np.ndarray) with shape (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
# aspect_ratios (np.ndarray) with shape (batch_size, max_num_images, 2) - aspect ratios for each image, padded to max_num_images with 1
# aspect_ratio_ids (np.ndarray) with shape (batch_size, max_num_images) - aspect ratio ids for each image, padded to max_num_images with 0
# num_tiles (List[List[int]]) with (batch_size, num_images_in_batch) - real number of tiles for each image, not padded
# aspect_ratio_mask (np.ndarray) with shape (batch_size, max_num_images, max_image_tiles) - number of tiles for each image, padded to max_num_images with 0
encoded_inputs = BatchFeature(
data={
"pixel_values": images,
"aspect_ratios": aspect_ratios,
"aspect_ratio_ids": aspect_ratio_ids,
"aspect_ratio_mask": aspect_ratio_mask,
},
Expand Down
Loading

0 comments on commit f51ccec

Please sign in to comment.