diff --git a/mindnlp/transformers/models/__init__.py b/mindnlp/transformers/models/__init__.py
index 722aa0f7d..0b35f57d1 100644
--- a/mindnlp/transformers/models/__init__.py
+++ b/mindnlp/transformers/models/__init__.py
@@ -145,6 +145,7 @@
     mctct,
     megatron_bert,
     mgp_str,
+    mimi,
     minicpm,
     minicpm3,
     mistral,
@@ -390,6 +391,7 @@
 from .mctct import *
 from .megatron_bert import *
 from .mgp_str import *
+from .mimi import *
 from .minicpm import *
 from .minicpm3 import *
 from .mistral import *
@@ -635,6 +637,7 @@
 __all__.extend(mctct.__all__)
 __all__.extend(megatron_bert.__all__)
 __all__.extend(mgp_str.__all__)
+__all__.extend(mimi.__all__)
 __all__.extend(minicpm.__all__)
 __all__.extend(minicpm3.__all__)
 __all__.extend(mistral.__all__)
@@ -750,4 +753,4 @@
 __all__.extend(umt5.__all__)
 __all__.extend(xmod.__all__)
 __all__.extend(fuyu.__all__)
-__all__.extend(yolos.__all__)
+__all__.extend(yolos.__all__)
\ No newline at end of file
diff --git a/mindnlp/transformers/models/auto/configuration_auto.py b/mindnlp/transformers/models/auto/configuration_auto.py
index 73d5851f2..d0e67900b 100644
--- a/mindnlp/transformers/models/auto/configuration_auto.py
+++ b/mindnlp/transformers/models/auto/configuration_auto.py
@@ -143,6 +143,7 @@
         ("mbart", "MBartConfig"),
         ("mctct", "MCTCTConfig"),
         ("megatron-bert", 'MegatronBertConfig'),
+        ("mimi", "MimiConfig"),
         ("minicpm", "MiniCPMConfig"),
         ("minicpm3", "MiniCPM3Config"),
         ("mistral", "MistralConfig"),
