Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reproduction for SLAM-Omni #190

Merged
merged 173 commits into from
Jan 22, 2025
Merged
Changes from 1 commit
Commits
Show all changes
173 commits
Select commit Hold shift + click to select a range
d66af4a
init
cwx-worst-one Sep 18, 2024
dedd657
9.20
cwx-worst-one Sep 20, 2024
b6a8478
9.20
cwx-worst-one Sep 20, 2024
f751db8
9.21
cwx-worst-one Sep 21, 2024
d6d6d97
9.22
cwx-worst-one Sep 22, 2024
76b439a
9.23
cwx-worst-one Sep 23, 2024
0f3abac
Update finetune_cwx.sh script
cwx-worst-one Sep 23, 2024
3233ce5
Update finetune_cwx.sh script: Adjust dataset split size and enable f…
cwx-worst-one Sep 23, 2024
e4dde82
9.23
cwx-worst-one Sep 23, 2024
9f49036
9.23
cwx-worst-one Sep 23, 2024
bf0d4c6
9.23
cwx-worst-one Sep 23, 2024
1c19ec8
generate
cwx-worst-one Sep 23, 2024
96721e2
9.23 推理有点问题
cwx-worst-one Sep 23, 2024
d54a5eb
9.24
cwx-worst-one Sep 24, 2024
e755212
fix padding_token in snac_utils.py
cwx-worst-one Sep 24, 2024
b0e404f
9.24
cwx-worst-one Sep 24, 2024
6e89942
track for layer loss
cwx-worst-one Sep 25, 2024
4ea1cc7
docker
cwx-worst-one Sep 25, 2024
8d1be75
Update Dockerfile and pyproject.toml
cwx-worst-one Sep 25, 2024
0c3e407
Update pyproject.toml to allow direct references in metadata
cwx-worst-one Sep 25, 2024
cf3d4a6
Update requirements.txt
cwx-worst-one Sep 25, 2024
9481720
Update requirements.txt to remove unused dependencies
cwx-worst-one Sep 25, 2024
4d96076
Remove unused dependencies from requirements.txt
cwx-worst-one Sep 25, 2024
5eddade
1
cwx-worst-one Sep 25, 2024
11ab973
1
cwx-worst-one Sep 25, 2024
bea2e89
1
cwx-worst-one Sep 25, 2024
f0311b7
1
cwx-worst-one Sep 25, 2024
7ae8d7a
9.25
cwx-worst-one Sep 25, 2024
e87dc3d
9.25
cwx-worst-one Sep 25, 2024
bae1714
9.26
cwx-worst-one Sep 26, 2024
14faf30
9.26
cwx-worst-one Sep 26, 2024
38fba5b
9.26
cwx-worst-one Sep 26, 2024
efe6351
9.29
cwx-worst-one Sep 29, 2024
fdd82ec
9.29
cwx-worst-one Sep 29, 2024
7c973ba
9.29
cwx-worst-one Sep 29, 2024
621e92a
9.29
cwx-worst-one Sep 29, 2024
3fc8b61
9.30
cwx-worst-one Sep 30, 2024
6fc689d
9.30
cwx-worst-one Sep 30, 2024
a708b19
9.30
cwx-worst-one Sep 30, 2024
5e18879
9.30
cwx-worst-one Sep 30, 2024
59dcd16
10.7
cwx-worst-one Oct 7, 2024
7e84708
10.7
cwx-worst-one Oct 7, 2024
e4a6003
10.7
cwx-worst-one Oct 7, 2024
1362515
10.7
cwx-worst-one Oct 7, 2024
cef91e7
10.8
cwx-worst-one Oct 8, 2024
b29bb9f
10.8
cwx-worst-one Oct 8, 2024
df15ad0
10.8
cwx-worst-one Oct 8, 2024
a2fce40
10.8
cwx-worst-one Oct 8, 2024
d145773
10.8
cwx-worst-one Oct 8, 2024
ac6155c
debug !!!
cwx-worst-one Oct 9, 2024
141d87e
slam-omni v0
cwx-worst-one Oct 9, 2024
d8a1d54
10.9
cwx-worst-one Oct 9, 2024
195c10c
10.9
cwx-worst-one Oct 9, 2024
3fbe0c6
wonderful day!
cwx-worst-one Oct 9, 2024
888e209
10.10
cwx-worst-one Oct 10, 2024
edcf096
10.10
cwx-worst-one Oct 10, 2024
1e6d5f1
10.10
cwx-worst-one Oct 10, 2024
1871b47
10.10
cwx-worst-one Oct 10, 2024
2626a0a
10.10
cwx-worst-one Oct 10, 2024
f501eda
10.10
cwx-worst-one Oct 10, 2024
aabfd02
10.10
cwx-worst-one Oct 10, 2024
3b62b5e
10.11
cwx-worst-one Oct 11, 2024
0a4f6d1
10.12
cwx-worst-one Oct 12, 2024
0c15b6e
10.12
cwx-worst-one Oct 12, 2024
1dc1905
life
cwx-worst-one Oct 12, 2024
e1c311f
sadness
cwx-worst-one Oct 12, 2024
ed8bfb0
10.14
cwx-worst-one Oct 14, 2024
655a0f6
Merge pull request #151 from cwx-worst-one/main
cwx-worst-one Oct 14, 2024
ec86c78
Merge pull request #152 from X-LANCE/main
cwx-worst-one Oct 14, 2024
c1a270a
enjoy yourself
cwx-worst-one Oct 14, 2024
4b418b0
pride
cwx-worst-one Oct 14, 2024
c98f1c2
1
cwx-worst-one Oct 14, 2024
c72da32
whisper support
cwx-worst-one Oct 14, 2024
fc0983a
Gluttony
cwx-worst-one Oct 14, 2024
664e30f
Gluttony
cwx-worst-one Oct 14, 2024
7d5e142
Gluttony
cwx-worst-one Oct 14, 2024
b24808d
lust
cwx-worst-one Oct 15, 2024
832e9ad
[cwx-worst-one] add streaming inference
cwx-worst-one Oct 15, 2024
4372ac1
Merge pull request #155 from X-LANCE/s2s
cwx-worst-one Oct 16, 2024
ab22dcb
[cwx-worst-one] Agony
cwx-worst-one Oct 16, 2024
fee1421
sloth
cwx-worst-one Oct 18, 2024
cecd5d7
Merge branch 'dev-cwx-my' into dev-cwx
cwx-worst-one Oct 18, 2024
ff8d2ee
Merge pull request #158 from X-LANCE/dev-cwx
cwx-worst-one Oct 18, 2024
ba419af
agony
cwx-worst-one Oct 19, 2024
2d39d26
agony
cwx-worst-one Oct 19, 2024
ec10aa2
greed
cwx-worst-one Oct 22, 2024
18fc658
greed
cwx-worst-one Oct 22, 2024
bb478fe
greed
cwx-worst-one Oct 22, 2024
6b8f264
greed
cwx-worst-one Oct 22, 2024
9484d5e
greed
cwx-worst-one Oct 22, 2024
77a6daf
add cosyvoice
cwx-worst-one Oct 22, 2024
eaf32b4
lust
cwx-worst-one Oct 23, 2024
24c9c56
lust
cwx-worst-one Oct 23, 2024
8e79221
lust
cwx-worst-one Oct 23, 2024
2fdab67
lust
cwx-worst-one Oct 23, 2024
9421566
envy
cwx-worst-one Oct 25, 2024
1004538
sloth
cwx-worst-one Oct 26, 2024
ea1bb36
add group decoding for CosyVoice and add logging info
cwx-worst-one Oct 28, 2024
b010300
1
cwx-worst-one Oct 28, 2024
2dec8b0
update audio repetition penalty
cwx-worst-one Oct 28, 2024
9b419d8
update online inference of CV
cwx-worst-one Oct 29, 2024
9d9647f
add readme
cwx-worst-one Oct 30, 2024
cef5d6e
modify shell script
cwx-worst-one Oct 30, 2024
22b9fbf
fix wandb
cwx-worst-one Oct 30, 2024
1d15626
update generate
cwx-worst-one Oct 30, 2024
ee9be97
1
cwx-worst-one Oct 30, 2024
a2a6f1f
update
cwx-worst-one Nov 1, 2024
3152858
update generate
cwx-worst-one Nov 1, 2024
c1001ac
update generate
cwx-worst-one Nov 1, 2024
b760507
update generate
cwx-worst-one Nov 4, 2024
90b87ee
1
cwx-worst-one Nov 5, 2024
431ad19
update script
cwx-worst-one Nov 6, 2024
5a11d43
update
cwx-worst-one Nov 15, 2024
ae2d3f7
11.17
cwx-worst-one Nov 17, 2024
5a0a825
11.18
cwx-worst-one Nov 18, 2024
e57d15a
11.18
cwx-worst-one Nov 18, 2024
3b0fcbc
update prompt
cwx-worst-one Nov 20, 2024
5f93f6f
file clean
cwx-worst-one Nov 26, 2024
0eed70e
update re
cwx-worst-one Nov 26, 2024
e6e490d
file clean
cwx-worst-one Nov 26, 2024
54bc3a2
11.26
cwx-worst-one Nov 26, 2024
21e0362
update
cwx-worst-one Nov 26, 2024
ef4bb63
clean script
cwx-worst-one Nov 27, 2024
bb67424
clean
cwx-worst-one Nov 27, 2024
0a2c107
fix
cwx-worst-one Nov 27, 2024
beae32c
clean
cwx-worst-one Nov 28, 2024
9f24ac0
clean
cwx-worst-one Nov 28, 2024
8b634c9
clean
cwx-worst-one Nov 28, 2024
2184d2c
1
cwx-worst-one Dec 3, 2024
6d4942c
update
cwx-worst-one Dec 3, 2024
212993c
update
cwx-worst-one Dec 3, 2024
2453b58
clean
cwx-worst-one Dec 3, 2024
b792ee2
update
cwx-worst-one Dec 3, 2024
57ace40
update
cwx-worst-one Dec 4, 2024
9c197c1
update asr
cwx-worst-one Dec 7, 2024
e3f810e
add readme
cwx-worst-one Dec 7, 2024
4df702d
update
cwx-worst-one Dec 11, 2024
350a115
update
cwx-worst-one Dec 12, 2024
d055338
update
cwx-worst-one Dec 12, 2024
98a1bc8
update readme
cwx-worst-one Dec 17, 2024
ce7ccfd
update download
cwx-worst-one Dec 17, 2024
a1519a7
update
cwx-worst-one Dec 18, 2024
4f0edef
update requirement
cwx-worst-one Dec 18, 2024
ca1fb06
update readme
cwx-worst-one Dec 18, 2024
2cdd57b
clean
cwx-worst-one Dec 18, 2024
f47b449
clean
cwx-worst-one Dec 18, 2024
b3d0f6d
update script
cwx-worst-one Dec 18, 2024
13946c8
update
cwx-worst-one Dec 18, 2024
6f58c72
update
cwx-worst-one Dec 18, 2024
6c2aef5
update multi-round
cwx-worst-one Dec 19, 2024
07ca5be
update readme
cwx-worst-one Dec 20, 2024
2b2f33f
update readme
cwx-worst-one Dec 20, 2024
85191fb
update readme
cwx-worst-one Dec 22, 2024
bd5e5aa
Merge pull request #183 from X-LANCE/main
cwx-worst-one Dec 24, 2024
c1e26c1
feat: add encoder_fairseq_dir path to fine-tuning and inference scripts
cwx-worst-one Dec 24, 2024
42514c2
Revert "merge latest main branch"
cwx-worst-one Dec 24, 2024
a000404
Merge pull request #184 from X-LANCE/revert-183-main
cwx-worst-one Dec 24, 2024
569590c
fix: update num_latency_tokens to 0 and set OMP_NUM_THREADS in finetu…
cwx-worst-one Dec 24, 2024
58b1ed0
Revert "Revert "merge latest main branch""
cwx-worst-one Dec 24, 2024
207d299
Merge pull request #185 from X-LANCE/revert-184-revert-183-main
cwx-worst-one Dec 24, 2024
737cf76
update
cwx-worst-one Dec 27, 2024
2a401f4
fix: set num_latency_tokens to 0 in inference and pretraining scripts
cwx-worst-one Jan 17, 2025
b08b681
update
cwx-worst-one Jan 21, 2025
425d438
update main readme
cwx-worst-one Jan 21, 2025
26f3832
Update README.md
cwx-worst-one Jan 21, 2025
98797c2
Merge pull request #188 from X-LANCE/main
cwx-worst-one Jan 21, 2025
221e893
update
cwx-worst-one Jan 21, 2025
ec7fb64
Update README.md
cwx-worst-one Jan 22, 2025
229765a
Update README.md
cwx-worst-one Jan 22, 2025
4830fe1
update train_utils
cwx-worst-one Jan 22, 2025
a294995
update main readme
cwx-worst-one Jan 22, 2025
2435b83
test
cwx-worst-one Jan 22, 2025
c9b7d9e
delete useless files
cwx-worst-one Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
9.21
cwx-worst-one committed Sep 21, 2024
commit f751db84cb3802aa9414fddc0488896700a438be
35 changes: 18 additions & 17 deletions examples/s2s/model/slam_model_s2s.py
Original file line number Diff line number Diff line change
@@ -30,8 +30,6 @@ def model_factory(train_config, model_config, **kwargs):
train_config, model_config, **kwargs
)

