@@ -121,28 +121,36 @@ class AscendAttentionState(Enum):
121121
122122@dataclass
123123class AscendMetadata :
124- num_actual_tokens : int # Number of tokens excluding padding.
125- # (batch_size, max_blocks_per_seq).
126- # Block addresses per sequence. (Seq id -> list of physical block)
127- block_tables : torch .Tensor
128- # (batch_size,). The sequence length per sequence. Sequence length means
129- # the computed tokens + new tokens None if it is a decoding.
130- query_start_loc : torch .Tensor
131- query_lens : torch .Tensor
132- seq_lens : torch .Tensor
133- # Maximum query length in the batch. None for decoding.
134- max_query_len : Optional [int ] = None
135- # (num_tokens,). The indices of the token slots that input tokens will be
136- # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
137- # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
138- # in block 0, and 1st slot in block 1, respectively.
139- slot_mapping : torch .Tensor = None
124+
125+ # **************************** Basic Properties ****************************
126+ attn_mask : Optional [torch .Tensor ] = None
140127 # Current state of this attention run.
141128 attn_state : AscendAttentionState = AscendAttentionState .ChunkedPrefill
142- attn_mask : Optional [torch .Tensor ] = None
143129
144- # For logging.
145- num_input_tokens : int = 0 # Number of tokens including padding.
130+ # Number of tokens excluding padding.
131+ num_actual_tokens : int = 0
132+
133+ # The sequence length per sequence. Sequence length means the computed
134+ # tokens + new tokens (is None if it is a decoding).
135+ # (batch_size,)
136+ seq_lens : torch .Tensor = None
137+
138+ query_start_loc : torch .Tensor = None
139+ query_lens : torch .Tensor = None
140+ # Maximum query length in the batch (None for decoding).
141+ max_query_len : Optional [int ] = None
142+
143+ # ********************** KV Cache Related Properties ***********************
144+ # Block addresses per sequence (Seq id -> list of physical block).
145+ # (batch_size, max_blocks_per_seq)
146+ block_tables : torch .Tensor = None
147+
148+ # The indices of the token slots that input tokens will be stored into.
149+ # E.g., if `slot_mapping` is [35, 2, 17] and the block size is 16, the
150+ # three tokens are stored in the 3rd slot in block 2, 2nd slot in block 0,
151+ # and 1st slot in block 1, respectively.
152+ # (num_tokens,)
153+ slot_mapping : torch .Tensor = None
146154
147155
148156class AscendAttentionMetadataBuilder :
0 commit comments