diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..3810e7f
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,26 @@
+# Use the official Python 3.12 base image
+FROM python:3.12-slim
+
+# Set the working directory
+WORKDIR /app
+
+# Copy the requirements file into the container
+COPY requirements.txt .
+
+# Install the dependencies
+RUN pip install --no-cache-dir -r requirements.txt
+
+# Copy the rest of the application code into the container
+COPY . .
+
+# Expose the port FastAPI will run on
+EXPOSE 8000
+
+# Define environment variables
+ENV TMP_DIR=/app/tmp
+
+# Create the temporary directory
+RUN mkdir -p $TMP_DIR
+
+# Set the command to run the FastAPI server
+CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
\ No newline at end of file
diff --git a/docker/README.md b/docker/README.md
new file mode 100644
index 0000000..52a52fa
--- /dev/null
+++ b/docker/README.md
@@ -0,0 +1,49 @@
+Provide an OpenAI API transcribe server.
+
+
+## Build Docker Image
+
+```bash
+docker-compose build
+```
+
+or
+
+```bash
+docker build -t sensevoice-openai-server .
+```
+
+## Start Server
+
+### Docker Compose
+Change volumes in docker-compose.yml to the path of the model you want to use.
+
+```bash
+docker-compose up -d
+```
+
+### Docker
+```bash
+docker run -d -p 8000:8000 -v "/your/cache/dir:/root/.cache" sensevoice-openai-server
+```
+
+## Usage
+
+```bash
+curl http://127.0.0.1:8000/v1/audio/transcriptions \
+  -H "Content-Type: multipart/form-data" \
+  -F model="iic/SenseVoiceSmall" \
+  -F file="@/path/to/file/openai.mp3"
+```
+
+```python
+from openai import OpenAI
+client = OpenAI(base_url="http://127.0.0.1:8000/v1", api_key="anything")
+
+audio_file= open("/path/to/file/audio.mp3", "rb")
+transcription = client.audio.transcriptions.create(
+  model="iic/SenseVoiceSmall", 
+  file=audio_file
+)
+print(transcription.text)
+```
\ No newline at end of file
diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml
new file mode 100644
index 0000000..273aec5
--- /dev/null
+++ b/docker/docker-compose.yml
@@ -0,0 +1,16 @@
+# @Author: Bi Ying
+# @Date:   2024-07-10 21:10:22
+# @Last Modified by:   Bi Ying
+# @Last Modified time: 2024-07-10 21:20:22
+version: '3.8'
+
+services:
+  sensevoice-openai-server:
+    image: sensevoice-openai-server
+    build: .
+    ports:
+      - "8000:8000"
+    volumes:
+      - ~/.cache:/root/.cache
+    environment:
+      TMP_DIR: /app/tmp
\ No newline at end of file
diff --git a/docker/main.py b/docker/main.py
new file mode 100644
index 0000000..01c8979
--- /dev/null
+++ b/docker/main.py
@@ -0,0 +1,212 @@
+# @Author: Bi Ying
+# @Date:   2024-07-10 17:22:55
+import shutil
+from pathlib import Path
+from typing import Union
+
+import torch
+import torchaudio
+import numpy as np
+from funasr import AutoModel
+from fastapi import FastAPI, Form, UploadFile, File, HTTPException, status
+
+
+app = FastAPI()
+
+TMP_DIR = "./tmp"
+
+# Initialize the model
+model = "iic/SenseVoiceSmall"
+model = AutoModel(
+    model=model,
+    vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+    vad_kwargs={"max_single_segment_time": 30000},
+    trust_remote_code=True,
+)
+
+emo_dict = {
+    "<|HAPPY|>": "๐Ÿ˜Š",
+    "<|SAD|>": "๐Ÿ˜”",
+    "<|ANGRY|>": "๐Ÿ˜ก",
+    "<|NEUTRAL|>": "",
+    "<|FEARFUL|>": "๐Ÿ˜ฐ",
+    "<|DISGUSTED|>": "๐Ÿคข",
+    "<|SURPRISED|>": "๐Ÿ˜ฎ",
+}
+
+event_dict = {
+    "<|BGM|>": "๐ŸŽผ",
+    "<|Speech|>": "",
+    "<|Applause|>": "๐Ÿ‘",
+    "<|Laughter|>": "๐Ÿ˜€",
+    "<|Cry|>": "๐Ÿ˜ญ",
+    "<|Sneeze|>": "๐Ÿคง",
+    "<|Breath|>": "",
+    "<|Cough|>": "๐Ÿคง",
+}
+
+emoji_dict = {
+    "<|nospeech|><|Event_UNK|>": "โ“",
+    "<|zh|>": "",
+    "<|en|>": "",
+    "<|yue|>": "",
+    "<|ja|>": "",
+    "<|ko|>": "",
+    "<|nospeech|>": "",
+    "<|HAPPY|>": "๐Ÿ˜Š",
+    "<|SAD|>": "๐Ÿ˜”",
+    "<|ANGRY|>": "๐Ÿ˜ก",
+    "<|NEUTRAL|>": "",
+    "<|BGM|>": "๐ŸŽผ",
+    "<|Speech|>": "",
+    "<|Applause|>": "๐Ÿ‘",
+    "<|Laughter|>": "๐Ÿ˜€",
+    "<|FEARFUL|>": "๐Ÿ˜ฐ",
+    "<|DISGUSTED|>": "๐Ÿคข",
+    "<|SURPRISED|>": "๐Ÿ˜ฎ",
+    "<|Cry|>": "๐Ÿ˜ญ",
+    "<|EMO_UNKNOWN|>": "",
+    "<|Sneeze|>": "๐Ÿคง",
+    "<|Breath|>": "",
+    "<|Cough|>": "๐Ÿ˜ท",
+    "<|Sing|>": "",
+    "<|Speech_Noise|>": "",
+    "<|withitn|>": "",
+    "<|woitn|>": "",
+    "<|GBG|>": "",
+    "<|Event_UNK|>": "",
+}
+
+lang_dict = {
+    "<|zh|>": "<|lang|>",
+    "<|en|>": "<|lang|>",
+    "<|yue|>": "<|lang|>",
+    "<|ja|>": "<|lang|>",
+    "<|ko|>": "<|lang|>",
+    "<|nospeech|>": "<|lang|>",
+}
+
+emo_set = {"๐Ÿ˜Š", "๐Ÿ˜”", "๐Ÿ˜ก", "๐Ÿ˜ฐ", "๐Ÿคข", "๐Ÿ˜ฎ"}
+event_set = {
+    "๐ŸŽผ",
+    "๐Ÿ‘",
+    "๐Ÿ˜€",
+    "๐Ÿ˜ญ",
+    "๐Ÿคง",
+    "๐Ÿ˜ท",
+}
+
+
+def format_str_v2(text: str, show_emo=True, show_event=True):
+    sptk_dict = {}
+    for sptk in emoji_dict:
+        sptk_dict[sptk] = text.count(sptk)
+        text = text.replace(sptk, "")
+
+    emo = "<|NEUTRAL|>"
+    for e in emo_dict:
+        if sptk_dict[e] > sptk_dict[emo]:
+            emo = e
+    if show_emo:
+        text = text + emo_dict[emo]
+
+    for e in event_dict:
+        if sptk_dict[e] > 0 and show_event:
+            text = event_dict[e] + text
+
+    for emoji in emo_set.union(event_set):
+        text = text.replace(" " + emoji, emoji)
+        text = text.replace(emoji + " ", emoji)
+
+    return text.strip()
+
+
+def format_str_v3(text: str, show_emo=True, show_event=True):
+    def get_emo(s):
+        return s[-1] if s[-1] in emo_set else None
+
+    def get_event(s):
+        return s[0] if s[0] in event_set else None
+
+    text = text.replace("<|nospeech|><|Event_UNK|>", "โ“")
+    for lang in lang_dict:
+        text = text.replace(lang, "<|lang|>")
+    parts = [format_str_v2(part, show_emo, show_event).strip(" ") for part in text.split("<|lang|>")]
+    new_s = " " + parts[0]
+    cur_ent_event = get_event(new_s)
+    for i in range(1, len(parts)):
+        if len(parts[i]) == 0:
+            continue
+        if get_event(parts[i]) == cur_ent_event and get_event(parts[i]) is not None:
+            parts[i] = parts[i][1:]
+        cur_ent_event = get_event(parts[i])
+        if get_emo(parts[i]) is not None and get_emo(parts[i]) == get_emo(new_s):
+            new_s = new_s[:-1]
+        new_s += parts[i].strip().lstrip()
+    new_s = new_s.replace("The.", " ")
+    return new_s.strip()
+
+
+def model_inference(input_wav, language, fs=16000, show_emo=True, show_event=True):
+    language = "auto" if len(language) < 1 else language
+
+    if isinstance(input_wav, tuple):
+        fs, input_wav = input_wav
+        input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
+        if len(input_wav.shape) > 1:
+            input_wav = input_wav.mean(-1)
+        if fs != 16000:
+            resampler = torchaudio.transforms.Resample(fs, 16000)
+            input_wav_t = torch.from_numpy(input_wav).to(torch.float32)
+            input_wav = resampler(input_wav_t[None, :])[0, :].numpy()
+
+    if len(input_wav) == 0:
+        raise ValueError("The provided audio is empty.")
+
+    merge_vad = True
+    text = model.generate(
+        input=input_wav,
+        cache={},
+        language=language,
+        use_itn=True,
+        batch_size_s=0,
+        merge_vad=merge_vad,
+    )
+
+    text = text[0]["text"]
+    text = format_str_v3(text, show_emo, show_event)
+
+    return text
+
+
+@app.post("/v1/audio/transcriptions")
+async def transcriptions(file: Union[UploadFile, None] = File(default=None), language: str = Form(default="auto")):
+    if file is None:
+        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Bad Request, no file provided")
+
+    filename = file.filename
+    fileobj = file.file
+    tmp_file = Path(TMP_DIR) / filename
+
+    with open(tmp_file, "wb+") as upload_file:
+        shutil.copyfileobj(fileobj, upload_file)
+
+    # ็กฎไฟ้Ÿณ้ข‘ๆ•ฐๆฎไฟๆŒไธบint32ๆ ผๅผ๏ผŒๅนถ่ฝฌๆขไธบไธ€็ปดๆ•ฐ็ป„
+    waveform, sample_rate = torchaudio.load(tmp_file)
+    waveform = (waveform * np.iinfo(np.int32).max).to(dtype=torch.int32).squeeze()
+    if len(waveform.shape) > 1:
+        waveform = waveform.float().mean(axis=0)  # ๅฐ†ๅคš้€š้“้Ÿณ้ข‘่ฝฌๆขไธบๅ•้€š้“
+    input_wav = (sample_rate, waveform.numpy())
+
+    result = model_inference(input_wav=input_wav, language=language, show_emo=False)
+
+    # ๅˆ ้™คไธดๆ—ถๆ–‡ไปถ
+    tmp_file.unlink()
+
+    return {"text": result}
+
+
+if __name__ == "__main__":
+    import uvicorn
+
+    uvicorn.run(app, host="0.0.0.0", port=8000)
diff --git a/docker/model.py b/docker/model.py
new file mode 100644
index 0000000..f8dcbfa
--- /dev/null
+++ b/docker/model.py
@@ -0,0 +1,862 @@
+# @Author: Bi Ying
+# @Date:   2024-07-10 20:09:12
+import time
+from typing import Optional
+import torch
+import torch.nn.functional as F
+from torch import nn
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.train_utils.device_funcs import force_gatherable
+
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.ctc.ctc import CTC
+
+from funasr.register import tables
+
+
+class SinusoidalPositionEncoder(torch.nn.Module):
+    """ """
+
+    def __int__(self, d_model=80, dropout_rate=0.1):
+        pass
+
+    def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
+        batch_size = positions.size(0)
+        positions = positions.type(dtype)
+        device = positions.device
+        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (depth / 2 - 1)
+        inv_timescales = torch.exp(torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment))
+        inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
+        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
+        encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
+        return encoding.type(dtype)
+
+    def forward(self, x):
+        batch_size, timesteps, input_dim = x.size()
+        positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
+        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
+
+        return x + position_encoding
+
+
+class PositionwiseFeedForward(torch.nn.Module):
+    """Positionwise feed forward layer.
+
+    Args:
+        idim (int): Input dimenstion.
+        hidden_units (int): The number of hidden units.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
+        """Construct an PositionwiseFeedForward object."""
+        super(PositionwiseFeedForward, self).__init__()
+        self.w_1 = torch.nn.Linear(idim, hidden_units)
+        self.w_2 = torch.nn.Linear(hidden_units, idim)
+        self.dropout = torch.nn.Dropout(dropout_rate)
+        self.activation = activation
+
+    def forward(self, x):
+        """Forward function."""
+        return self.w_2(self.dropout(self.activation(self.w_1(x))))
+
+
+class MultiHeadedAttentionSANM(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(
+        self,
+        n_head,
+        in_feat,
+        n_feat,
+        dropout_rate,
+        kernel_size,
+        sanm_shfit=0,
+        lora_list=None,
+        lora_rank=8,
+        lora_alpha=16,
+        lora_dropout=0.1,
+    ):
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        # self.linear_q = nn.Linear(n_feat, n_feat)
+        # self.linear_k = nn.Linear(n_feat, n_feat)
+        # self.linear_v = nn.Linear(n_feat, n_feat)
+
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        # padding
+        left_padding = (kernel_size - 1) // 2
+        if sanm_shfit > 0:
+            left_padding = left_padding + sanm_shfit
+        right_padding = kernel_size - 1 - left_padding
+        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+
+    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
+        b, t, d = inputs.size()
+        if mask is not None:
+            mask = torch.reshape(mask, (b, -1, 1))
+            if mask_shfit_chunk is not None:
+                mask = mask * mask_shfit_chunk
+            inputs = inputs * mask
+
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x += inputs
+        x = self.dropout(x)
+        if mask is not None:
+            x = x * mask
+        return x
+
+    def forward_qkv(self, x):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        b, t, d = x.size()
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q_h, k_h, v_h, v
+
+    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            if mask_att_chunk_encoder is not None:
+                mask = mask * mask_att_chunk_encoder
+
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+
+            min_value = -float("inf")  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        return att_outs + fsmn_memory
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        if chunk_size is not None and look_back > 0 or look_back == -1:
+            if cache is not None:
+                k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
+                v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
+                k_h = torch.cat((cache["k"], k_h), dim=2)
+                v_h = torch.cat((cache["v"], v_h), dim=2)
+
+                cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
+                cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
+                if look_back != -1:
+                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
+                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
+            else:
+                cache_tmp = {
+                    "k": k_h[:, :, : -(chunk_size[2]), :],
+                    "v": v_h[:, :, : -(chunk_size[2]), :],
+                }
+                cache = cache_tmp
+        fsmn_memory = self.forward_fsmn(v, None)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, None)
+        return att_outs + fsmn_memory, cache
+
+
+class LayerNorm(nn.LayerNorm):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, input):
+        output = F.layer_norm(
+            input.float(),
+            self.normalized_shape,
+            self.weight.float() if self.weight is not None else None,
+            self.bias.float() if self.bias is not None else None,
+            self.eps,
+        )
+        return output.type_as(input)
+
+
+def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
+    if maxlen is None:
+        maxlen = lengths.max()
+    row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
+    matrix = torch.unsqueeze(lengths, dim=-1)
+    mask = row_vector < matrix
+    mask = mask.detach()
+
+    return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
+
+
+class EncoderLayerSANM(nn.Module):
+    def __init__(
+        self,
+        in_size,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayerSANM, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(in_size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.in_size = in_size
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+        self.dropout_rate = dropout_rate
+
+    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.concat_after:
+            x_concat = torch.cat(
+                (
+                    x,
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    ),
+                ),
+                dim=-1,
+            )
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+            else:
+                x = stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.dropout(
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
+                )
+            else:
+                x = stoch_layer_coeff * self.dropout(
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
+                )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.in_size == self.size:
+            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+            x = residual + attn
+        else:
+            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.feed_forward(x)
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, cache
+
+
+@tables.register("encoder_classes", "SenseVoiceEncoderSmall")
+class SenseVoiceEncoderSmall(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        tp_blocks: int = 0,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        stochastic_depth_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        kernel_size: int = 11,
+        sanm_shfit: int = 0,
+        selfattention_layer_type: str = "sanm",
+        **kwargs,
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        self.embed = SinusoidalPositionEncoder()
+
+        self.normalize_before = normalize_before
+
+        positionwise_layer = PositionwiseFeedForward
+        positionwise_layer_args = (
+            output_size,
+            linear_units,
+            dropout_rate,
+        )
+
+        encoder_selfattn_layer = MultiHeadedAttentionSANM
+        encoder_selfattn_layer_args0 = (
+            attention_heads,
+            input_size,
+            output_size,
+            attention_dropout_rate,
+            kernel_size,
+            sanm_shfit,
+        )
+        encoder_selfattn_layer_args = (
+            attention_heads,
+            output_size,
+            output_size,
+            attention_dropout_rate,
+            kernel_size,
+            sanm_shfit,
+        )
+
+        self.encoders0 = nn.ModuleList(
+            [
+                EncoderLayerSANM(
+                    input_size,
+                    output_size,
+                    encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                    positionwise_layer(*positionwise_layer_args),
+                    dropout_rate,
+                )
+                for i in range(1)
+            ]
+        )
+        self.encoders = nn.ModuleList(
+            [
+                EncoderLayerSANM(
+                    output_size,
+                    output_size,
+                    encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                    positionwise_layer(*positionwise_layer_args),
+                    dropout_rate,
+                )
+                for i in range(num_blocks - 1)
+            ]
+        )
+
+        self.tp_encoders = nn.ModuleList(
+            [
+                EncoderLayerSANM(
+                    output_size,
+                    output_size,
+                    encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                    positionwise_layer(*positionwise_layer_args),
+                    dropout_rate,
+                )
+                for i in range(tp_blocks)
+            ]
+        )
+
+        self.after_norm = LayerNorm(output_size)
+
+        self.tp_norm = LayerNorm(output_size)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+    ):
+        """Embed positions in tensor."""
+        masks = sequence_mask(ilens, device=ilens.device)[:, None, :]
+
+        xs_pad *= self.output_size() ** 0.5
+
+        xs_pad = self.embed(xs_pad)
+
+        # forward encoder1
+        for layer_idx, encoder_layer in enumerate(self.encoders0):
+            encoder_outs = encoder_layer(xs_pad, masks)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+        for layer_idx, encoder_layer in enumerate(self.encoders):
+            encoder_outs = encoder_layer(xs_pad, masks)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+        xs_pad = self.after_norm(xs_pad)
+
+        # forward encoder2
+        olens = masks.squeeze(1).sum(1).int()
+
+        for layer_idx, encoder_layer in enumerate(self.tp_encoders):
+            encoder_outs = encoder_layer(xs_pad, masks)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+        xs_pad = self.tp_norm(xs_pad)
+        return xs_pad, olens
+
+
+@tables.register("model_classes", "SenseVoiceSmall")
+class SenseVoiceSmall(nn.Module):
+    """CTC-attention hybrid Encoder-Decoder model"""
+
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        ctc_conf: dict = None,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        length_normalized_loss: bool = False,
+        **kwargs,
+    ):
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        if ctc_conf is None:
+            ctc_conf = {}
+        ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
+
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+        self.error_calculator = None
+
+        self.ctc = ctc
+
+        self.length_normalized_loss = length_normalized_loss
+        self.encoder_output_size = encoder_output_size
+
+        self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
+        self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
+        self.textnorm_dict = {"withitn": 14, "woitn": 15}
+        self.textnorm_int_dict = {25016: 14, 25017: 15}
+        self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size)
+
+        self.criterion_att = LabelSmoothingLoss(
+            size=self.vocab_size,
+            padding_idx=self.ignore_id,
+            smoothing=kwargs.get("lsm_weight", 0.0),
+            normalize_length=self.length_normalized_loss,
+        )
+
+    @staticmethod
+    def from_pretrained(model: str = None, **kwargs):
+        from funasr import AutoModel
+
+        model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
+
+        return model, kwargs
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ):
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        # import pdb;
+        # pdb.set_trace()
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+
+        batch_size = speech.shape[0]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
+
+        loss_ctc, cer_ctc = None, None
+        loss_rich, acc_rich = None, None
+        stats = dict()
+
+        loss_ctc, cer_ctc = self._calc_ctc_loss(
+            encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4
+        )
+
+        loss_rich, acc_rich = self._calc_rich_ce_loss(encoder_out[:, :4, :], text[:, :4])
+
+        loss = loss_ctc
+        # Collect total loss stats
+        stats["loss"] = torch.clone(loss.detach()) if loss_ctc is not None else None
+        stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
+        stats["acc_rich"] = acc_rich
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        **kwargs,
+    ):
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+
+        # Data augmentation
+        if self.specaug is not None and self.training:
+            speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+        # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+        if self.normalize is not None:
+            speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+        lids = torch.LongTensor(
+            [
+                [self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0]
+                for lid in text[:, 0]
+            ]
+        ).to(speech.device)
+        language_query = self.embed(lids)
+
+        styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device)
+        style_query = self.embed(styles)
+        speech = torch.cat((style_query, speech), dim=1)
+        speech_lengths += 1
+
+        event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+        input_query = torch.cat((language_query, event_emo_query), dim=1)
+        speech = torch.cat((input_query, speech), dim=1)
+        speech_lengths += 3
+
+        encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+
+        return encoder_out, encoder_out_lens
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+
+    def _calc_rich_ce_loss(
+        self,
+        encoder_out: torch.Tensor,
+        ys_pad: torch.Tensor,
+    ):
+        decoder_out = self.ctc.ctc_lo(encoder_out)
+        # 2. Compute attention loss
+        loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous())
+        acc_rich = th_accuracy(
+            decoder_out.view(-1, self.vocab_size),
+            ys_pad.contiguous(),
+            ignore_label=self.ignore_id,
+        )
+
+        return loss_rich, acc_rich
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = ["wav_file_tmp_name"],
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+        language = kwargs.get("language", "auto")
+        language_query = self.embed(
+            torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)
+        ).repeat(speech.size(0), 1, 1)
+
+        use_itn = kwargs.get("use_itn", False)
+        textnorm = kwargs.get("text_norm", None)
+        if textnorm is None:
+            textnorm = "withitn" if use_itn else "woitn"
+        textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(
+            speech.size(0), 1, 1
+        )
+        speech = torch.cat((textnorm_query, speech), dim=1)
+        speech_lengths += 1
+
+        event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
+        input_query = torch.cat((language_query, event_emo_query), dim=1)
+        speech = torch.cat((input_query, speech), dim=1)
+        speech_lengths += 3
+
+        # Encoder
+        encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        # c. Passed the encoder result and the beam search
+        ctc_logits = self.ctc.log_softmax(encoder_out)
+
+        results = []
+        b, n, d = encoder_out.size()
+        if isinstance(key[0], (list, tuple)):
+            key = key[0]
+        if len(key) < b:
+            key = key * b
+        for i in range(b):
+            x = ctc_logits[i, : encoder_out_lens[i].item(), :]
+            yseq = x.argmax(dim=-1)
+            yseq = torch.unique_consecutive(yseq, dim=-1)
+
+            ibest_writer = None
+            if kwargs.get("output_dir") is not None:
+                if not hasattr(self, "writer"):
+                    self.writer = DatadirWriter(kwargs.get("output_dir"))
+                ibest_writer = self.writer["1best_recog"]
+
+            mask = yseq != self.blank_id
+            token_int = yseq[mask].tolist()
+
+            # Change integer-ids to tokens
+            text = tokenizer.decode(token_int)
+
+            result_i = {"key": key[i], "text": text}
+            results.append(result_i)
+
+            if ibest_writer is not None:
+                ibest_writer["text"][key[i]] = text
+
+        return results, meta_data
+
+    def export(self, **kwargs):
+        from .export_meta import export_rebuild_model
+
+        if "max_seq_len" not in kwargs:
+            kwargs["max_seq_len"] = 512
+        models = export_rebuild_model(model=self, **kwargs)
+        return models
diff --git a/docker/requirements.txt b/docker/requirements.txt
new file mode 100644
index 0000000..60d828a
--- /dev/null
+++ b/docker/requirements.txt
@@ -0,0 +1,7 @@
+fastapi
+uvicorn
+torch
+torchaudio
+numpy
+funasr
+modelscope
\ No newline at end of file