@@ -224,8 +224,6 @@ class Mamba2Layer(ABC):
224224 Inherit from this class if you implement a custom Mamba2 layer.
225225 """
226226
227- chunk_size : int
228-
229227 # Contains the KV cache (mamba state) for the layer
230228 # in the shape specified by `self.get_state_shape`.
231229 # The outer list is for v0 PP virtual engine. Though this code path
@@ -257,22 +255,21 @@ class MambaMixer2(Mamba2Layer, CustomOp):
257255 """
258256
259257 def __init__ (
260- self ,
261- hidden_size : int ,
262- ssm_state_size : int ,
263- conv_kernel_size : int ,
264- intermediate_size : int ,
265- use_conv_bias : bool ,
266- use_bias : bool ,
267- n_groups : int = 1 ,
268- num_heads : int = 128 ,
269- head_dim : int = 64 ,
270- rms_norm_eps : float = 1e-5 ,
271- activation : str = "silu" ,
272- use_rms_norm : bool = True ,
273- quant_config : Optional [QuantizationConfig ] = None ,
274- prefix : str = "" ,
275- chunk_size : int = - 1 , # the chunk size used by v1
258+ self ,
259+ hidden_size : int ,
260+ ssm_state_size : int ,
261+ conv_kernel_size : int ,
262+ intermediate_size : int ,
263+ use_conv_bias : bool ,
264+ use_bias : bool ,
265+ n_groups : int = 1 ,
266+ num_heads : int = 128 ,
267+ head_dim : int = 64 ,
268+ rms_norm_eps : float = 1e-5 ,
269+ activation : str = "silu" ,
270+ use_rms_norm : bool = True ,
271+ quant_config : Optional [QuantizationConfig ] = None ,
272+ prefix : str = "" ,
276273 ):
277274 super ().__init__ ()
278275
@@ -454,10 +451,7 @@ def __init__(
454451 # of Attention + v0 PP.
455452 # The inner tuple is (conv_state, ssm_state)
456453 self .kv_cache = [(torch .tensor ([]), torch .tensor ([]))]
457- assert chunk_size != - 1 , "chunk_size must be set for v1"
458454
459- # NOTE: chunk_size may be -1 for models without v1 support
460- self .chunk_size = chunk_size
461455 self .prefix = prefix
462456
463457 def forward_native (
0 commit comments