Skip to content

Commit 56d68c6

Browse files
authored
Addiing ByteDance Seed Seed-OSS (#40272)
add seed oss
1 parent 8a6908c commit 56d68c6

File tree

12 files changed

+1232
-0
lines changed

12 files changed

+1232
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,8 @@
675675
title: RoFormer
676676
- local: model_doc/rwkv
677677
title: RWKV
678+
- local: model_doc/seed_oss
679+
title: Seed-Oss
678680
- local: model_doc/splinter
679681
title: Splinter
680682
- local: model_doc/squeezebert
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
<!--
2+
# Copyright 2025 Bytedance-Seed Ltd and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License. -->
15+
16+
# SeedOss
17+
18+
## Overview
19+
20+
To be released with the official model launch.
21+
22+
### Model Details
23+
24+
To be released with the official model launch.
25+
26+
## Usage tips
27+
28+
To be released with the official model launch.
29+
30+
## SeedOssConfig
31+
32+
[[autodoc]] SeedOssConfig
33+
34+
## SeedOssModel
35+
36+
[[autodoc]] SeedOssModel
37+
- forward
38+
39+
## SeedOssForCausalLM
40+
41+
[[autodoc]] SeedOssForCausalLM
42+
- forward
43+
44+
## SeedOssForSequenceClassification
45+
46+
[[autodoc]] SeedOssForSequenceClassification
47+
- forward
48+
49+
## SeedOssForTokenClassification
50+
51+
[[autodoc]] SeedOssForTokenClassification
52+
- forward
53+
54+
## SeedOssForQuestionAnswering
55+
56+
[[autodoc]] SeedOssForQuestionAnswering
57+
- forward

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@
295295
from .sam_hq import *
296296
from .seamless_m4t import *
297297
from .seamless_m4t_v2 import *
298+
from .seed_oss import *
298299
from .segformer import *
299300
from .seggpt import *
300301
from .sew import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@
348348
("sam_vision_model", "SamVisionConfig"),
349349
("seamless_m4t", "SeamlessM4TConfig"),
350350
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
351+
("seed_oss", "SeedOssConfig"),
351352
("segformer", "SegformerConfig"),
352353
("seggpt", "SegGptConfig"),
353354
("sew", "SEWConfig"),
@@ -782,6 +783,7 @@
782783
("sam_vision_model", "SamVisionModel"),
783784
("seamless_m4t", "SeamlessM4T"),
784785
("seamless_m4t_v2", "SeamlessM4Tv2"),
786+
("seed_oss", "SeedOss"),
785787
("segformer", "SegFormer"),
786788
("seggpt", "SegGPT"),
787789
("sew", "SEW"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
339339
("sam_vision_model", "SamVisionModel"),
340340
("seamless_m4t", "SeamlessM4TModel"),
341341
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
342+
("seed_oss", "SeedOssModel"),
342343
("segformer", "SegformerModel"),
343344
("seggpt", "SegGptModel"),
344345
("sew", "SEWModel"),
@@ -718,6 +719,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
718719
("roc_bert", "RoCBertForCausalLM"),
719720
("roformer", "RoFormerForCausalLM"),
720721
("rwkv", "RwkvForCausalLM"),
722+
("seed_oss", "SeedOssForCausalLM"),
721723
("smollm3", "SmolLM3ForCausalLM"),
722724
("speech_to_text_2", "Speech2Text2ForCausalLM"),
723725
("stablelm", "StableLmForCausalLM"),
@@ -1264,6 +1266,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
12641266
("roberta-prelayernorm", "RobertaPreLayerNormForSequenceClassification"),
12651267
("roc_bert", "RoCBertForSequenceClassification"),
12661268
("roformer", "RoFormerForSequenceClassification"),
1269+
("seed_oss", "SeedOssForSequenceClassification"),
12671270
("smollm3", "SmolLM3ForSequenceClassification"),
12681271
("squeezebert", "SqueezeBertForSequenceClassification"),
12691272
("stablelm", "StableLmForSequenceClassification"),
@@ -1352,6 +1355,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
13521355
("roberta-prelayernorm", "RobertaPreLayerNormForQuestionAnswering"),
13531356
("roc_bert", "RoCBertForQuestionAnswering"),
13541357
("roformer", "RoFormerForQuestionAnswering"),
1358+
("seed_oss", "SeedOssForQuestionAnswering"),
13551359
("smollm3", "SmolLM3ForQuestionAnswering"),
13561360
("splinter", "SplinterForQuestionAnswering"),
13571361
("squeezebert", "SqueezeBertForQuestionAnswering"),
@@ -1462,6 +1466,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
14621466
("roberta-prelayernorm", "RobertaPreLayerNormForTokenClassification"),
14631467
("roc_bert", "RoCBertForTokenClassification"),
14641468
("roformer", "RoFormerForTokenClassification"),
1469+
("seed_oss", "SeedOssForTokenClassification"),
14651470
("smollm3", "SmolLM3ForTokenClassification"),
14661471
("squeezebert", "SqueezeBertForTokenClassification"),
14671472
("stablelm", "StableLmForTokenClassification"),
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 Bytedance-Seed Ltd and the HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_seed_oss import *
22+
from .modeling_seed_oss import *
23+
else:
24+
import sys
25+
26+
_file = globals()["__file__"]
27+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright 2025 Bytedance-Seed Ltd and the HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""SeedOss model configuration"""
15+
16+
from transformers.configuration_utils import PretrainedConfig
17+
from transformers.modeling_rope_utils import rope_config_validation
18+
19+
20+
class SeedOssConfig(PretrainedConfig):
21+
r"""
22+
This is the configuration class to store the configuration of a [`SeedOssModel`]. It is used to instantiate an SeedOss
23+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
24+
defaults will yield a similar configuration to that of the SeedOss-36B.
25+
e.g. [ByteDance-Seed/SeedOss-36B](https://huggingface.co/ByteDance-Seed/SeedOss-36B)
26+
27+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
28+
documentation from [`PretrainedConfig`] for more information.
29+
30+
31+
Args:
32+
vocab_size (`int`, *optional*, defaults to 155136):
33+
Vocabulary size of the SeedOss model. Defines the number of different tokens that can be represented by the
34+
`inputs_ids` passed when calling [`SeedOssModel`]
35+
hidden_size (`int`, *optional*, defaults to 4096):
36+
Dimension of the hidden representations.
37+
intermediate_size (`int`, *optional*, defaults to 27648):
38+
Dimension of the MLP representations.
39+
num_hidden_layers (`int`, *optional*, defaults to 64):
40+
Number of hidden layers in the Transformer decoder.
41+
num_attention_heads (`int`, *optional*, defaults to 80):
42+
Number of attention heads for each attention layer in the Transformer decoder.
43+
num_key_value_heads (`int`, *optional*, defaults to 8):
44+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
45+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
46+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
47+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
48+
by meanpooling all the original heads within that group. For more details, check out [this
49+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
50+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
51+
The non-linear activation function (function or string) in the decoder.
52+
max_position_embeddings (`int`, *optional*, defaults to 524288):
53+
The maximum sequence length that this model might ever be used with.
54+
initializer_range (`float`, *optional*, defaults to 0.02):
55+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
56+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
57+
The epsilon used by the rms normalization layers.
58+
use_cache (`bool`, *optional*, defaults to `True`):
59+
Whether or not the model should return the last key/values attentions (not used by all models). Only
60+
relevant if `config.is_decoder=True`.
61+
pad_token_id (`int`, *optional*, defaults to 1):
62+
Padding token id.
63+
bos_token_id (`int`, *optional*, defaults to 0):
64+
Beginning of stream token id.
65+
eos_token_id (`int`, *optional*, defaults to 2):
66+
End of stream token id.
67+
pretraining_tp (`int`, *optional*, defaults to 1):
68+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
69+
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
70+
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
71+
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
72+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
73+
Whether to tie weight embeddings
74+
rope_theta (`float`, *optional*, defaults to 10000.0):
75+
The base period of the RoPE embeddings.
76+
rope_scaling (`Dict`, *optional*):
77+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
78+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
79+
accordingly.
80+
Expected contents:
81+
`rope_type` (`str`):
82+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
83+
'llama3'], with 'default' being the original RoPE implementation.
84+
`factor` (`float`, *optional*):
85+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
86+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
87+
original maximum pre-trained length.
88+
`original_max_position_embeddings` (`int`, *optional*):
89+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
90+
pretraining.
91+
`attention_factor` (`float`, *optional*):
92+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
93+
computation. If unspecified, it defaults to value recommended by the implementation, using the
94+
`factor` field to infer the suggested value.
95+
`beta_fast` (`float`, *optional*):
96+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
97+
ramp function. If unspecified, it defaults to 32.
98+
`beta_slow` (`float`, *optional*):
99+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
100+
ramp function. If unspecified, it defaults to 1.
101+
`short_factor` (`list[float]`, *optional*):
102+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
103+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
104+
size divided by the number of attention heads divided by 2
105+
`long_factor` (`list[float]`, *optional*):
106+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
107+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
108+
size divided by the number of attention heads divided by 2
109+
`low_freq_factor` (`float`, *optional*):
110+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
111+
`high_freq_factor` (`float`, *optional*):
112+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
113+
attention_bias (`bool`, *optional*, defaults to `True`):
114+
Whether to use a bias in the query, key, value layers during self-attention.
115+
attention_out_bias (`bool`, *optional*, defaults to `False`):
116+
Whether to use a bias in the output projection layer during self-attention.
117+
attention_dropout (`float`, *optional*, defaults to 0.1):
118+
The dropout ratio for the attention probabilities.
119+
residual_dropout (`float`, *optional*, defaults to 0.1):
120+
Residual connection dropout value.
121+
mlp_bias (`bool`, *optional*, defaults to `False`):
122+
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
123+
head_dim (`int`, *optional*, defaults to 128):
124+
The attention head dimension.
125+
126+
```python
127+
>>> from transformers import SeedOssModel, SeedOssConfig
128+
129+
>>> # Initializing a SeedOss-36b style configuration
130+
>>> configuration = SeedOssConfig()
131+
132+
>>> # Initializing a model from the SeedOss-36b style configuration
133+
>>> model = SeedOssModel(configuration)
134+
135+
>>> # Accessing the model configuration
136+
>>> configuration = model.config
137+
```"""
138+
139+
model_type = "seed_oss"
140+
keys_to_ignore_at_inference = ["past_key_values"]
141+
# Default tensor parallel plan for base model `SeedOssModel`
142+
base_model_tp_plan = {
143+
"layers.*.self_attn.q_proj": "colwise",
144+
"layers.*.self_attn.k_proj": "colwise",
145+
"layers.*.self_attn.v_proj": "colwise",
146+
"layers.*.self_attn.o_proj": "rowwise",
147+
"layers.*.mlp.gate_proj": "colwise",
148+
"layers.*.mlp.up_proj": "colwise",
149+
"layers.*.mlp.down_proj": "rowwise",
150+
}
151+
base_model_pp_plan = {
152+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
153+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
154+
"norm": (["hidden_states"], ["hidden_states"]),
155+
}
156+
157+
def __init__(
158+
self,
159+
vocab_size=155136,
160+
hidden_size=4096,
161+
intermediate_size=27648,
162+
num_hidden_layers=64,
163+
num_attention_heads=80,
164+
num_key_value_heads=8,
165+
hidden_act="silu",
166+
max_position_embeddings=524288,
167+
initializer_range=0.02,
168+
rms_norm_eps=1e-6,
169+
use_cache=True,
170+
pad_token_id=1,
171+
bos_token_id=0,
172+
eos_token_id=2,
173+
pretraining_tp=1,
174+
tie_word_embeddings=False,
175+
rope_theta=10000.0,
176+
rope_scaling=None,
177+
attention_bias=True,
178+
attention_out_bias=False,
179+
attention_dropout=0.1,
180+
residual_dropout=0.1,
181+
mlp_bias=False,
182+
head_dim=128,
183+
**kwargs,
184+
):
185+
self.vocab_size = vocab_size
186+
self.max_position_embeddings = max_position_embeddings
187+
self.hidden_size = hidden_size
188+
self.intermediate_size = intermediate_size
189+
self.num_hidden_layers = num_hidden_layers
190+
self.num_attention_heads = num_attention_heads
191+
# for backward compatibility
192+
if num_key_value_heads is None:
193+
num_key_value_heads = num_attention_heads
194+
195+
self.num_key_value_heads = num_key_value_heads
196+
self.hidden_act = hidden_act
197+
self.initializer_range = initializer_range
198+
self.rms_norm_eps = rms_norm_eps
199+
self.pretraining_tp = pretraining_tp
200+
self.use_cache = use_cache
201+
self.rope_theta = rope_theta
202+
self.rope_scaling = rope_scaling
203+
self.attention_bias = attention_bias
204+
self.attention_out_bias = attention_out_bias
205+
self.attention_dropout = attention_dropout
206+
self.residual_dropout = residual_dropout
207+
self.mlp_bias = mlp_bias
208+
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
209+
# Validate the correctness of rotary position embeddings parameters
210+
# BC: if there is a 'type' field, copy it it to 'rope_type'.
211+
if self.rope_scaling is not None and "type" in self.rope_scaling:
212+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
213+
rope_config_validation(self)
214+
215+
super().__init__(
216+
pad_token_id=pad_token_id,
217+
bos_token_id=bos_token_id,
218+
eos_token_id=eos_token_id,
219+
tie_word_embeddings=tie_word_embeddings,
220+
**kwargs,
221+
)
222+
223+
224+
__all__ = ["SeedOssConfig"]

0 commit comments

Comments
 (0)