# TODO: add decoder projector and decoder

model = slam_model_s2s(
encoder,
llm,
@@ -83,13 +81,15 @@ def __init__(
**kwargs,
)

# TODO: 增加逻辑,修改 llm 的 lm_head 和 embedding 的词表大小,并重新打印模型大小
# resize llm embedding layer
if self.model_config.vocab_config.total_vocabsize != self.llm.lm_head.weight.size(0):
self.llm.resize_token_embeddings(self.model_config.vocab_config.total_vocabsize)


def concat_whisper_feat(self, audio_feature, input_ids, T, task="A1A2"):
btz = len(T) # 获取批量大小
def concat_whisper_feat(self, audio_feature, input_ids, T, task = None):
btz = len(T)
for j in range(btz):
if task[j] != "T1T2" and task[j] != "T1A2":
if task is None or (task[j] != "T1T2" and task[j] != "T1A2"):
for i in range(7):
input_ids[j, i, 1 : T[j] + 1, :] = audio_feature[j][: T[j]].clone()
else:
@@ -111,7 +111,6 @@ def forward(self,
):
audio_mel = kwargs.get("audio_mel", None)
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
audio_length = kwargs.get("audio_length", None)

audio = kwargs.get("audio", None)
audio_mask = kwargs.get("audio_mask", None)
@@ -157,22 +156,24 @@ def forward(self,
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)

if audio_mel is not None or audio is not None:
inputs_embeds = self.concat_whisper_feat(encoder_outs, inputs_embeds, audio_length)

inputs_embeds = torch.mean(inputs_embeds, dim=1) # [btz, seq_length, emb_dim]
# if audio_mel is not None or audio is not None:
# inputs_embeds = self.concat_whisper_feat(encoder_outs, inputs_embeds, audio_length) # embed the audio feature into the input_embeds

if modality_mask is not None:
modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1)
modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()
modality_mask = modality_mask.unsqueeze(1).repeat(1, 7, 1) # [btz, 8, seq_length]
modality_mask_start_indices = (modality_mask == True).float().argmax(dim=2)
modality_lengths = torch.clamp(modality_mask.sum(dim=2), max=encoder_outs.shape[1]).tolist()