@@ -362,6 +363,7 @@
         ("mega", "MEGA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
         ("megatron-bert", "MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
         ("mgp-str", "MGP_STR_PRETRAINED_CONFIG_ARCHIVE_MAP"),
+        ("mimi", "MIMI_PRETRAINED_CONFIG_ARCHIVE_MAP"),
         ("minicpm3", "MINICPM3_PRETRAINED_CONFIG_ARCHIVE_MAP"),
         ("mistral", "MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
         ("mixtral", "MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP"),
@@ -621,6 +623,7 @@
         ("megatron-bert", "Megatron-BERT"),
         ("megatron_gpt2", "Megatron-GPT2"),
         ("mgp-str", "MGP-STR"),
+        ("mimi", "Mimi"),
         ("minicpm", "MiniCPM"),
         ("minicpm3", "MiniCPM3"),
         ("mistral", "Mistral"),
diff --git a/mindnlp/transformers/models/mini/__init__.py b/mindnlp/transformers/models/mini/__init__.py
new file mode 100644
index 000000000..61fce3f73
--- /dev/null
+++ b/mindnlp/transformers/models/mini/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Mimi Model.
+"""
+from . import modeling_mimi, configuration_mimi
+from .modeling_mimi import *
+from .configuration_mimi import *
+
+__all__ = []
+__all__.extend(modeling_mimi.__all__)
+__all__.extend(configuration_mimi.__all__)
\ No newline at end of file
diff --git a/mindnlp/transformers/models/mini/configuration_mimi.py b/mindnlp/transformers/models/mini/configuration_mimi.py
new file mode 100644
index 000000000..be62a24a0
--- /dev/null
+++ b/mindnlp/transformers/models/mini/configuration_mimi.py
@@ -0,0 +1,238 @@
+# coding=utf-8
+# Copyright 2024 Meta Platforms, Inc. and affiliates, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Mimi model configuration"""
+
+import math
+
+import numpy as np
+
+from mindnlp.utils import logging
+from ...configuration_utils import PretrainedConfig
+
+
+
+logger = logging.get_logger(__name__)
+
+
+class MimiConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of an [`MimiModel`]. It is used to instantiate a
+    Mimi model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the
+    [kyutai/mimi](https://huggingface.co/kyutai/mimi) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        sampling_rate (`int`, *optional*, defaults to 24000):
+            The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
+        frame_rate (`float`, *optional*, defaults to 12.5):
+            Framerate of the model.
+        audio_channels (`int`, *optional*, defaults to 1):
+            Number of channels in the audio data. Either 1 for mono or 2 for stereo.
+        hidden_size (`int`, *optional*, defaults to 512):
+            Intermediate representation dimension.
+        num_filters (`int`, *optional*, defaults to 64):
+            Number of convolution kernels of first `MimiConv1d` down sampling layer.
+        num_residual_layers (`int`,  *optional*, defaults to 1):
+            Number of residual layers.
+        upsampling_ratios (`Sequence[int]`, *optional*):
+            Kernel size and stride ratios. The encoder uses downsampling ratios instead of upsampling ratios, hence it
+            will use the ratios in the reverse order to the ones specified here that must match the decoder order.
+            If not specified, will defaults to `[8, 6, 5, 4]`
+        kernel_size (`int`, *optional*, defaults to 7):
+            Kernel size for the initial convolution.
+        last_kernel_size (`int`, *optional*, defaults to 3):
+            Kernel size for the last convolution layer.
+        residual_kernel_size (`int`, *optional*, defaults to 3):
+            Kernel size for the residual layers.
+        dilation_growth_rate (`int`, *optional*, defaults to 2):
+            How much to increase the dilation with each layer.
+        use_causal_conv (`bool`, *optional*, defaults to `True`):
+            Whether to use fully causal convolution.
+        pad_mode (`str`, *optional*, defaults to `"constant"`):
+            Padding mode for the convolutions.
+        compress (`int`, *optional*, defaults to 2):
+            Reduced dimensionality in residual branches.
+        trim_right_ratio (`float`, *optional*, defaults to 1.0):
+            Ratio for trimming at the right of the transposed convolution under the `use_causal_conv = True` setup. If
+            equal to 1.0, it means that all the trimming is done at the right.
+        codebook_size (`int`, *optional*, defaults to 2048):
+            Number of discret codes in each codebooks.
+        codebook_dim (`int`, *optional*, defaults to 256):
+            Dimension of the unquantized codebook vectors. If not defined, uses `hidden_size`.
+        num_quantizers (`int`, *optional*, defaults to 32):
+            Number of quantizer channels, or codebooks, in the quantizer.
+        use_conv_shortcut (`bool`, *optional*, defaults to `False`):
+            Whether to use a convolutional layer as the 'skip' connection in the `MimiResnetBlock` block. If False,
+            an identity function will be used, giving a generic residual connection.
+        vector_quantization_hidden_dimension (`int`, *optional*, defaults to 256):
+            Intermediate representation dimension in the residual vector quantization space.
+        num_semantic_quantizers (`int`, *optional*, defaults to 1):
+            Number of semantic quantizer channels, or codebooks, in the semantic quantizer. Must be lower than `num_quantizers`.
+        upsample_groups (`int`, *optional*, defaults to 512):
+            If `frame_rate!=encodec_frame_rate`, indicates the number of groups used in the upsampling operation to go from one rate to another.
+        num_hidden_layers (`int`, *optional*, defaults to 8):
+            Number of hidden layers in the Transformer models.
+        intermediate_size (`int`, *optional*, defaults to 2048):
+            Dimension of the MLP representations.
+        num_attention_heads (`int`, *optional*, defaults to 8):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        num_key_value_heads (`int`, *optional*, defaults to 8):
+            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+            by meanpooling all the original heads within that group. For more details checkout [this
+            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
+        head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
+            The attention head dimension.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the decoder.
+        max_position_embeddings (`int`, *optional*, defaults to 8000):
+            The maximum sequence length that this model might ever be used with. Mimi's sliding window attention
+            allows sequence of up to 8000 tokens.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the LayerNorm normalization layers.
+        use_cache (`bool`, *optional*, defaults to `False`):
+            Whether or not the model should return the last key/values attentions (not used by all models). Only
+            relevant if `config.is_decoder=True`.
+        rope_theta (`float`, *optional*, defaults to 10000.0):
+            The base period of the RoPE embeddings.
+        sliding_window (`int`, *optional*, defaults to 250):
+            Sliding window attention window size. If not specified, will default to `250`.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
+            Initiale scale of the residual rescaling operation done in the Transformer models.
+        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
+            Whether to use a bias in the query, key, value and output projection layers during self-attention.
+    Example:
+
+    ```python
+    >>> from transformers import MimiModel, MimiConfig
+
+    >>> # Initializing a "kyutai/mimi" style configuration
+    >>> configuration = MimiConfig()
+
+    >>> # Initializing a model (with random weights) from the "kyutai/mimi" style configuration
+    >>> model = MimiModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "mimi"
+
+    def __init__(
+        self,
+        sampling_rate=24_000,
+        frame_rate=12.5,
+        audio_channels=1,
+        hidden_size=512,
+        num_filters=64,
+        num_residual_layers=1,
+        upsampling_ratios=None,
+        kernel_size=7,
+        last_kernel_size=3,
+        residual_kernel_size=3,
+        dilation_growth_rate=2,
+        use_causal_conv=True,
+        pad_mode="constant",
+        compress=2,
+        trim_right_ratio=1.0,
+        codebook_size=2048,
+        codebook_dim=256,
+        num_quantizers=32,
+        use_conv_shortcut=False,
+        vector_quantization_hidden_dimension=256,
+        num_semantic_quantizers=1,
+        upsample_groups=512,
+        num_hidden_layers=8,
+        intermediate_size=2048,
+        num_attention_heads=8,
+        num_key_value_heads=8,
+        head_dim=None,
+        hidden_act="gelu",
+        max_position_embeddings=8000,
+        initializer_range=0.02,
+        norm_eps=1e-5,
+        use_cache=False,
+        rope_theta=10000.0,
+        sliding_window=250,
+        attention_dropout=0.0,
+        layer_scale_initial_scale=0.01,
+        attention_bias=False,
+        **kwargs,
+    ):
+        self.sampling_rate = sampling_rate
+        self.frame_rate = frame_rate
+        self.audio_channels = audio_channels
+        self.hidden_size = hidden_size
+        self.num_filters = num_filters
+        self.num_residual_layers = num_residual_layers
+        self.upsampling_ratios = upsampling_ratios if upsampling_ratios else [8, 6, 5, 4]
+        self.kernel_size = kernel_size
+        self.last_kernel_size = last_kernel_size
+        self.residual_kernel_size = residual_kernel_size
+        self.dilation_growth_rate = dilation_growth_rate
+        self.use_causal_conv = use_causal_conv
+        self.pad_mode = pad_mode
+        self.compress = compress
+        self.trim_right_ratio = trim_right_ratio
+        self.codebook_size = codebook_size
+        self.codebook_dim = codebook_dim if codebook_dim is not None else hidden_size
+        self.num_quantizers = num_quantizers
+        self.use_conv_shortcut = use_conv_shortcut
+        self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension
+        self.upsample_groups = upsample_groups
+        self.num_hidden_layers = num_hidden_layers
+        self.intermediate_size = intermediate_size
+        self.num_attention_heads = num_attention_heads
+        self.num_key_value_heads = num_key_value_heads
+        self.hidden_act = hidden_act
+        self.max_position_embeddings = max_position_embeddings
+        self.initializer_range = initializer_range
+        self.norm_eps = norm_eps
+        self.use_cache = use_cache
+        self.rope_theta = rope_theta
+        self.sliding_window = sliding_window
+        self.attention_dropout = attention_dropout
+        self.head_dim = head_dim or hidden_size // num_attention_heads
+        self.layer_scale_initial_scale = layer_scale_initial_scale
+        self.attention_bias = attention_bias
+
+        if num_semantic_quantizers >= self.num_quantizers:
+            raise ValueError(
+                f"The number of semantic quantizers should be lower than the total number of quantizers {self.num_quantizers}, but is currently {num_semantic_quantizers}."
+            )
+        self.num_semantic_quantizers = num_semantic_quantizers
+        super().__init__(**kwargs)
+
+    @property
+    def encodec_frame_rate(self) -> int:
+        hop_length = np.prod(self.upsampling_ratios)
+        return math.ceil(self.sampling_rate / hop_length)
+
+    @property
+    def num_codebooks(self) -> int:
+        # alias to num_quantizers
+        return self.num_quantizers
+
+
+__all__ = ["MimiConfig"]
diff --git a/mindnlp/transformers/models/mini/modeling_mimi.py b/mindnlp/transformers/models/mini/modeling_mimi.py
new file mode 100644
index 000000000..15a3bddd2
--- /dev/null
+++ b/mindnlp/transformers/models/mini/modeling_mimi.py
@@ -0,0 +1,1799 @@
+# coding=utf-8
+# Copyright 2024 Kyutai, and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Mindspore Mimi model."""
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import mindspore as ms
+from mindspore.common.initializer import initializer, TruncatedNormal
+from mindnlp.core import nn, ops
+from ....core.autograd import no_grad
+
+from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from ...modeling_utils import PreTrainedModel
+from ....common.activations import ACT2FN
+from ....amp import autocast
+from ....utils import (
+    ModelOutput,
+    logging,
+)
+# from ....utils import (
+#     ModelOutput,
+#     add_start_docstrings,
+#     add_start_docstrings_to_model_forward,
+#     is_flash_attn_2_available,
+#     is_flash_attn_greater_or_equal_2_10,
+#     logging,
+#     replace_return_docstrings,
+# )
+from .configuration_mimi import MimiConfig
+
+
+# if is_flash_attn_2_available():
+#     from ...modeling_flash_attention_utils import _flash_attention_forward
+
+logger = logging.get_logger(__name__)
+
+
+# General docstring
+_CONFIG_FOR_DOC = "MimiConfig"
+
+
+@dataclass
+class MimiOutput(ModelOutput):
+    """
+    Args:
+        audio_codes (`torch.LongTensor`  of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
+            Discret code embeddings computed using `model.encode`.
+        audio_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*)
+            Decoded audio values, obtained using the decoder part of Mimi.
+        encoder_past_key_values (`Cache`, *optional*):
+            Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
+            This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+            The model will output the same cache format that is fed as input.
+            If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+            have their past key value states given to this model).
+        decoder_past_key_values (`Cache`, *optional*):
+            Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
+            This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+            The model will output the same cache format that is fed as input.
+            If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+            have their past key value states given to this model).
+    """
+
+    audio_codes: ms.Tensor = None
+    audio_values: ms.Tensor = None
+    encoder_past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None
+    decoder_past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None
+
+
+@dataclass
+class MimiEncoderOutput(ModelOutput):
+    """
+    Args:
+        audio_codes (`torch.LongTensor`  of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
+            Discret code embeddings computed using `model.encode`.
+        encoder_past_key_values (`Cache`, *optional*):
+            Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
+            This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+            The model will output the same cache format that is fed as input.
+            If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+            have their past key value states given to this model).
+    """
+
+    audio_codes: ms.Tensor = None
+    encoder_past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None
+
+
+@dataclass
+class MimiDecoderOutput(ModelOutput):
+    """
+    Args:
+        audio_values (`torch.FloatTensor`  of shape `(batch_size, segment_length)`, *optional*):
+            Decoded audio values, obtained using the decoder part of Mimi.
+        decoder_past_key_values (`Cache`, *optional*):
+            Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
+            This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+            The model will output the same cache format that is fed as input.
+            If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+            have their past key value states given to this model).
+    """
+
+    audio_values: ms.Tensor = None
+    decoder_past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None
+
+
+class MimiConv1d(nn.Module):
+    """Conv1d with asymmetric or causal padding and normalization."""
+
+    def __init__(
+        self,
+        config,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        pad_mode=None,
+        bias: bool = True,
+    ):
+        super().__init__()
+        self.causal = config.use_causal_conv
+        self.pad_mode = config.pad_mode if pad_mode is None else pad_mode
+
+        # warn user on unusual setup between dilation and stride
+        if stride > 1 and dilation > 1:
+            logger.warning(
+                "MimiConv1d has been initialized with stride > 1 and dilation > 1"
+                f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
+            )
+
+        self.conv = nn.Conv1d(
+            in_channels, out_channels, kernel_size, stride, dilation=dilation, groups=groups, bias=bias
+        )
+
+        kernel_size = self.conv.kernel_size[0]
+        stride = ms.tensor(self.conv.stride[0], dtype=ms.int64)
+        dilation = self.conv.dilation[0]
+
+        # Effective kernel size with dilations.
+        kernel_size = ms.tensor((kernel_size - 1) * dilation + 1, dtype=ms.int64)
+
+        self.register_buffer("stride", stride, persistent=False)
+        self.register_buffer("kernel_size", kernel_size, persistent=False)
+        self.register_buffer("padding_total", ms.tensor(kernel_size - stride, dtype=ms.int64), persistent=False)
+
+        # Asymmetric padding required for odd strides
+        self.padding_right = self.padding_total // 2
+        self.padding_left = self.padding_total - self.padding_right
+
+    def apply_weight_norm(self):
+        weight_norm = nn.utils.weight_norm
+        if hasattr(nn.utils.parametrizations, "weight_norm"):
+            weight_norm = nn.utils.parametrizations.weight_norm
+
+        weight_norm(self.conv)
+
+    def remove_weight_norm(self):
+        nn.utils.remove_weight_norm(self.conv)
+
+    # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._get_extra_padding_for_conv1d
+    def _get_extra_padding_for_conv1d(
+        self,
+        hidden_states: ms.Tensor,
+    ) -> ms.Tensor:
+        """See `pad_for_conv1d`."""
+        length = hidden_states.shape[-1]
+        n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
+        #向上取整
+        n_frames = ops.ceil(n_frames).to(ms.int64) - 1
+        ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
+
+        return ideal_length - length
+
+    @staticmethod
+    # Copied from transformers.models.encodec.modeling_encodec.EncodecConv1d._pad1d
+    def _pad1d(hidden_states: ms.Tensor, paddings: Tuple[int, int], mode: str = "zero", value: float = 0.0):
+        """Tiny wrapper around torch.nn.functional.pad, just to allow for reflect padding on small input.
+        If this is the case, we insert extra 0 padding to the right before the reflection happens.
+        """
+        length = hidden_states.shape[-1]
+        paddings = (int(paddings[0]), int(paddings[1]))
+        padding_left, padding_right = paddings
+        if not mode == "reflect":
+            return nn.functional.pad(hidden_states, paddings, mode, value)
+        # Failed calling ConstantPadND with "ConstantPadND()(input=Tensor, padding=tuple<Tensor, Tensor>, value=float)".
+        # The valid calling should be: 
+        # "ConstantPadND()(input=<Tensor>, padding=<list of int, Tensor, tuple of int>, value=<Number>)".
+        max_pad = max(padding_left, padding_right)
+        extra_pad = 0
+        if length <= max_pad:
+            extra_pad = max_pad - length + 1
+            hidden_states = nn.functional.pad(hidden_states, (0, extra_pad))
+        padded = nn.functional.pad(hidden_states, paddings, mode, value)
+        end = padded.shape[-1] - extra_pad
+        return padded[..., :end]
+
+    def forward(self, hidden_states):
+        extra_padding = self._get_extra_padding_for_conv1d(hidden_states)
+
+        if self.causal:
+            # Left padding for causal
+            hidden_states = self._pad1d(hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode)
+        else:
+            hidden_states = self._pad1d(
+                hidden_states, (self.padding_left, self.padding_right + extra_padding), mode=self.pad_mode
+            )
+
+        hidden_states = self.conv(hidden_states)
+        return hidden_states
+
+
+class MimiConvTranspose1d(nn.Module):
+    """ConvTranspose1d with asymmetric or causal padding and normalization."""
+
+    def __init__(
+        self,
+        config,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        groups: int = 1,
+        bias=True,
+    ):
+        super().__init__()
+        self.causal = config.use_causal_conv
+        self.trim_right_ratio = config.trim_right_ratio
+        self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, groups=groups, bias=bias)
+
+        if not (self.causal or self.trim_right_ratio == 1.0):
+            raise ValueError("`trim_right_ratio` != 1.0 only makes sense for causal convolutions")
+
+        kernel_size = self.conv.kernel_size[0]
+        stride = self.conv.stride[0]
+        padding_total = kernel_size - stride
+
+        # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+        # removed at the very end, when keeping only the right length for the output,
+        # as removing it here would require also passing the length at the matching layer
+        # in the encoder.
+        if self.causal:
+            # Trim the padding on the right according to the specified ratio
+            # if trim_right_ratio = 1.0, trim everything from right
+            self.padding_right = math.ceil(padding_total * self.trim_right_ratio)
+        else:
+            # Asymmetric padding required for odd strides
+            self.padding_right = padding_total // 2
+
+        self.padding_left = padding_total - self.padding_right
+
+    def apply_weight_norm(self):
+        weight_norm = nn.utils.weight_norm
+        if hasattr(nn.utils.parametrizations, "weight_norm"):
+            weight_norm = nn.utils.parametrizations.weight_norm
+
+        weight_norm(self.conv)
+
+    def remove_weight_norm(self):
+        nn.utils.remove_weight_norm(self.conv)
+
+    def forward(self, hidden_states):
+        hidden_states = self.conv(hidden_states)
+
+        # unpad
+        end = hidden_states.shape[-1] - self.padding_right
+        hidden_states = hidden_states[..., self.padding_left : end]
+        return hidden_states
+
+
+# Copied from transformers.models.encodec.modeling_encodec.EncodecResnetBlock with Encodec->Mimi,EnCodec->Mimi
+class MimiResnetBlock(nn.Module):
+    """
+    Residual block from SEANet model as used by Mimi.
+    """
+
+    def __init__(self, config: MimiConfig, dim: int, dilations: List[int]):
+        super().__init__()
+        kernel_sizes = (config.residual_kernel_size, 1)
+        if len(kernel_sizes) != len(dilations):
+            raise ValueError("Number of kernel sizes should match number of dilations")
+
+        hidden = dim // config.compress
+        block = []
+        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+            in_chs = dim if i == 0 else hidden
+            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+            block += [nn.ELU()]
+            block += [MimiConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)]
+        self.block = nn.ModuleList(block)
+        #即它不对输入数据进行任何变换,直接将输入作为输出返回
+        if config.use_conv_shortcut:
+            self.shortcut = MimiConv1d(config, dim, dim, kernel_size=1)
+        else:
+            self.shortcut = nn.Identity()
+
+    def forward(self, hidden_states):
+        residual = hidden_states
+        for layer in self.block:
+            hidden_states = layer(hidden_states)
+
+        return self.shortcut(residual) + hidden_states
+
+
+class MimiEncoder(nn.Module):
+    """SEANet encoder as used by Mimi."""
+
+    def __init__(self, config: MimiConfig):
+        super().__init__()
+        model = [MimiConv1d(config, config.audio_channels, config.num_filters, config.kernel_size)]
+        scaling = 1
+
+        # Downsample to raw audio scale
+        for ratio in reversed(config.upsampling_ratios):
+            current_scale = scaling * config.num_filters
+            # Add residual layers
+            for j in range(config.num_residual_layers):
+                model += [MimiResnetBlock(config, current_scale, [config.dilation_growth_rate**j, 1])]
+            # Add downsampling layers
+            model += [nn.ELU()]
+            model += [MimiConv1d(config, current_scale, current_scale * 2, kernel_size=ratio * 2, stride=ratio)]
+            scaling *= 2
+
+        model += [nn.ELU()]
+        model += [MimiConv1d(config, scaling * config.num_filters, config.hidden_size, config.last_kernel_size)]
+
+        self.layers = nn.ModuleList(model)
+
+    # Copied from transformers.models.encodec.modeling_encodec.EncodecEncoder.forward
+    def forward(self, hidden_states):
+        for layer in self.layers:
+            hidden_states = layer(hidden_states)
+        return hidden_states
+
+
+class MimiLayerScale(nn.Module):
+    """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+    This rescales diagonally the residual outputs close to 0, with a learnt scale.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        channels = config.hidden_size
+        initial_scale = config.layer_scale_initial_scale
+        self.scale = nn.Parameter(ops.full((channels,), initial_scale), requires_grad=True)
+
+    def forward(self, x: ms.Tensor):
+        return self.scale * x
+
+
+# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi
+class MimiRotaryEmbedding(nn.Module):
+    def __init__(self, config: MimiConfig, device=None):
+        super().__init__()
+        # BC: "rope_type" was originally "type"
+        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+        else:
+            self.rope_type = "default"
+        self.max_seq_len_cached = config.max_position_embeddings
+        self.original_max_seq_len = config.max_position_embeddings
+
+        self.config = config
+        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self.original_inv_freq = self.inv_freq
+
+    def _dynamic_frequency_update(self, position_ids, device):
+        """
+        dynamic RoPE layers should recompute `inv_freq` in the following situations:
+        1 - growing beyond the cached sequence length (allow scaling)
+        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+        """
+        seq_len = ops.max(position_ids) + 1
+        if seq_len > self.max_seq_len_cached:  # growth
+            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
+            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
+            self.max_seq_len_cached = seq_len
+
+        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
+            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+            self.max_seq_len_cached = self.original_max_seq_len
+
+    @no_grad()
+    def forward(self, x, position_ids):
+        if "dynamic" in self.rope_type:
+            device = ms.get_context('device_target')
+            self._dynamic_frequency_update(position_ids, device=device)
+
+        # Core RoPE block
+        #TypeError: expand() takes 2 positional arguments but 4 were given
+        inv_freq_expanded = self.inv_freq[None, :, None].float().broadcast_to((position_ids.shape[0], -1, 1))
+        position_ids_expanded = position_ids[:, None, :].float()
+        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+        device_type = ms.get_context('device_target')
+        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+        with autocast(dtype=device_type, enabled=False):
+            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose((0, 2, 1))
+            emb = ops.cat((freqs, freqs), dim=-1)
+            cos = emb.cos()
+            sin = emb.sin()
+
+        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+        cos = cos * self.attention_scaling
+        sin = sin * self.attention_scaling
+
+        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return ops.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+    """Applies Rotary Position Embedding to the query and key tensors.
+    Args:
+        q (`torch.Tensor`): The query tensor.
+        k (`torch.Tensor`): The key tensor.
+        cos (`torch.Tensor`): The cosine part of the rotary embedding.
+        sin (`torch.Tensor`): The sine part of the rotary embedding.
+        position_ids (`torch.Tensor`, *optional*):
+            Deprecated and unused.
+        unsqueeze_dim (`int`, *optional*, defaults to 1):
+            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+    Returns:
+        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+    """
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class MimiMLP(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.activation_fn = ACT2FN[config.hidden_act]
+        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
+        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+
+    # Copied from transformers.models.clip.modeling_clip.CLIPMLP.forward
+    def forward(self, hidden_states: ms.Tensor) -> ms.Tensor:
+        hidden_states = self.fc1(hidden_states)
+        hidden_states = self.activation_fn(hidden_states)
+        hidden_states = self.fc2(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim))
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi
+# no longer copied after attention refactors
+class MimiAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        if layer_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+                "when creating this class."
+            )
+
+        self.attention_dropout = config.attention_dropout
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = config.head_dim
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.is_causal = True
+        self.scaling = 1 / math.sqrt(config.head_dim)
+
+        if self.hidden_size % self.num_heads != 0:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads})."
+            )
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
+        self.rotary_emb = MimiRotaryEmbedding(config)
+        self.sliding_window = config.sliding_window  # Ignore copy
+
+    def forward(
+        self,
+        hidden_states: ms.Tensor,
+        attention_mask: Optional[ms.Tensor] = None,
+        position_ids: Optional[ms.Tensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[ms.Tensor] = None,
+    ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+
+        hidden_states = ms.Tensor(hidden_states)
+        bsz, q_len, _ = hidden_states.shape
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+        query_states = query_states.transpose((0, 2, 1, 3))
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+        key_states = key_states.transpose((0, 2, 1, 3))
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+        value_states = value_states.transpose((0, 2, 1, 3))
+
+        cos, sin = self.rotary_emb(value_states, position_ids)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_value is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = ops.matmul(query_states, key_states.transpose((0,1,3,2))) * self.scaling
+
+        if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+            attn_weights = attn_weights + causal_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query_states.dtype)
+        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+        attn_output = ops.matmul(attn_weights, value_states)
+
+        if attn_output.shape != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.shape}"
+            )
+
+        attn_output = attn_output.transpose((0, 2, 1,3)).contiguous()
+
+        attn_output = attn_output.view(bsz, q_len, -1)
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi
+# TODO cyril: modular
+# class MimiFlashAttention2(MimiAttention):
+#     """
+#     Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays
+#     untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+#     flash attention and deal with padding tokens in case the input contains any of them.
+#     """
+
+#     def __init__(self, *args, **kwargs):
+#         super().__init__(*args, **kwargs)
+
+#         # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+#         # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+#         # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+#         self._flash_attn_uses_top_left_mask = False
+
+#     def forward(
+#         self,
+#         hidden_states: ms.Tensor,
+#         attention_mask: Optional[ms.Tensor] = None,
+#         position_ids: Optional[ms.Tensor] = None,
+#         past_key_value: Optional[Cache] = None,
+#         output_attentions: bool = False,
+#         use_cache: bool = False,
+#         cache_position: Optional[ms.Tensor] = None,
+#     ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+#         if isinstance(past_key_value, StaticCache):
+#             raise ValueError(
+#                 "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+#                 "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+#             )
+
+#         output_attentions = False
+
+#         bsz, q_len, _ = hidden_states.size()
+
+#         query_states = self.q_proj(hidden_states)
+#         key_states = self.k_proj(hidden_states)
+#         value_states = self.v_proj(hidden_states)
+
+#         # Flash attention requires the input to have the shape
+#         # batch_size x seq_length x head_dim x hidden_dim
+#         # therefore we just need to keep the original shape
+#         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+#         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+#         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+#         cos, sin = self.rotary_emb(value_states, position_ids)
+#         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+#         if past_key_value is not None:
+#             # sin and cos are specific to RoPE models; cache_position needed for the static cache
+#             cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+#             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+#         # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+#         # to be able to avoid many of these transpose/reshape/view.
+#         query_states = query_states.transpose(1, 2)
+#         key_states = key_states.transpose(1, 2)
+#         value_states = value_states.transpose(1, 2)
+
+#         dropout_rate = self.attention_dropout if self.training else 0.0
+
+#         # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+#         # therefore the input hidden states gets silently casted in float32. Hence, we need
+#         # cast them back in the correct dtype just to be sure everything works as expected.
+#         # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+#         # in fp32. (MimiRMSNorm handles it correctly)
+
+#         input_dtype = query_states.dtype
+#         if input_dtype == ms.float32:
+#             if ops.is_autocast_enabled():
+#                 target_dtype = ops.get_default_dtype()#get_autocast_gpu_dtype()
+#             # Handle the case where the model is quantized
+#             elif hasattr(self.config, "_pre_quantization_dtype"):
+#                 target_dtype = self.config._pre_quantization_dtype
+#             else:
+#                 target_dtype = self.q_proj.weight.dtype
+
+#             logger.warning_once(
+#                 f"The input hidden states seems to be silently casted in float32, this might be related to"
+#                 f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+#                 f" {target_dtype}."
+#             )
+
+#             query_states = query_states.to(target_dtype)
+#             key_states = key_states.to(target_dtype)
+#             value_states = value_states.to(target_dtype)
+
+        # attn_output = _flash_attention_forward(
+        #     query_states,
+        #     key_states,
+        #     value_states,
+        #     attention_mask,
+        #     q_len,
+        #     position_ids=position_ids,
+        #     dropout=dropout_rate,
+        #     sliding_window=getattr(self, "sliding_window", None),
+        #     is_causal=self.is_causal,
+        #     use_top_left_mask=self._flash_attn_uses_top_left_mask,
+        # )
+
+        # attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+        # attn_output = self.o_proj(attn_output)
+
+        # if not output_attentions:
+        #     attn_weights = None
+
+        # return attn_output, attn_weights, past_key_value
+
+
+# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi
+# TODO cyril: modular
+class MimiSdpaAttention(MimiAttention):
+    """
+    Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+    `MimiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+    SDPA API.
+    """
+
+    # Adapted from MimiAttention.forward
+    def forward(
+        self,
+        hidden_states: ms.Tensor,
+        attention_mask: Optional[ms.Tensor] = None,
+        position_ids: Optional[ms.Tensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[ms.Tensor] = None,
+        **kwargs,
+    ) -> Tuple[ms.Tensor, Optional[ms.Tensor], Optional[Tuple[ms.Tensor]]]:
+        if output_attentions:
+            # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+            logger.warning_once(
+                "MimiModel is using MimiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+                'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+            )
+            return super().forward(
+                hidden_states=hidden_states,
+                attention_mask=attention_mask,
+                position_ids=position_ids,
+                past_key_value=past_key_value,
+                output_attentions=output_attentions,
+                use_cache=use_cache,
+                cache_position=cache_position,
+            )
+
+        bsz, q_len, _ = hidden_states.shape
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose((0, 2, 1, 3))
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose((0, 2, 1, 3))
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose((0, 2, 1, 3))
+
+        cos, sin = self.rotary_emb(value_states, position_ids)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_value is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        causal_mask = attention_mask
+        if attention_mask is not None:
+            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+        # Reference: https://github.com/pytorch/pytorch/issues/112577.
+        if causal_mask is not None:
+            query_states = query_states.contiguous()
+            key_states = key_states.contiguous()
+            value_states = value_states.contiguous()
+
+        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
+        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
+        is_causal = True if causal_mask is None and q_len > 1 else False
+
+        attn_output = nn.functional.scaled_dot_product_attention(
+            query_states,
+            key_states,
+            value_states,
+            attn_mask=causal_mask,
+            dropout_p=self.attention_dropout if self.training else 0.0,
+            is_causal=is_causal,
+        )
+
+
+        attn_output = attn_output.transpose((0, 2, 1)).contiguous()
+        attn_output = attn_output.view(bsz, q_len, -1)
+
+        attn_output = self.o_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+#"flash_attention_2": MimiFlashAttention2,
+MIMI_ATTENTION_CLASSES = {
+    "eager": MimiAttention,
+    "sdpa": MimiSdpaAttention,
+}
+
+
+class MimiTransformerLayer(nn.Module):
+    def __init__(self, config: MimiConfig, layer_idx: int):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+
+        self.self_attn = MIMI_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
+
+        self.mlp = MimiMLP(config)
+        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
+        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
+        self.self_attn_layer_scale = MimiLayerScale(config)
+        self.mlp_layer_scale = MimiLayerScale(config)
+
+    def forward(
+        self,
+        hidden_states: ms.Tensor,
+        attention_mask: Optional[ms.Tensor] = None,
+        position_ids: Optional[ms.Tensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+        cache_position: Optional[ms.Tensor] = None,
+        **kwargs,
+    ) -> Tuple[ms.Tensor, Optional[Tuple[ms.Tensor, ms.Tensor]]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*):
+                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+                query_sequence_length, key_sequence_length)` if default attention is used.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+                Indices depicting the position of the input sequence tokens in the sequence
+            kwargs (`dict`, *optional*):
+                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+                into the model
+        """
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+            cache_position=cache_position,
+            **kwargs,
+        )
+        hidden_states = residual + self.self_attn_layer_scale(hidden_states)
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + self.mlp_layer_scale(hidden_states)
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+class MimiTransformerModel(nn.Module):
+    """
+    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MimiTransformerLayer`]
+    Args:
+        config: MimiConfig
+    """
+
+    def __init__(self, config: MimiConfig):
+        super().__init__()
+
+        self.layers = nn.ModuleList(
+            [MimiTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+        )
+        self._attn_implementation = config._attn_implementation
+
+        self.gradient_checkpointing = False
+        self.config = config
+
+    def forward(
+        self,
+        hidden_states: ms.Tensor = None,
+        attention_mask: Optional[ms.Tensor] = None,
+        position_ids: Optional[ms.Tensor] = None,
+        past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[ms.Tensor] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Embedded representation that will be contextualized by the model
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+                [What are attention masks?](../glossary#attention-mask)
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+                If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+                `past_key_values`).
+                If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+                and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+                information on the default strategy.
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+            position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+                config.n_positions - 1]`.
+                [What are position IDs?](../glossary#position-ids)
+            past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+                Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+                blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+                returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+                Two formats are allowed:
+                - a [`~cache_utils.Cache`] instance;
+                - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+                cache format.
+                The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+                legacy cache format will be returned.
+                If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+                have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+                of shape `(batch_size, sequence_length)`.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+                `past_key_values`).
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+                tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+                more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if self.gradient_checkpointing and self.training and use_cache:
+            logger.warning_once(
+                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+            )
+            use_cache = False
+
+        if use_cache and not isinstance(past_key_values, Cache):
+            if past_key_values is None:
+                past_key_values = DynamicCache()
+            else:
+                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+                logger.warning_once(
+                    "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
+                    "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
+                    "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
+                )
+
+        if cache_position is None:
+            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+            cache_position = ops.arange(
+                past_seen_tokens, past_seen_tokens + hidden_states.shape[1]
+            )
+
+        if position_ids is None:
+            position_ids = cache_position.unsqueeze(0)
+
+        causal_mask = None
+        if attention_mask is not None:
+            causal_mask = self._update_causal_mask(
+                attention_mask, hidden_states, cache_position, past_key_values, output_attentions
+            )
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = None
+
+        for decoder_layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    causal_mask,
+                    position_ids,
+                    past_key_values,
+                    output_attentions,
+                    use_cache,
+                    cache_position,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=causal_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_values,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                    cache_position=cache_position,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+    # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask with Phi3->Mimi
+    def _update_causal_mask(
+        self,
+        attention_mask: ms.Tensor,
+        input_tensor: ms.Tensor,
+        cache_position: ms.Tensor,
+        past_key_values: Cache,
+        output_attentions: bool,
+    ):
+        if self.config._attn_implementation == "flash_attention_2":
+            if attention_mask is not None and past_key_values is not None:
+                is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.shape[0]
+                if is_padding_right:
+                    raise ValueError(
+                        "You are attempting to perform batched generation with padding_side='right'"
+                        " this may lead to unexpected behaviour for Flash Attention version of Mimi. Make sure to "
+                        " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
+                    )
+            if attention_mask is not None and 0.0 in attention_mask:
+                return attention_mask
+            return None
+
+        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+        # to infer the attention mask.
+        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+        using_static_cache = isinstance(past_key_values, StaticCache)
+        using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
+
+        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+        if (
+            self.config._attn_implementation == "sdpa"
+            and not (using_static_cache or using_sliding_window_cache)
+            and not output_attentions
+        ):
+            if AttentionMaskConverter._ignore_causal_mask_sdpa(
+                attention_mask,
+                inputs_embeds=input_tensor,
+                past_key_values_length=past_seen_tokens,
+                sliding_window=self.config.sliding_window,
+                is_training=self.training,
+            ):
+                return None
+
+        dtype, device = input_tensor.dtype, ms.get_context('device_target')
+        min_dtype = ops.finfo(dtype).min
+        sequence_length = input_tensor.shape[1]
+        # SlidingWindowCache or StaticCache
+        if using_sliding_window_cache or using_static_cache:
+            target_length = past_key_values.get_max_cache_shape()
+        # DynamicCache or no cache
+        else:
+            target_length = (
+                attention_mask.shape[-1]
+                if isinstance(attention_mask, ms.Tensor)
+                else past_seen_tokens + sequence_length + 1
+            )
+
+        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+            attention_mask,
+            sequence_length=sequence_length,
+            target_length=target_length,
+            dtype=dtype,
+            device=device,
+            cache_position=cache_position,
+            batch_size=input_tensor.shape[0],
+            config=self.config,
+            past_key_values=past_key_values,
+        )
+        #mindspore.get_context('device_target') == 'Ascend'
+        if (
+            self.config._attn_implementation == "sdpa"
+            and attention_mask is not None
+            and ms.get_context('device_target') == "Ascend"
+            and not output_attentions
+        ):
+            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+            # Details: https://github.com/pytorch/pytorch/issues/110213
+            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+        return causal_mask
+
+    @staticmethod
+    # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mimi
+    def _prepare_4d_causal_attention_mask_with_cache_position(
+        attention_mask: ms.Tensor,
+        sequence_length: int,
+        target_length: int,
+        dtype: ms.dtype,
+        device: str,
+        cache_position: ms.Tensor,
+        batch_size: int,
+        config: MimiConfig,
+        past_key_values: Cache,
+    ):
+        """
+        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+        Args:
+            attention_mask (`torch.Tensor`):
+                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
+            sequence_length (`int`):
+                The sequence length being processed.
+            target_length (`int`):
+                The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
+            dtype (`torch.dtype`):
+                The dtype to use for the 4D attention mask.
+            device (`torch.device`):
+                The device to plcae the 4D attention mask on.
+            cache_position (`torch.Tensor`):
+                Indices depicting the position of the input sequence tokens in the sequence.
+            batch_size (`torch.Tensor`):
+                Batch size.
+            config (`MimiConfig`):
+                The model's configuration class
+            past_key_values (`Cache`):
+                The cache class that is being used currently to generate
+        """
+        if attention_mask is not None and attention_mask.dim() == 4:
+            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+            causal_mask = attention_mask
+        else:
+            min_dtype = ops.finfo(dtype).min
+            causal_mask = ops.full(
+                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
+            )
+            diagonal_attend_mask = ops.arange(target_length) > cache_position.reshape(-1, 1)
+            if config.sliding_window is not None:
+                # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
+                # the check is needed to verify is current checkpoint was trained with sliding window or not
+                if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
+                    sliding_attend_mask = ops.arange(target_length) <= (
+                        cache_position.reshape(-1, 1) - config.sliding_window
+                    )
+                    diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
+            causal_mask *= diagonal_attend_mask
+            causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, 1, -1, -1))
+            if attention_mask is not None:
+                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
+                if attention_mask.shape[-1] > target_length:
+                    attention_mask = attention_mask[:, :target_length]
+                mask_length = attention_mask.shape[-1]
+                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+                padding_mask = padding_mask == 0
+                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+                    padding_mask, min_dtype
+                )
+        return causal_mask
+
+
+class MimiDecoder(nn.Module):
+    """SEANet decoder as used by Mimi."""
+
+    def __init__(self, config: MimiConfig):
+        super().__init__()
+        scaling = int(2 ** len(config.upsampling_ratios))
+        model = [MimiConv1d(config, config.hidden_size, scaling * config.num_filters, config.kernel_size)]
+
+        # Upsample to raw audio scale
+        for ratio in config.upsampling_ratios:
+            current_scale = scaling * config.num_filters
+            # Add upsampling layers
+            model += [nn.ELU()]
+            model += [
+                MimiConvTranspose1d(config, current_scale, current_scale // 2, kernel_size=ratio * 2, stride=ratio)
+            ]
+            # Add residual layers
+            for j in range(config.num_residual_layers):
+                model += [MimiResnetBlock(config, current_scale // 2, (config.dilation_growth_rate**j, 1))]
+            scaling //= 2
+
+        # Add final layers  
+        model += [nn.ELU()]
+        model += [MimiConv1d(config, config.num_filters, config.audio_channels, config.last_kernel_size)]
+        self.layers = nn.ModuleList(model)
+
+    # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoder.forward MimiModel
+    def forward(self, hidden_states):
+        for layer in self.layers:
+            hidden_states = layer(hidden_states)
+        return hidden_states
+
+
+class MimiEuclideanCodebook(nn.Module):
+    """Codebook with Euclidean distance."""
+
+    def __init__(self, config: MimiConfig, epsilon: float = 1e-5):
+        super().__init__()
+        embed = ops.zeros(config.codebook_size, config.codebook_dim)
+
+        self.codebook_size = config.codebook_size
+
+        self.register_buffer("initialized", ms.Tensor([True]))
+        self.register_buffer("cluster_usage", ops.ones(config.codebook_size))
+        self.register_buffer("embed_sum", embed)
+        self._embed = None
+        self.epsilon = epsilon
+
+    @property
+    def embed(self) -> ms.Tensor:
+        if self._embed is None:
+            self._embed = self.embed_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
+        return self._embed
+
+    def quantize(self, hidden_states):
+        # Projects each vector in `hidden_states` over the nearest centroid and return its index.
+        # `hidden_states` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
+        dists = ops.cdist(hidden_states[None], self.embed[None], p=2)[0]
+        embed_ind = dists.argmin(axis=-1)
+        return embed_ind
+
+    # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.encode
+    def encode(self, hidden_states):
+        shape = hidden_states.shape
+        # pre-process
+        hidden_states = hidden_states.reshape((-1, shape[-1]))
+        # quantize
+        embed_ind = self.quantize(hidden_states)
+        # post-process
+        embed_ind = embed_ind.view(*shape[:-1])
+        return embed_ind
+
+    # Copied from transformers.models.encodec.modeling_encodec.EncodecEuclideanCodebook.decode
+    def decode(self, embed_ind):
+        quantize = nn.functional.embedding(embed_ind, self.embed)
+        return quantize
+
+
+# Copied from transformers.models.encodec.modeling_encodec.EncodecVectorQuantization with Encodec->Mimi
+class MimiVectorQuantization(nn.Module):
+    """
+    Vector quantization implementation. Currently supports only euclidean distance.
+    """
+
+    def __init__(self, config: MimiConfig):
+        super().__init__()
+        self.codebook = MimiEuclideanCodebook(config)
+
+    def encode(self, hidden_states):
+        hidden_states = hidden_states.permute(0, 2, 1)
+        embed_in = self.codebook.encode(hidden_states)
+        return embed_in
+
+    def decode(self, embed_ind):
+        quantize = self.codebook.decode(embed_ind)
+        quantize = quantize.permute(0, 2, 1)
+        return quantize
+
+
+class MimiResidualVectorQuantizer(nn.Module):
+    """Residual Vector Quantizer."""
+
+    def __init__(self, config: MimiConfig, num_quantizers: int = None):
+        super().__init__()
+        self.codebook_size = config.codebook_size
+        self.frame_rate = config.frame_rate
+        self.num_quantizers = num_quantizers if num_quantizers is not None else config.num_quantizers
+        self.layers = nn.ModuleList([MimiVectorQuantization(config) for _ in range(self.num_quantizers)])
+
+        self.input_proj = None
+        self.output_proj = None
+        if config.vector_quantization_hidden_dimension != config.hidden_size:
+            self.input_proj = nn.Conv1d(
+                config.hidden_size, config.vector_quantization_hidden_dimension, 1, bias=False
+            )
+            self.output_proj = nn.Conv1d(
+                config.vector_quantization_hidden_dimension, config.hidden_size, 1, bias=False
+            )
+
+    def encode(self, embeddings: ms.Tensor, num_quantizers: Optional[int] = None) -> ms.Tensor:
+        """
+        Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets
+        the appropriate number of quantizers to use and returns indices for each quantizer.
+        """
+        if self.input_proj is not None:
+            embeddings = self.input_proj(embeddings)
+
+        num_quantizers = num_quantizers if num_quantizers is not None else self.num_quantizers
+
+        residual = embeddings
+        all_indices = []
+        for layer in self.layers[:num_quantizers]:
+            indices = layer.encode(residual)
+            quantized = layer.decode(indices)
+            residual = residual - quantized
+            all_indices.append(indices)
+        out_indices = ops.stack(all_indices)
+        return out_indices
+
+    def decode(self, codes: ms.Tensor) -> ms.Tensor:
+
+        """Decode the given codes of shape [B, K, T] to the quantized representation."""
+        device = ms.get_context('device_target')
+        quantized_out = ms.tensor(0.0)
+        codes = codes.transpose((1,0,2))
+        for i, indices in enumerate(codes):
+            layer = self.layers[i]
+            quantized = layer.decode(indices)
+            quantized_out = quantized_out + quantized
+
+        if self.output_proj is not None:
+            quantized_out = self.output_proj(quantized_out)
+        return quantized_out
+
+
+class MimiSplitResidualVectorQuantizer(nn.Module):
+    """Split Residual Vector Quantizer."""
+
+    def __init__(self, config: MimiConfig):
+        super().__init__()
+        self.codebook_size = config.codebook_size
+        self.frame_rate = config.frame_rate
+        self.max_num_quantizers = config.num_quantizers
+
+        self.num_semantic_quantizers = config.num_semantic_quantizers
+        self.num_acoustic_quantizers = config.num_quantizers - config.num_semantic_quantizers
+
+        self.semantic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_semantic_quantizers)
+        self.acoustic_residual_vector_quantizer = MimiResidualVectorQuantizer(config, self.num_acoustic_quantizers)
+
+    def encode(self, embeddings: ms.Tensor, num_quantizers: Optional[float] = None) -> ms.Tensor:
+        """
+        Encode a given input tensor with the specified frame rate at the given number of quantizers / codebooks. The RVQ encode method sets
+        the appropriate number of quantizers to use and returns indices for each quantizer.
+        """
+
+        num_quantizers = self.max_num_quantizers if num_quantizers is None else num_quantizers
+
+        if num_quantizers > self.max_num_quantizers:
+            raise ValueError(
+                f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.max_num_quantizers}, but is currently {num_quantizers}."
+            )
+
+        if num_quantizers < self.num_semantic_quantizers:
+            raise ValueError(
+                f"The number of quantizers (i.e codebooks) asked should be higher than the number of semantic quantizers {self.num_semantic_quantizers}, but is currently {num_quantizers}."
+            )
+
+        # codes is [K, B, T], with T frames, K nb of codebooks.
+        codes = self.semantic_residual_vector_quantizer.encode(embeddings)
+
+        if num_quantizers > self.num_semantic_quantizers:
+            acoustic_codes = self.acoustic_residual_vector_quantizer.encode(
+                embeddings, num_quantizers=num_quantizers - self.num_semantic_quantizers
+            )
+            codes = ops.cat([codes, acoustic_codes], dim=0)
+
+        return codes
+
+    def decode(self, codes: ms.Tensor) -> ms.Tensor:
+        """Decode the given codes to the quantized representation."""
+
+        # The first num_semantic_quantizers codebooks are decoded using the semantic RVQ
+        quantized_out = self.semantic_residual_vector_quantizer.decode(codes[:, : self.num_semantic_quantizers])
+
+        # The rest of the codebooks are decoded using the acoustic RVQ
+        if codes.shape[1] > self.num_semantic_quantizers:
+            quantized_out += self.acoustic_residual_vector_quantizer.decode(codes[:, self.num_semantic_quantizers :])
+        return quantized_out
+
+
+class MimiPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = MimiConfig
+    base_model_prefix = "mimi"
+    main_input_name = "input_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["MimiDecoderLayer"]
+    _skip_keys_device_placement = "past_key_values"
+    _supports_flash_attn_2 = True
+    _supports_sdpa = True
+    _supports_cache_class = True
+    _supports_static_cache = True
+
+    # Copied from transformers.models.encodec.modeling_encodec.EncodecPreTrainedModel._init_weights
+    def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, nn.GroupNorm,nn.Conv1d, nn.Embedding,nn.LSTM]):
+        """Initialize the weights"""
+        if isinstance(module, nn.Linear):
+            module.weight.assign_value(
+                initializer(
+                    TruncatedNormal(sigma=self.config.initializer_range, mean=0.0),
+                    module.weight.shape,
+                    module.weight.dtype,
+                )
+            )
+            if module.bias is not None:
+                module.bias.assign_value(
+                    initializer(
+                        "zeros",
+                        module.bias.shape,
+                        module.bias.dtype,
+                    )
+                )
+        elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+            module.bias.assign_value(
+                initializer(
+                    "zeros",
+                    module.bias.shape,
+                    module.bias.dtype,
+                )
+            )
+            module.weight.assign_value(
+                initializer("ones", module.weight.shape, module.weight.dtype)
+            )
+        elif isinstance(module, nn.Conv1d):
+            nn.init.kaiming_normal_(module.weight)
+            if module.bias is not None:
+                k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
+                nn.init.uniform_(module.bias, a=-k, b=k)
+        elif isinstance(module, nn.Embedding):
+            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
+            if module.padding_idx is not None:
+                module.weight[module.padding_idx] = 0
+        elif isinstance(module, nn.LSTM):
+            for name, param in module.named_parameters():
+                if "weight" in name:
+                    nn.init.xavier_uniform_(param)
+                elif "bias" in name:
+                    nn.init.constant_(param, 0.0)
+
+
+
+
+
+MIMI_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+    Parameters:
+        config ([`MimiConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+MIMI_INPUTS_DOCSTRING = r"""
+    Args:
+        input_values (`torch.FloatTensor` of shape `(batch_size, channels, sequence_length)`, *optional*):
+            Raw audio input converted to Float.
+        padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
+            for *masked*.
+        num_quantizers (`int`, *optional*):
+            Number of quantizers (i.e codebooks) to use. By default, all quantizers are used.
+        audio_codes (`torch.LongTensor`  of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
+            Discret code embeddings computed using `model.encode`.
+        encoder_past_key_values (`Cache`, *optional*):
+            Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
+            This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+            The model will output the same cache format that is fed as input.
+            If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+            have their past key value states given to this model).
+        decoder_past_key_values (`Cache`, *optional*):
+            Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
+            This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+            The model will output the same cache format that is fed as input.
+            If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+            have their past key value states given to this model).
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# @add_start_docstrings(
+#     "The Mimi neural audio codec model.",
+#     MIMI_START_DOCSTRING,
+# )
+class MimiModel(MimiPreTrainedModel):
+    def __init__(self, config: MimiConfig):
+        super().__init__(config)
+        self.config = config
+
+        self.encoder = MimiEncoder(config)
+        self.encoder_transformer = MimiTransformerModel(config)
+
+        self.downsample = None
+        self.upsample = None
+        if config.frame_rate != config.encodec_frame_rate:
+            self.downsample = MimiConv1d(
+                config,
+                config.hidden_size,
+                config.hidden_size,
+                kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate),
+                stride=2,
+                bias=False,
+                pad_mode="replicate",
+            )
+
+            self.upsample = MimiConvTranspose1d(
+                config,
+                config.hidden_size,
+                config.hidden_size,
+                kernel_size=2 * int(config.encodec_frame_rate / config.frame_rate),
+                stride=2,
+                bias=False,
+                groups=config.upsample_groups,
+            )
+
+        self.decoder_transformer = MimiTransformerModel(config)
+        self.decoder = MimiDecoder(config)
+
+        self.quantizer = MimiSplitResidualVectorQuantizer(config)
+
+        self.bits_per_codebook = int(math.log2(self.config.codebook_size))
+        if 2**self.bits_per_codebook != self.config.codebook_size:
+            raise ValueError("The codebook_size must be a power of 2.")
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_encoder(self):
+        return self.encoder
+
+    def get_decoder(self):
+        return self.decoder
+
+    def _encode_frame(
+        self,
+        input_values: ms.Tensor,
+        num_quantizers: int,
+        padding_mask: int,
+        past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Tuple[ms.Tensor, Optional[ms.Tensor]]:
+        """
+        Encodes the given input using the underlying VQVAE. The padding mask is required to compute the correct scale.
+        """
+        embeddings = self.encoder(input_values)
+        embeddings = embeddings.transpose((0, 2, 1))
+        encoder_outputs = self.encoder_transformer(
+            embeddings, past_key_values=past_key_values, return_dict=return_dict
+        )
+        if return_dict:
+            past_key_values = encoder_outputs.get("past_key_values")
+        elif len(encoder_outputs) > 1:
+            past_key_values = encoder_outputs[1]
+        embeddings = encoder_outputs[0].transpose((0, 2, 1))
+        embeddings = self.downsample(embeddings)
+
+        codes = self.quantizer.encode(embeddings, num_quantizers)
+        codes = codes.transpose((1,0, 2))
+        return codes, past_key_values
+
+    def encode(
+        self,
+        input_values: ms.Tensor,
+        padding_mask: ms.Tensor = None,
+        num_quantizers: Optional[float] = None,
+        encoder_past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple[ms.Tensor, Optional[ms.Tensor]], MimiEncoderOutput]:
+        """
+        Encodes the input audio waveform into discrete codes.
+        Args:
+            input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+                Float values of the input audio waveform.
+            padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+                Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
+                for *masked*.
+            num_quantizers (`int`, *optional*):
+                Number of quantizers (i.e codebooks) to use. By default, all quantizers are used.
+            encoder_past_key_values (`Cache`, *optional*):
+                Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the encoder transformer.
+                This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+                The model will output the same cache format that is fed as input.
+                If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+                have their past key value states given to this model).
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        Returns:
+            `codebook` of shape `[batch_size, num_codebooks, frames]`, the discrete encoded codes for the input audio waveform.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        num_quantizers = self.config.num_quantizers if num_quantizers is None else num_quantizers
+
+        if num_quantizers > self.config.num_quantizers:
+            raise ValueError(
+                f"The number of quantizers (i.e codebooks) asked should be lower than the total number of quantizers {self.config.num_quantizers}, but is currently {num_quantizers}."
+            )
+
+        _, channels, input_length = input_values.shape
+
+        if channels < 1 or channels > 2:
+            raise ValueError(f"Number of audio channels must be 1 or 2, but got {channels}")
+
+        if padding_mask is None:
+            padding_mask = torch.ones_like(input_values).bool()
+
+        encoded_frames, encoder_past_key_values = self._encode_frame(
+            input_values,
+            num_quantizers,
+            padding_mask.bool(),
+            past_key_values=encoder_past_key_values,
+            return_dict=return_dict,
+        )
+
+
+        if not return_dict:
+            return (
+                encoded_frames,
+                encoder_past_key_values,
+            )
+
+        return MimiEncoderOutput(encoded_frames, encoder_past_key_values)
+
+    def _decode_frame(
+        self,
+        codes: ms.Tensor,
+        past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None,
+        return_dict: Optional[bool] = None,
+    ) -> ms.Tensor:
+        embeddings = self.quantizer.decode(codes)
+        embeddings = self.upsample(embeddings)
+        decoder_outputs = self.decoder_transformer(
+            embeddings.transpose((0, 2, 1)), past_key_values=past_key_values, return_dict=return_dict
+        )
+        if return_dict:
+            past_key_values = decoder_outputs.get("past_key_values")
+        elif len(decoder_outputs) > 1:
+            past_key_values = decoder_outputs[1]
+        embeddings = decoder_outputs[0].transpose((0, 2, 1))
+        outputs = self.decoder(embeddings)
+        return outputs, past_key_values
+
+    def decode(
+        self,
+        audio_codes: ms.Tensor,
+        padding_mask: Optional[ms.Tensor] = None,
+        decoder_past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple[ms.Tensor, ms.Tensor], MimiDecoderOutput]:
+        """
+        Decodes the given frames into an output audio waveform.
+        Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
+        trimmed.
+        Args:
+            audio_codes (`torch.LongTensor`  of shape `(batch_size, num_quantizers, codes_length)`, *optional*):
+                Discret code embeddings computed using `model.encode`.
+            padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
+                Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
+                for *masked*.
+            decoder_past_key_values (`Cache`, *optional*):
+                Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding of the decoder transformer.
+                This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+                The model will output the same cache format that is fed as input.
+                If `past_key_values` are used, the user can optionally input only the last `audio_values` or `audio_codes (those that don't
+                have their past key value states given to this model).
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+        audio_values, decoder_past_key_values = self._decode_frame(
+            audio_codes, past_key_values=decoder_past_key_values, return_dict=return_dict
+        )
+
+        # truncate based on padding mask
+        if padding_mask is not None and padding_mask.shape[-1] < audio_values.shape[-1]:
+            audio_values = audio_values[..., : padding_mask.shape[-1]]
+
+        if not return_dict:
+            return (
+                audio_values,
+                decoder_past_key_values,
+            )
+        return MimiDecoderOutput(audio_values, decoder_past_key_values)
+
+    # @add_start_docstrings_to_model_forward(MIMI_INPUTS_DOCSTRING)
+    # @replace_return_docstrings(output_type=MimiOutput, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_values: ms.Tensor,
+        padding_mask: Optional[ms.Tensor] = None,
+        num_quantizers: Optional[int] = None,
+        audio_codes: Optional[ms.Tensor] = None,
+        encoder_past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None,
+        decoder_past_key_values: Optional[Union[Cache, List[ms.Tensor]]] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple[ms.Tensor, ms.Tensor], MimiOutput]:
+        r"""
+        Returns:
+        Examples:
+        ```python
+        >>> from mindnlp.transformers import MimiModel, MimiConfig
+        >>> from mindnlp.transformers import AutoFeatureExtractor
+        >>> # Initializing a "kyutai/mimi" style configuration
+        >>> configuration = MimiConfig()
+        >>> # Initializing a model (with random weights) from the "kyutai/mimi" style configuration
+        >>> model = MimiModel(configuration)
+        >>> # Accessing the model configuration
+        >>> configuration = model.config
+        >>> from datasets import load_dataset
+        >>> dataset = load_dataset("hf-internal-testing/ashraq-esc50-1-dog-example")
+        >>> audio_sample = dataset["train"]["audio"][0]["array"]
+        >>> model_id = "kyutai/mimi"
+        >>> model = MimiModel.from_pretrained(model_id, mirror="modelscope")
+        >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
+        >>> inputs = feature_extractor(raw_audio=audio_sample, return_tensors="ms")
+        >>> outputs = model(**inputs)
+        >>> audio_codes = outputs.audio_codes
+        >>> audio_values = outputs.audio_values
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.return_dict
+
+        if padding_mask is None:
+            padding_mask = ops.ones_like(input_values).bool()
+        if audio_codes is None:
+            encoder_outputs = self.encode(
+                input_values, padding_mask, num_quantizers, encoder_past_key_values, return_dict=return_dict
+            )
+
+            audio_codes = encoder_outputs[0]
+            if return_dict:
+                encoder_past_key_values = encoder_outputs.get("past_key_values")
+            elif len(encoder_outputs) > 1:
+                encoder_past_key_values = encoder_outputs[1]
+        decoder_outputs = self.decode(audio_codes, padding_mask, decoder_past_key_values, return_dict=return_dict)
+        audio_values = decoder_outputs[0]
+        if return_dict:
+            decoder_past_key_values = decoder_outputs.get("past_key_values")
+        elif len(decoder_outputs) > 1:
+            decoder_past_key_values = decoder_outputs[1]
+
+        if not return_dict:
+            return (audio_codes, audio_values, encoder_past_key_values, decoder_past_key_values)
+
+        return MimiOutput(
+            audio_codes=audio_codes,
+            audio_values=audio_values,
+            encoder_past_key_values=encoder_past_key_values,
+            decoder_past_key_values=decoder_past_key_values,
+        )
+
+
+__all__ = ["MimiModel", "MimiPreTrainedModel"]
\ No newline at end of file
diff --git a/tests/transformers/models/mini/__init__.py b/tests/transformers/models/mini/__init__.py
new file mode 100644
index 000000000..61fce3f73
--- /dev/null
+++ b/tests/transformers/models/mini/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+Mimi Model.
+"""
+from . import modeling_mimi, configuration_mimi
+from .modeling_mimi import *
+from .configuration_mimi import *
+
+__all__ = []
+__all__.extend(modeling_mimi.__all__)
+__all__.extend(configuration_mimi.__all__)
\ No newline at end of file
diff --git a/tests/transformers/models/mini/unit_test_mini.py b/tests/transformers/models/mini/unit_test_mini.py
new file mode 100644
index 000000000..a44a2591d
--- /dev/null
+++ b/tests/transformers/models/mini/unit_test_mini.py
@@ -0,0 +1,805 @@
+import inspect
+import os
+import tempfile
+import unittest
+
+import numpy as np
+from datasets import Audio, load_dataset
+from parameterized import parameterized
+from pytest import mark
+from mindspore import Tensor,ops,nn,dtype
+from mindnlp.transformers import AutoFeatureExtractor, MimiConfig
+from mindnlp.utils.testing_utils import (
+    is_flaky,
+    require_flash_attn,
+    is_mindspore_available,
+    slow,
+)
+
+
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
+
+def allclose(valueA,valueB,atol,rtol):
+    return ops.all(ops.abs(valueA-valueB)<atol)
+
+def assert_close(valueA,valueB,atol=1e-5,rtol=1e-5):
+    return allclose(valueA,valueB,atol,rtol)
+
+if is_mindspore_available():
+    import mindspore as ms
+
+    from mindnlp.transformers import MimiModel
+
+def prepare_inputs_dict(
+    config,
+    input_ids=None,
+    input_values=None,
+    decoder_input_ids=None,
+    attention_mask=None,
+    decoder_attention_mask=None,
+    head_mask=None,
+    decoder_head_mask=None,
+    cross_attn_head_mask=None,
+):
+    if input_ids is not None:
+        encoder_dict = {"input_ids": input_ids}
+    else:
+        encoder_dict = {"input_values": input_values}
+
+    decoder_dict = {"decoder_input_ids": decoder_input_ids} if decoder_input_ids is not None else {}
+
+    return {**encoder_dict, **decoder_dict}
+
+@require_mindspore
+class MimiModelTester:
+    def __init__(
+        self,
+        parent,
+        batch_size=5,
+        num_channels=1,
+        is_training=False,
+        intermediate_size=40,
+        hidden_size=32,
+        num_filters=8,
+        num_residual_layers=1,
+        upsampling_ratios=[8, 4],
+        codebook_size=64,
+        vector_quantization_hidden_dimension=64,
+        codebook_dim=64,
+        upsample_groups=32,
+        num_hidden_layers=2,
+        num_attention_heads=2,
+        num_key_value_heads=2,
+        sliding_window=4,
+        use_cache=False,
+    ):
+        self.parent = parent
+        self.batch_size = batch_size
+        self.num_channels = num_channels
+        self.is_training = is_training
+        self.intermediate_size = intermediate_size
+        self.hidden_size = hidden_size
+        self.num_filters = num_filters
+        self.num_residual_layers = num_residual_layers
+        self.upsampling_ratios = upsampling_ratios
+        self.codebook_size = codebook_size
+        self.vector_quantization_hidden_dimension = vector_quantization_hidden_dimension
+        self.codebook_dim = codebook_dim
+        self.upsample_groups = upsample_groups
+        self.num_hidden_layers = num_hidden_layers
+        self.num_attention_heads = num_attention_heads
+        self.num_key_value_heads = num_key_value_heads
+        self.sliding_window = sliding_window
+        self.use_cache = use_cache
+
+    def prepare_config_and_inputs(self):
+        input_values = floats_tensor([self.batch_size, self.num_channels, self.intermediate_size], scale=1.0)
+        config = self.get_config()
+        inputs_dict = {"input_values": input_values}
+        return config, inputs_dict
+
+    def prepare_config_and_inputs_for_common(self):
+        config, inputs_dict = self.prepare_config_and_inputs()
+        return config, inputs_dict
+
+    def prepare_config_and_inputs_for_model_class(self, model_class):
+        config, inputs_dict = self.prepare_config_and_inputs()
+        inputs_dict["audio_codes"] = ids_tensor([self.batch_size, 1, self.num_channels], self.codebook_size).type(
+            ms.int32
+        )
+
+        return config, inputs_dict
+
+    def get_config(self):
+        return MimiConfig(
+            audio_channels=self.num_channels,
+            chunk_in_sec=None,
+            hidden_size=self.hidden_size,
+            num_filters=self.num_filters,
+            num_residual_layers=self.num_residual_layers,
+            upsampling_ratios=self.upsampling_ratios,
+            codebook_size=self.codebook_size,
+            vector_quantization_hidden_dimension=self.vector_quantization_hidden_dimension,
+            upsample_groups=self.upsample_groups,
+            num_hidden_layers=self.num_hidden_layers,
+            num_attention_heads=self.num_attention_heads,
+            num_key_value_heads=self.num_key_value_heads,
+            sliding_window=self.sliding_window,
+            codebook_dim=self.codebook_dim,
+            use_cache=self.use_cache,
+        )
+
+    def create_and_check_model_forward(self, config, inputs_dict):
+        model = MimiModel(config=config)
+
+        input_values = inputs_dict["input_values"]
+        result = model(input_values)
+        self.parent.assertEqual(
+            result.audio_values.shape, (self.batch_size, self.num_channels, self.intermediate_size)
+        )
+
+@require_mindspore
+class MimiModelTest(ModelTesterMixin, unittest.TestCase):
+    all_model_classes = (MimiModel,) if is_mindspore_available() else ()
+    is_encoder_decoder = True
+    test_pruning = False
+    test_headmasking = False
+    test_resize_embeddings = False
+    test_torchscript = False
+
+    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
+        # model does support returning hidden states
+        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
+        if "output_attentions" in inputs_dict:
+            inputs_dict.pop("output_attentions")
+        if "output_hidden_states" in inputs_dict:
+            inputs_dict.pop("output_hidden_states")
+        return inputs_dict
+
+    def setUp(self):
+        self.model_tester = MimiModelTester(self)
+        self.config_tester = ConfigTester(
+            self, config_class=MimiConfig, hidden_size=37, common_properties=[], has_text_modality=False
+        )
+
+    def test_config(self):
+        self.config_tester.run_common_tests()
+
+    def test_model_forward(self):
+        config_and_inputs = self.model_tester.prepare_config_and_inputs()
+        self.model_tester.create_and_check_model_forward(*config_and_inputs)
+
+    def test_forward_signature(self):
+        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+        for model_class in self.all_model_classes:
+            model = model_class(config)
+            signature = inspect.signature(model.forward)
+            # signature.parameters is an OrderedDict => so arg_names order is deterministic
+            arg_names = [*signature.parameters.keys()]
+
+            expected_arg_names = ["input_values", "padding_mask", "num_quantizers"]
+            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
+
+    @unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics")
+    def test_inputs_embeds(self):
+        pass
+
+    @unittest.skip(reason="The MimiModel does not have `inputs_embeds` logics")
+    def test_model_get_set_embeddings(self):
+        pass
+
+    @unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
+    def test_retain_grad_hidden_states_attentions(self):
+        pass
+
+    @unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
+    def test_torchscript_output_attentions(self):
+        pass
+
+    @unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic")
+    def test_torchscript_output_hidden_state(self):
+        pass
+
+    def _create_and_check_torchscript(self, config, inputs_dict):
+        if not self.test_torchscript:
+            self.skipTest(reason="test_torchscript is set to False")
+
+        configs_no_init = _config_zero_init(config)  # To be sure we have no Nan
+        configs_no_init.torchscript = True
+        configs_no_init.return_dict = False
+        for model_class in self.all_model_classes:
+            model = model_class(config=configs_no_init)
+            inputs = self._prepare_for_class(inputs_dict, model_class)
+
+            main_input_name = model_class.main_input_name
+
+            try:
+                main_input = inputs[main_input_name]
+                model(main_input)
+                traced_model = ms.jit.trace(model, main_input)
+            except RuntimeError:
+                self.fail("Couldn't trace module.")
+
+            with tempfile.TemporaryDirectory() as tmp_dir_name:
+                pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
+
+                try:
+                    ms.jit.save(traced_model, pt_file_name)
+                except Exception:
+                    self.fail("Couldn't save module.")
+
+                try:
+                    loaded_model = ms.jit.load(pt_file_name)
+                except Exception:
+                    self.fail("Couldn't load module.")
+
+
+
+
+            model_state_dict = model.state_dict()
+            loaded_model_state_dict = loaded_model.state_dict()
+
+            non_persistent_buffers = {}
+            for key in loaded_model_state_dict.keys():
+                if key not in model_state_dict.keys():
+                    non_persistent_buffers[key] = loaded_model_state_dict[key]
+
+            loaded_model_state_dict = {
+                key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
+            }
+
+            self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
+
+            model_buffers = list(model.buffers())
+            for non_persistent_buffer in non_persistent_buffers.values():
+                found_buffer = False
+                for i, model_buffer in enumerate(model_buffers):
+                    if ops.equal(non_persistent_buffer, model_buffer):
+                        found_buffer = True
+                        break
+
+                self.assertTrue(found_buffer)
+                model_buffers.pop(i)
+
+            model_buffers = list(model.buffers())
+            for non_persistent_buffer in non_persistent_buffers.values():
+                found_buffer = False
+                for i, model_buffer in enumerate(model_buffers):
+                    if ops.equal(non_persistent_buffer, model_buffer):
+                        found_buffer = True
+                        break
+
+                self.assertTrue(found_buffer)
+                model_buffers.pop(i)
+
+            models_equal = True
+            for layer_name, p1 in model_state_dict.items():
+                if layer_name in loaded_model_state_dict:
+                    p2 = loaded_model_state_dict[layer_name]
+                    if p1.data.ne(p2.data).sum() > 0:
+                        models_equal = False
+
+            self.assertTrue(models_equal)
+
+            # Avoid memory leak. Without this, each call increase RAM usage by ~20MB.
+            # (Even with this call, there are still memory leak by ~0.04MB)
+            self.clear_mindspore_ore_ore_jit_class_registry()
+
+    @unittest.skip(reason="The MimiModel does not have the usual `attention` logic")
+    def test_attention_outputs(self):
+        pass
+
+    @unittest.skip(reason="The MimiModel does not have the usual `hidden_states` logic")
+    def test_hidden_states_output(self):
+        pass
+
+    def test_determinism(self):
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+        def check_determinism(first, second):
+            # outputs are not tensors but list (since each sequence don't have the same frame_length)
+            out_1 = first.cpu().numpy()
+            out_2 = second.cpu().numpy()
+            out_1 = out_1[~np.isnan(out_1)]
+            out_2 = out_2[~np.isnan(out_2)]
+            max_diff = np.amax(np.abs(out_1 - out_2))
+            self.assertLessEqual(max_diff, 1e-5)
+
+        for model_class in self.all_model_classes:
+            model = model_class(config)
+            with ms.no_grad():
+                first = model(**self._prepare_for_class(inputs_dict, model_class))[0]
+                second = model(**self._prepare_for_class(inputs_dict, model_class))[0]
+
+            if isinstance(first, tuple) and isinstance(second, tuple):
+                for tensor1, tensor2 in zip(first, second):
+                    check_determinism(tensor1, tensor2)
+            else:
+                check_determinism(first, second)
+
+    def test_model_outputs_equivalence(self):
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+        def set_nan_tensor_to_zero(t):
+            t[t != t] = 0
+            return t
+
+        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
+            with ms.no_grad():
+                tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
+                dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs)
+
+                self.assertTrue(isinstance(tuple_output, tuple))
+                self.assertTrue(isinstance(dict_output, dict))
+
+                for tuple_value, dict_value in zip(tuple_output, dict_output.values()):
+                    self.assertTrue(
+                        allclose(
+                            set_nan_tensor_to_zero(tuple_value), set_nan_tensor_to_zero(dict_value), atol=1e-5
+                        ),
+                        msg=(
+                            "Tuple and dict output are not equal. Difference:"
+                            f" {ops.max(ops.abs(tuple_value - dict_value))}. Tuple has `nan`:"
+                            f" {ops.isnan(tuple_value).any()} and `inf`: {ops.isinf(tuple_value)}. Dict has"
+                            f" `nan`: {ops.isnan(dict_value).any()} and `inf`: {ops.isinf(dict_value)}."
+                        ),
+                    )
+
+        for model_class in self.all_model_classes:
+            model = model_class(config)
+            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
+            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
+            check_equivalence(model, tuple_inputs, dict_inputs)
+
+    def test_initialization(self):
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+        configs_no_init = _config_zero_init(config)
+        for model_class in self.all_model_classes:
+            model = model_class(config=configs_no_init)
+            for name, param in model.named_parameters():
+                uniform_init_parms = ["conv", "input_proj", "output_proj"]
+                if param.requires_grad:
+                    if any(x in name for x in uniform_init_parms):
+                        self.assertTrue(
+                            -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
+                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+                        )
+
+    def test_identity_shortcut(self):
+        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
+        config.use_conv_shortcut = False
+        self.model_tester.create_and_check_model_forward(config, inputs_dict)
+
+    # Overwrite to use `audio_values` as the tensors to compare.
+    # TODO: Try to do this in the parent class.
+    @parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
+    @require_mindspore_sdpa
+    def test_eager_matches_sdpa_inference(self, mindspore_dtype: str):
+
+        if not self.has_attentions:
+            self.skipTest(reason="Model architecture does not support attentions")
+
+        if not self.all_model_classes[0]._supports_sdpa:
+            self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
+
+        # Not sure whether it's fine to put mindspore.XXX in a decorator if mindspore is not available so hacking it here instead.
+        if ms_dtype == "float16":
+            ms_dtype = ms.float16
+        elif ms_dtype == "bfloat16":
+            ms_dtype = ms.float16
+        elif ms_dtype == "float32":
+            ms_dtype = ms.float32
+
+        atols = {
+            ("cpu", False, ms.float32): 1e-6,
+            ("cpu", False, ms.float16): 1e-2,
+            ("cpu", True, ms.float32): 1e-6,
+            ("cpu", True, ms.float16): 1e-2,
+            ("Ascend", False, ms.float32): 1e-6,
+            ("Ascend", False, ms.float16): 1e-2,
+            ("Ascend", False, ms.float16): 5e-3,
+            ("Ascend", True, ms.float32): 1e-6,
+            ("Ascend", True, ms.float16): 1e-2,
+            ("Ascend", True, ms.float16): 5e-3,
+        }
+        rtols = {
+            ("cpu", False, ms.float32): 1e-4,
+            ("cpu", False, ms.float16): 1e-2,
+            ("cpu", True, ms.float32): 1e-4,
+            ("cpu", True, ms.float16): 1e-2,
+            ("Ascend", False, ms.float32): 1e-4,
+            ("Ascend", False, ms.float16): 1e-2,
+            ("Ascend", False, ms.float16): 5e-3,
+            ("Ascend", True, ms.float32): 1e-4,
+            ("Ascend", True, ms.float16): 3e-2,
+            ("Ascend", True, ms.float16): 5e-3,
+        }
+
+        def get_mean_reldiff(failcase, x, ref, atol, rtol):
+            return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}"
+
+        for model_class in self.all_model_classes:
+            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+            model = model_class(config)
+            # FIXME: we deactivate boolean mask for models using "use_mask_token" in their constructors.
+            # These models support masking only in the case `use_mask_token=True`. Otherwise they cannot consume an input mask.
+            # This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code.
+            # However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it.
+            deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters
+
+            is_encoder_decoder = model.config.is_encoder_decoder
+
+            with tempfile.TemporaryDirectory() as tmpdirname:
+                model.save_pretrained(tmpdirname)
+                model_sdpa = model_class.from_pretrained(tmpdirname, mindspore_dtype=mindspore_dtype)
+                model_sdpa = model_sdpa
+
+                self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
+
+                model_eager = model_class.from_pretrained(
+                    tmpdirname,
+                    mindspore_dtype=mindspore_dtype,
+                    attn_implementation="eager",
+                )
+                model_eager = model_eager
+
+                self.assertTrue(model_eager.config._attn_implementation == "eager")
+
+                for name, submodule in model_eager.named_modules():
+                    class_name = submodule.__class__.__name__
+                    if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+                        raise ValueError("The eager model should not have SDPA attention layers")
+
+                has_sdpa = False
+                for name, submodule in model_sdpa.named_modules():
+                    class_name = submodule.__class__.__name__
+                    if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
+                        has_sdpa = True
+                        break
+                if not has_sdpa and model_sdpa.config.model_type != "falcon":
+                    raise ValueError("The SDPA model should have SDPA attention layers")
+
+                # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model,
+                # but it would be nicer to have an efficient way to use parameterized.expand
+                fail_cases = []
+                for padding_side in ["left", "right"]:
+                    for use_mask in [False, True]:
+                        for output_attentions in [True, False]:
+                            can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters
+                            if not (self.has_attentions and can_output_attn) and output_attentions:
+                                continue
+                            for batch_size in [7]:
+                                dummy_input = inputs_dict[model.main_input_name]
+
+                                if dummy_input.dtype in [ms.float32, ms.float16, ms.float16]:
+                                    dummy_input = dummy_input.to(mindspore_dtype)
+
+                                dummy_input = dummy_input[:batch_size]
+                                if dummy_input.shape[0] != batch_size:
+                                    if dummy_input.dtype in [ms.float32, ms.float16, ms.float16]:
+                                        extension = ops.rand(
+                                            batch_size - dummy_input.shape[0],
+                                            *dummy_input.shape[1:],
+                                            dtype=mindspore_dtype,
+                                        )
+                                        dummy_input = ops.cat((dummy_input, extension), dim=0)
+                                    else:
+                                        extension = ops.randint(
+                                            high=5,
+                                            size=(batch_size - dummy_input.shape[0], *dummy_input.shape[1:]),
+                                            dtype=dummy_input.dtype,
+                                        )
+                                        dummy_input = ops.cat((dummy_input, extension), dim=0)
+
+                                if not use_mask:
+                                    dummy_attention_mask = None
+                                else:
+                                    dummy_attention_mask = inputs_dict.get("attention_mask", None)
+                                    if dummy_attention_mask is None:
+                                        if is_encoder_decoder:
+                                            seqlen = inputs_dict.get("decoder_input_ids", dummy_input).shape[-1]
+                                        else:
+                                            seqlen = dummy_input.shape[-1]
+                                        dummy_attention_mask = (
+                                            dtype(ops.ones(batch_size, seqlen),(ms.int64))
+                                        )
+
+                                    dummy_attention_mask = dummy_attention_mask[:batch_size]
+                                    if dummy_attention_mask.shape[0] != batch_size:
+                                        extension = ops.ones(
+                                            batch_size - dummy_attention_mask.shape[0],
+                                            *dummy_attention_mask.shape[1:],
+                                            dtype=dummy_attention_mask.dtype,
+                                        )
+                                        dummy_attention_mask = ops.cat((dummy_attention_mask, extension), dim=0)
+
+                                    dummy_attention_mask[:] = 1
+                                    if padding_side == "left":
+                                        dummy_attention_mask[-1, :2] = 0
+                                        dummy_attention_mask[-1, 2:] = 1
+                                    elif padding_side == "right":
+                                        dummy_attention_mask[-1, -2:] = 0
+                                        dummy_attention_mask[-1, :-2] = 1
+
+                                for enable_kernels in [False, True]:
+                                    failcase = f"padding_side={padding_side}, use_mask={use_mask}, batch_size={batch_size}, enable_kernels={enable_kernels}"
+                                    if is_encoder_decoder:
+                                        decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[
+                                            :batch_size
+                                        ]
+                                        if decoder_input_ids.shape[0] != batch_size:
+                                            extension = ops.ones(
+                                                batch_size - decoder_input_ids.shape[0],
+                                                *decoder_input_ids.shape[1:],
+                                                dtype=decoder_input_ids.dtype,
+                                            )
+                                            decoder_input_ids = ops.cat((decoder_input_ids, extension), dim=0)
+
+                                        # TODO: never an `attention_mask` arg here?
+                                        processed_inputs = {
+                                            model.main_input_name: dummy_input,
+                                            "decoder_input_ids": decoder_input_ids,
+                                            "decoder_attention_mask": dummy_attention_mask,
+                                            "output_hidden_states": True,
+                                        }
+                                    else:
+                                        processed_inputs = {
+                                            model.main_input_name: dummy_input,
+                                            "output_hidden_states": True,
+                                        }
+
+                                        # Otherwise fails for e.g. WhisperEncoderModel
+                                        if "attention_mask" in inspect.signature(model_eager.forward).parameters:
+                                            processed_inputs["attention_mask"] = dummy_attention_mask
+
+                                        if (
+                                            self.has_attentions
+                                            and "output_attentions" in inspect.signature(model_sdpa.forward).parameters
+                                        ):
+                                            processed_inputs["output_attentions"] = output_attentions
+                                    if not deactivate_mask and (
+                                        "bool_masked_pos" in inspect.signature(model_eager.forward).parameters
+                                    ):
+                                        dummy_mask = ops.ones((self.model_tester.num_masks,))
+
+                                        # In case of additional token (like class) we define a custom `mask_length`
+                                        if hasattr(self.model_tester, "mask_length"):
+                                            mask_length = self.model_tester.mask_length - dummy_mask.size(0)
+                                        else:
+                                            mask_length = self.model_tester.seq_length - dummy_mask.size(0)
+                                        dummy_mask = ops.cat([dummy_mask, ms.zeros(mask_length)])
+                                        dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool()
+                                        processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos
+
+                                    if "noise" in inspect.signature(model_eager.forward).parameters:
+                                        np.random.seed(2)
+                                        num_patches = int(
+                                            (self.model_tester.image_size // self.model_tester.patch_size) ** 2
+                                        )
+                                        noise = np.random.uniform(size=(batch_size, num_patches))
+                                        processed_inputs["noise"] = Tensor(noise)
+
+                                    # TODO: test gradients as well (& for FA2 as well!)
+                                    with ms.no_grad():
+                                        with sdpa_kernel(
+                                            enable_flash=enable_kernels,
+                                            enable_math=True,
+                                            enable_mem_efficient=enable_kernels,
+                                        ):
+                                            prepared_inputs = self._prepare_for_class(processed_inputs, model_class)
+                                            outputs_eager = model_eager(**prepared_inputs)
+                                            outputs_sdpa = model_sdpa(**prepared_inputs)
+
+                                    # Ignore copy
+                                    logits_eager = outputs_eager.audio_values
+                                    # Ignore copy
+                                    logits_sdpa = outputs_sdpa.audio_values
+
+                                    if mindspore_device in ["cpu", "Ascend"]:
+                                        atol = atols[mindspore_device, enable_kernemindspore_rch_dtype]
+                                        rtol = rtols[mindspore_device, enable_kernemindspore_rch_dtype]
+                                    elif mindspore_device == "xpu":
+                                        # As of PyTorch 2.5 XPU backend supports only mindspore.nn.attention.SDPBackend.MATH
+                                        # which is implemented on PyTorch level using aten operators and is
+                                        # device agnostic with respect to implementation of each aten operator.
+                                        atol = atols["Ascend", False, mindspore_dtype]
+                                        rtol = rtols["Ascend", False, mindspore_dtype]
+                                    else:
+                                        atol = 1e-7
+                                        rtol = 1e-4
+
+                                    # Masked tokens output slightly deviates - we don't mind that.
+                                    if use_mask:
+                                        _logits_sdpa = ms.zeros_like(input=logits_sdpa)
+                                        _logits_eager = ms.zeros_like(input=logits_eager)
+
+                                        _logits_sdpa[:-1] = logits_sdpa[:-1]
+                                        _logits_eager[:-1] = logits_eager[:-1]
+
+                                        if padding_side == "left":
+                                            _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:]
+                                            _logits_eager[-1:, 2:] = logits_eager[-1:, 2:]
+
+                                        elif padding_side == "right":
+                                            _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2]
+                                            _logits_eager[-1:, 2:] = logits_eager[-1:, :-2]
+
+                                        logits_sdpa = _logits_sdpa
+                                        logits_eager = _logits_eager
+
+                                    results = [
+                                        allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol)
+                                        for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager)
+                                    ]
+                                    # If 80% batch elements have matched results, it's fine
+                                    if np.mean(results) < 0.8:
+                                        fail_cases.append(
+                                            get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol)
+                                        )
+
+                self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases))
+
+    @require_flash_attn
+    @require_mindspore_gpu
+    @mark.flash_attn_test
+    @slow
+    @is_flaky()
+    def test_flash_attn_2_inference_equivalence(self):
+        for model_class in self.all_model_classes:
+            config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+            model = model_class(config)
+
+            with tempfile.TemporaryDirectory() as tmpdirname:
+                model.save_pretrained(tmpdirname)
+                model_fa = model_class.from_pretrained(
+                    tmpdirname, mindspore_dtype=ms.float16, attn_implementation="flash_attention_2"
+                )
+
+                model = model_class.from_pretrained(tmpdirname, mindspore_dtype=ms.float16)
+
+                dummy_input = inputs_dict[model.main_input_name][:1]
+                if dummy_input.dtype in [ms.float32, ms.float16]:
+                    dummy_input = dummy_input.to(ms.float16)
+
+                outputs = model(dummy_input)
+                outputs_fa = model_fa(dummy_input)
+
+                logits = outputs[1]
+                logits_fa = outputs_fa[1]
+
+                assert allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+    @unittest.skip(reason="The MimiModel does not support right padding")
+    def test_flash_attn_2_inference_equivalence_right_padding(self):
+        pass
+
+    @unittest.skip(reason="The MimiModel does not have support dynamic compile yet")
+    def test_sdpa_can_compile_dynamic(self):
+        pass
+
+# Copied from transformers.tests.encodec.test_modeling_encodec.normalize
+def normalize(arr):
+    norm = np.linalg.norm(arr)
+    normalized_arr = arr / norm
+    return normalized_arr
+
+
+# Copied from transformers.tests.encodec.test_modeling_encodec.compute_rmse
+def compute_rmse(arr1, arr2):
+    arr1_normalized = normalize(arr1)
+    arr2_normalized = normalize(arr2)
+    return np.sqrt(((arr1_normalized - arr2_normalized) ** 2).mean())
+
+
+@slow
+@require_mindspore
+class MimiIntegrationTest(unittest.TestCase):
+    def test_integration_using_cache_decode(self):
+        expected_rmse = {
+            "8": 0.0018785292,
+            "32": 0.0012330565,
+        }
+
+        librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+        model_id = "kyutai/mimi"
+
+        model = MimiModel.from_pretrained(model_id, use_cache=True)
+        processor = AutoFeatureExtractor.from_pretrained(model_id)
+
+        librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
+        audio_sample = librispeech_dummy[-1]["audio"]["array"]
+
+        inputs = processor(
+            raw_audio=audio_sample,
+            sampling_rate=processor.sampling_rate,
+            return_tensors="ms",
+        )
+
+        for num_codebooks, expected_rmse in expected_rmse.items():
+            with ms.no_grad():
+                # use max bandwith for best possible reconstruction
+                encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks))
+
+                audio_codes = encoder_outputs[0]
+
+                decoder_outputs_first_part = model.decode(audio_codes[:, :, : audio_codes.shape[2] // 2])
+                decoder_outputs_second_part = model.decode(
+                    audio_codes[:, :, audio_codes.shape[2] // 2 :],
+                    decoder_past_key_values=decoder_outputs_first_part.decoder_past_key_values,
+                )
+
+                audio_output_entire_context = model.decode(audio_codes)[0]
+                audio_output_concat_context = ops.cat(
+                    [decoder_outputs_first_part[0], decoder_outputs_second_part[0]], dim=2
+                )
+
+            # make sure audios are more or less equal
+            # the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0
+            rmse = compute_rmse(
+                audio_output_concat_context.squeeze().cpu().numpy(),
+                audio_output_entire_context.squeeze().cpu().numpy(),
+            )
+            self.assertTrue(rmse < 1e-3)
+
+    def test_integration(self):
+        expected_rmses = {
+            "8": 0.0018785292,
+            "32": 0.0012330565,
+        }
+        expected_codesums = {
+            "8": 426176,
+            "32": 1795819,
+        }
+        librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+
+        model_id = "kyutai/mimi"
+
+        processor = AutoFeatureExtractor.from_pretrained(model_id)
+
+        librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
+        audio_sample = librispeech_dummy[-1]["audio"]["array"]
+
+        inputs = processor(
+            raw_audio=audio_sample,
+            sampling_rate=processor.sampling_rate,
+            return_tensors="ms",
+        )
+
+        for use_cache in [False, True]:
+            model = MimiModel.from_pretrained(model_id, use_cache=use_cache)
+            for num_codebooks, expected_rmse in expected_rmses.items():
+                with ms.no_grad():
+                    # use max bandwith for best possible reconstruction
+                    encoder_outputs = model.encode(inputs["input_values"], num_quantizers=int(num_codebooks))
+
+                    audio_code_sums = encoder_outputs[0].sum().cpu().item()
+
+                    # make sure audio encoded codes are correct
+                    # assert relative difference less than a threshold, because `audio_code_sums` varies a bit
+                    # depending on torch version
+                    self.assertTrue(
+                        np.abs(audio_code_sums - expected_codesums[num_codebooks]) <= (3e-3 * audio_code_sums)
+                    )
+
+                    input_values_dec = model.decode(encoder_outputs[0], padding_mask=inputs["padding_mask"])[0]
+                    input_values_enc_dec = model(
+                        inputs["input_values"], inputs["padding_mask"], num_quantizers=int(num_codebooks)
+                    )[1]
+
+                # make sure forward and decode gives same result
+                assert_close(input_values_dec, input_values_enc_dec)
+
+                # make sure shape matches
+                self.assertTrue(inputs["input_values"].shape == input_values_enc_dec.shape)
+
+                arr = inputs["input_values"][0].cpu().numpy()
+                arr_enc_dec = input_values_enc_dec[0].cpu().numpy()
+
+                # make sure audios are more or less equal
+                # the RMSE of two random gaussian noise vectors with ~N(0, 1) is around 1.0
+                rmse = compute_rmse(arr, arr_enc_dec)
+                self.assertTrue(np.abs(rmse - expected_rmse) < 1e-5)
\ No newline at end of file