110110
111111if is_310p ():
112112 torch_npu .npu .set_compile_mode (jit_compile = False )
113+ ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
114+ else :
115+ ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
113116
114117
115118@dataclass
@@ -2047,8 +2050,8 @@ def load_model(self) -> None:
20472050 if isinstance (module ,
20482051 (MergedColumnParallelLinear ,
20492052 QKVParallelLinear , RowParallelLinear )):
2050- module .weight .data = torch_npu . npu_format_cast (
2051- module .weight .data , ACL_FORMAT_FRACTAL_NZ )
2053+ module .weight .data = self . _convert_torch_foramt (
2054+ module .weight .data )
20522055 if self .drafter :
20532056 logger .info ("Loading drafter model..." )
20542057 if isinstance (self .drafter , EagleProposer ):
@@ -2133,6 +2136,10 @@ def _get_torchair_lazy_compiled_model(self, batch_size: int):
21332136 ge_cache = False )
21342137 return self .torchair_compiled_models [batch_size ]
21352138
2139+ def _convert_torch_foramt (self , tensor ):
2140+ tensor = torch_npu .npu_format_cast (tensor , ACL_FORMAT )
2141+ return tensor
2142+
21362143 def initialize_kv_cache (self , kv_cache_config : KVCacheConfig ) -> None :
21372144 """
21382145 Initialize KV cache based on `kv_cache_config`.
@@ -2141,9 +2148,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
21412148 cache size of each layer
21422149 """
21432150 self .kv_cache_config = kv_cache_config
2144- import torch_npu
2145- acl_format = ACL_FORMAT_FRACTAL_NZ if is_310p (
2146- ) and not self .torchair_graph_enabled else ACL_FORMAT_FRACTAL_ND
21472151 kv_caches : Dict [str , torch .Tensor ] = {}
21482152
21492153 def align_memory (tensor : torch .Tensor , alignment : int ) -> torch .Tensor :
@@ -2202,7 +2206,6 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
22022206 kv_cache_spec .head_size )
22032207 dtype = kv_cache_spec .dtype
22042208 if self .model_config .is_deepseek_mla :
2205-
22062209 num_blocks , block_size , num_kv_heads , head_size = kv_cache_shape
22072210 rope_dim = self .model_config .hf_text_config .qk_rope_head_dim
22082211 nope_dim = head_size - rope_dim
@@ -2218,10 +2221,8 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
22182221 nope_cache = torch .zeros (nope_cache_shape ,
22192222 dtype = dtype ,
22202223 device = self .device )
2221- rope_cache = torch_npu .npu_format_cast (
2222- rope_cache , acl_format )
2223- nope_cache = torch_npu .npu_format_cast (
2224- nope_cache , acl_format )
2224+ rope_cache = self ._convert_torch_foramt (rope_cache )
2225+ nope_cache = self ._convert_torch_foramt (nope_cache )
22252226 else :
22262227
22272228 # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
@@ -2259,8 +2260,7 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
22592260 kv_cache = torch .zeros (cache_shape ,
22602261 dtype = dtype ,
22612262 device = self .device )
2262- kv_cache = torch_npu .npu_format_cast (
2263- kv_cache , acl_format )
2263+ kv_cache = self ._convert_torch_foramt (kv_cache )
22642264 else :
22652265 cache_size = math .prod (cache_shape )
22662266 cache_size_aligned = cache_size + alignment
0 commit comments