encoder_outs_pad = torch.zeros_like(inputs_embeds)
for i in range(encoder_outs.shape[0]):
encoder_outs_pad[
i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i]
] = encoder_outs[i][:modality_lengths[i]]
for j in range(7):
start_idx = modality_mask_start_indices[i, j].item()
length = modality_lengths[i][j]
encoder_outs_pad[i, j, start_idx:start_idx+length] = encoder_outs[i, :length]

inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])
inputs_embeds[:, :7, :, :] = encoder_outs_pad[:, :7, :, :] + inputs_embeds[:, :7, :, :] * (~modality_mask[:, :, :, None])

inputs_embeds = torch.mean(inputs_embeds, dim=1) # [btz, seq_length, emb_dim], average over the 8 layers

if kwargs.get("inference_mode", False):
return inputs_embeds, attention_mask
45 changes: 45 additions & 0 deletions examples/s2s/s2s_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,45 @@
from dataclasses import dataclass, field
from typing import Optional, List

@dataclass
class VocabConfig:
text_vocabsize: int = 151936
text_specialtokens: int = 64
audio_vocabsize: int = 4096
audio_specialtokens: int = 64
total_vocabsize: int = 181120

padded_text_vocabsize: int = field(init=False)
padded_audio_vocabsize: int = field(init=False)

