Skip to content

Commit

Permalink
working generation
Browse files Browse the repository at this point in the history
  • Loading branch information
ylacombe committed Sep 20, 2024
1 parent 34b6e24 commit 50f9eb8
Show file tree
Hide file tree
Showing 2 changed files with 430 additions and 567 deletions.
26 changes: 12 additions & 14 deletions src/transformers/models/moshi/configuration_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class MoshiConfig(PretrainedConfig):
Example:
```python
```python # TODO(YL): update
>>> from transformers import (
... MoshiConfig,
... EncodecConfig,
Expand Down Expand Up @@ -189,21 +189,24 @@ def __init__(self,
self.depth_head_dim = depth_head_dim or depth_hidden_size // depth_num_attention_heads
self.depth_num_key_value_heads = depth_num_key_value_heads if depth_num_key_value_heads is not None else depth_num_attention_heads

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

if "audio_encoder" not in kwargs:
audio_encoder_config = kwargs.pop("audio_encoder", None)
if audio_encoder_config is None:
raise ValueError("Config has to be initialized with audio_encoder config")

audio_encoder_config = kwargs.pop("audio_encoder")

audio_encoder_model_type = audio_encoder_config.pop("model_type")

self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)

if self.num_codebooks > self.audio_encoder.num_codebooks:
raise ValueError(f"`num_codebooks={num_codebooks}` is greater than the maximum number of codebooks that the audio encoder can deal with ({self.audio_encoder.num_codebooks}). Please lower it.")

self.audio_vocab_size = self.audio_encoder.codebook_size if audio_vocab_size is None else audio_vocab_size

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)


@property
def sampling_rate(self):
return self.audio_encoder.sampling_rate

@classmethod
def from_audio_encoder_config(
Expand All @@ -213,17 +216,12 @@ def from_audio_encoder_config(
):
r"""
Instantiate a [`MoshiConfig`] (or a derived class) from an audio encoder configuration.
Returns:
[`MoshiConfig`]: An instance of a configuration object
"""

return cls(
audio_encoder=audio_encoder_config.to_dict(),
**kwargs,
)

@property
# This is a property because you might want to change the codec model on the fly
def sampling_rate(self):
return self.audio_encoder.sampling_rate
Loading

0 comments on commit 50f9eb8

Please sign in to comment.