@@ -121,28 +121,40 @@ class AscendAttentionState(Enum):
121121
122122@dataclass
123123class AscendMetadata :
124- num_actual_tokens : int # Number of tokens excluding padding.
125- # (batch_size, max_blocks_per_seq).
126- # Block addresses per sequence. (Seq id -> list of physical block)
127- block_tables : torch .Tensor
128- # (batch_size,). The sequence length per sequence. Sequence length means
129- # the computed tokens + new tokens None if it is a decoding.
130- query_start_loc : torch .Tensor
131- query_lens : torch .Tensor
132- seq_lens : torch .Tensor
133- # max value of number of tokens across dp group
134- max_num_tokens_across_dp : int = 0
135- # Maximum query length in the batch. None for decoding.
136- max_query_len : Optional [int ] = None
137- # (num_tokens,). The indices of the token slots that input tokens will be
138- # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
139- # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
140- # in block 0, and 1st slot in block 1, respectively.
141- slot_mapping : torch .Tensor = None
124+ # **************************** Basic Properties ****************************
125+ attn_mask : Optional [torch .Tensor ] = None
142126 # Current state of this attention run.
143127 attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
144- attn_mask : Optional [torch .Tensor ] = None
128+
129+ # Number of tokens excluding padding.
130+ num_actual_tokens : int = 0
131+
132+ # The sequence length per sequence. Sequence length means the computed
133+ # tokens + new tokens (is None if it is a decoding).
134+ # (batch_size,)
135+ seq_lens : torch .Tensor = None
136+
137+ query_start_loc : torch .Tensor = None
138+ query_lens : torch .Tensor = None
139+ # Maximum query length in the batch (None for decoding).
140+ max_query_len : Optional [int ] = None
141+
142+ # ********************** KV Cache Related Properties ***********************
143+ # Block addresses per sequence (Seq id -> list of physical block).
144+ # (batch_size, max_blocks_per_seq)
145+ block_tables : torch .Tensor = None
146+
147+ # The indices of the token slots that input tokens will be stored into.
148+ # E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
149+ # three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
150+ # and 1st slot in block 1, respectively.
151+ # (num_tokens,)
152+ slot_mapping : torch .Tensor = None
153+
154+ # ************************* DP Related Properties **************************
145155 with_prefill_across_dp : bool = False
156+ # Maximum number of tokens across dp group
157+ max_num_tokens_across_dp : int = 0
146158
147159
148160class AscendAttentionMetadataBuilder :
0 commit comments