eot: int = field(init=False) # end of text token
pad_t: int = field(init=False) # padding text token
input_t: int = field(init=False) # input text token
answer_t: int = field(init=False) # answer text token
asr: int = field(init=False) # ASR token

eoa: int = field(init=False) # end of audio token
pad_a: int = field(init=False) # padding audio token
input_a: int = field(init=False) # input audio token
answer_a: int = field(init=False) # answer audio token
split: int = field(init=False) # split token

def __post_init__(self):
self.padded_text_vocabsize = self.text_vocabsize + self.text_specialtokens
self.padded_audio_vocabsize = self.audio_vocabsize + self.audio_specialtokens

self.eot = self.text_vocabsize
self.pad_t = self.text_vocabsize + 1
self.input_t = self.text_vocabsize + 2
self.answer_t = self.text_vocabsize + 3
self.asr = self.text_vocabsize + 4

self.eoa = self.audio_vocabsize
self.pad_a = self.audio_vocabsize + 1
self.input_a = self.audio_vocabsize + 2
self.answer_a = self.audio_vocabsize + 3
self.split = self.audio_vocabsize + 4

@dataclass
class ModelConfig:
file: str = "examples/s2s/model/slam_model_s2s.py:model_factory"
@@ -20,6 +60,7 @@ class ModelConfig:
encoder_type: str = field(default="finetune", metadata={
"help": "whether model is only pretrained or finetuned, used for models such as hubert"
})
vocab_config: VocabConfig = field(default_factory=VocabConfig)

