Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Wav2Vec2 Adapter Weights to Flax #15521

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
48d8aef
Handle PyTorch to Flax conversion of 1D convolutions
sanchit-gandhi Feb 4, 2022
9e515c1
Add Wav2Vec2 Adapter Weights to Flax
sanchit-gandhi Feb 4, 2022
31be2f4
[deepspeed docs] Megatron-Deepspeed info (#15488)
stas00 Feb 4, 2022
ac6aa10
Standardize semantic segmentation models outputs (#15469)
sgugger Feb 4, 2022
8ce1330
[deepspeed docs] DeepSpeed ZeRO Inference (#15486)
stas00 Feb 4, 2022
e02bdce
Revert "Handle PyTorch to Flax conversion of 1D convolutions (#15519)…
patrickvonplaten Feb 7, 2022
5f1918a
[ASR pipeline] correct asr pipeline for seq2seq models (#15541)
patrickvonplaten Feb 7, 2022
c47d259
[torch_int_div] Correct true division in generation (#15498)
patrickvonplaten Feb 7, 2022
84eec9e
Add ConvNeXT (#15277)
NielsRogge Feb 7, 2022
75b13f8
[Trainer] Deeper length checks for IterableDatasetShard (#15539)
anton-l Feb 7, 2022
a459f7f
Add ASR CTC streaming example (#15309)
anton-l Feb 7, 2022
7a1412e
Wav2Vec2 models must either throw or deal with add_apater (#15409)
FremyCompany Feb 7, 2022
8255163
Revert "Handle PyTorch to Flax conversion of 1D convolutions"
sanchit-gandhi Feb 7, 2022
6775b21
Remove Longformers from ONNX-supported models (#15273)
lewtun Feb 7, 2022
131e258
Fix TF T5/LED missing cross attn in retrun values (#15511)
ydshieh Feb 7, 2022
2850299
Comment update for 1 and 2D convolutions
sanchit-gandhi Feb 7, 2022
0f1c0b5
Merge branch 'flax-utils' into flax-wav2vec2
sanchit-gandhi Feb 7, 2022
6900a67
Add adapter arg
sanchit-gandhi Feb 7, 2022
ad1d3c4
Make TF Wav2Vec2 outputs the same as PT's version (#15530)
ydshieh Feb 7, 2022
c460cfe
Remove layer drop
sanchit-gandhi Feb 7, 2022
1e34903
Remove layer drop
sanchit-gandhi Feb 7, 2022
521bd9d
Remove layer drop
sanchit-gandhi Feb 7, 2022
552f8d3
Create a custom model guide (#15489)
stevhliu Feb 7, 2022
0fe17f3
FX tracing improvement (#14321)
michaelbenayoun Feb 7, 2022
87d08af
electra is added to onnx supported model (#15084)
aaron-dunamu Feb 8, 2022
0acd84f
[GPTJ] fix docs (#15558)
patil-suraj Feb 8, 2022
6a5472a
Force use_cache to be False in PyTorch (#15385)
ydshieh Feb 8, 2022
34da433
Add Wav2Vec2 Adapter Weights to Flax
sanchit-gandhi Feb 4, 2022
1b5cb64
Comment update for 1 and 2D convolutions
sanchit-gandhi Feb 7, 2022
b131023
Remove layer drop
sanchit-gandhi Feb 7, 2022
6c3be19
Correct Flax int div implementation
sanchit-gandhi Feb 8, 2022
05f64ec
Flax int div
sanchit-gandhi Feb 8, 2022
269f02a
Correct Flax int div
sanchit-gandhi Feb 8, 2022
c5fd053
flake reformat
sanchit-gandhi Feb 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
return renamed_pt_tuple_key, pt_tensor

# conv1d layer
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 3 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
pt_tensor = pt_tensor.transpose(2, 1, 0)
return renamed_pt_tuple_key, pt_tensor

# linear layer
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
Expand Down
88 changes: 86 additions & 2 deletions src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,75 @@ def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, te
return codevectors, perplexity


class FlaxWav2Vec2Adapter(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
# hidden_states require down-projection if feature dims don't match
if self.config.output_hidden_size != self.config.hidden_size:
self.proj = nn.Dense(
self.config.output_hidden_size,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
else:
self.proj = self.proj_layer_norm = None

self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
# down-project hidden_states if required
if self.proj is not None and self.proj_layer_norm is not None:
hidden_states = self.proj(hidden_states)
hidden_states = self.proj_layer_norm(hidden_states)

hidden_states = self.layers(hidden_states)

return hidden_states


class FlaxWav2Vec2AdapterLayer(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
self.conv = nn.Conv(
features=2 * self.config.output_hidden_size,
kernel_size=(self.config.adapter_kernel_size,),
strides=(self.config.adapter_stride,),
padding=((1, 1),),
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)

def __call__(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = nn.glu(hidden_states, axis=2)

return hidden_states


class FlaxWav2Vec2AdapterLayersCollection(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32

def setup(self):
self.layers = [
FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_adapter_layers)
]

def __call__(self, hidden_states, deterministic=True):
for conv_layer in self.layers:
layerdrop_prob = np.random.random()
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved
if deterministic or (layerdrop_prob > self.layerdrop):
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved
hidden_states = conv_layer(hidden_states)
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved

return hidden_states


class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand Down Expand Up @@ -840,7 +909,9 @@ def __call__(
rngs=rngs,
)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
return self.module._get_feat_extract_output_lengths(input_lengths)
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved


Expand All @@ -860,6 +931,8 @@ def setup(self):
else:
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")

self.adapter = FlaxWav2Vec2Adapter(self.config) if self.config.add_adapter else None

def __call__(
self,
input_values,
Expand Down Expand Up @@ -905,6 +978,9 @@ def __call__(

hidden_states = encoder_outputs[0]

if self.adapter is not None:
hidden_states = self.adapter(hidden_states)

if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]

Expand Down Expand Up @@ -1021,11 +1097,15 @@ def __call__(

return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)

def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""

add_adapter = self.config.add_adapter if add_adapter is None else add_adapter

def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
Expand All @@ -1034,6 +1114,10 @@ def _conv_out_length(input_length, kernel_size, stride):
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)

if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)

return input_lengths


Expand Down