@dataclass
class PeftConfig:
@@ -79,6 +120,8 @@ class TrainConfig:
})
freeze_encoder:bool = False



@dataclass
class DataConfig:
dataset: str = "speech_dataset_s2s"
@@ -106,6 +149,8 @@ class DataConfig:
manifest_format: str = field(default="datasets", metadata={ "help": "alternative: jsonl" })
split_size: float = 0.1

vocab_config: VocabConfig = field(default_factory=VocabConfig)

@dataclass
class FSDPConfig:
mixed_precision: bool = True
73 changes: 47 additions & 26 deletions src/slam_llm/datasets/speech_dataset_s2s.py
Original file line number Diff line number Diff line change
@@ -10,25 +10,25 @@
import librosa

# these tokens setting is from Mini-Omni
text_vocabsize = 151936
text_specialtokens = 64
audio_vocabsize = 4096
audio_specialtokens = 64
# text_vocabsize = 151936
# text_specialtokens = 64
# audio_vocabsize = 4096
# audio_specialtokens = 64

padded_text_vocabsize = text_vocabsize + text_specialtokens
padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
# padded_text_vocabsize = text_vocabsize + text_specialtokens
# padded_audio_vocabsize = audio_vocabsize + audio_specialtokens

_eot = text_vocabsize
_pad_t = text_vocabsize + 1
_input_t = text_vocabsize + 2
_answer_t = text_vocabsize + 3
_asr = text_vocabsize + 4
# _eot = text_vocabsize
# _pad_t = text_vocabsize + 1
# _input_t = text_vocabsize + 2
# _answer_t = text_vocabsize + 3
# _asr = text_vocabsize + 4

_eoa = audio_vocabsize
_pad_a = audio_vocabsize + 1
_input_a = audio_vocabsize + 2
_answer_a = audio_vocabsize + 3
_split = audio_vocabsize + 4
# _eoa = audio_vocabsize
# _pad_a = audio_vocabsize + 1
# _input_a = audio_vocabsize + 2
# _answer_a = audio_vocabsize + 3
# _split = audio_vocabsize + 4


class SpeechDatasetJsonl(torch.utils.data.Dataset):
@@ -58,8 +58,29 @@ def __init__(self,
assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]"
assert self.manifest_format in ["datasets", "jsonl"], "manifest_format must be one of [datasets, jsonl]"

self.special_token_a = _answer_a
self.special_token_t = _answer_t
# vocab config
self.vocab_config = dataset_config.get("vocab_config", None)
self.text_vocabsize = self.vocab_config.text_vocabsize
self.text_specialtokens = self.vocab_config.text_specialtokens
self.audio_vocabsize = self.vocab_config.audio_vocabsize
self.audio_specialtokens = self.vocab_config.audio_specialtokens
self.padded_text_vocabsize = self.vocab_config.padded_text_vocabsize
self.padded_audio_vocabsize = self.vocab_config.padded_audio_vocabsize
self.total_vocabsize = self.vocab_config.total_vocabsize
self._eot = self.vocab_config.eot
self._pad_t = self.vocab_config.pad_t
self._input_t = self.vocab_config.input_t
self._answer_t = self.vocab_config.answer_t
self._asr = self.vocab_config.asr
self._eoa = self.vocab_config.eoa
self._pad_a = self.vocab_config.pad_a
self._input_a = self.vocab_config.input_a
self._answer_a = self.vocab_config.answer_a
self._split = self.vocab_config.split

self.special_token_a = self._answer_a
self.special_token_t = self._answer_t


self.data_list = []

@@ -126,22 +147,22 @@ def get_input_ids(self, length, special_token_a, special_token_t):
input_ids = []
for i in range(7):
input_ids_item = []
input_ids_item.append(layershift(_input_a, i))
input_ids_item += [layershift(_pad_a, i)] * length
input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
input_ids_item.append(layershift(self._input_a, i))
input_ids_item += [layershift(self._pad_a, i)] * length
input_ids_item += [(layershift(self._eoa, i)), layershift(special_token_a, i)]
input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
input_id_T = torch.tensor([_input_t] + [_pad_t] * length + [_eot, special_token_t])
input_id_T = torch.tensor([self._input_t] + [self._pad_t] * length + [self._eot, special_token_t])
input_ids.append(input_id_T.unsqueeze(0))
return input_ids

def get_answer_ids(self, length):
answer_ids = []
for i in range(7):
answer_ids_item = []
answer_ids_item += [layershift(_pad_a, i)] * length
answer_ids_item += [(layershift(_eoa, i))]
answer_ids_item += [layershift(self._pad_a, i)] * length
answer_ids_item += [(layershift(self._eoa, i))]
answer_ids.append(torch.tensor(answer_ids_item).unsqueeze(0))
answer_id_T = torch.tensor([_pad_t] * length + [_eot])
answer_id_T = torch.tensor([self._pad_t] * length + [self._eot])
answer_ids.append(answer_id_T.unsqueeze(0))
return answer_ids

@@ -201,7 +222,7 @@ def __getitem__(self, index):

answer_text = self.answer_template.format(target_text)
answer_text_ids = self.tokenizer.encode(answer_text) # [prompt,answer]
answer_text_ids.append(_eot) # [prompt,answer,eos]
answer_text_ids.append(self._eot) # [prompt,answer,eos]
answer_text_ids = torch.tensor(answer_text_ids, dtype=torch.int64)
answer_ids = self.get_answer_ids(target_audio_length) # NOTE: suppose audio length is always longer than text length
answer_ids[7] = torch.cat((answer_text_ids.unsqueeze(0), answer_ids[7][:,len(answer_text_ids):]),dim=1)