|
17 | 17 | # Only Neox style true scenario is supported for now |
18 | 18 | IS_NEOX_STYLE = [True] |
19 | 19 | DTYPES = [torch.half] |
20 | | -HEAD_SIZES = [64, 96, 128, 256] |
| 20 | +HEAD_SIZES = [64, 64, 96, 128, 256] |
21 | 21 | ROTARY_DIMS = [None, 32] # None means rotary dim == head size |
22 | 22 | NUM_HEADS = [17] # Arbitrary values for testing |
23 | 23 | BATCH_SIZES = [5] # Arbitrary values for testing |
24 | 24 | SEQ_LENS = [11, 4096] # Arbitrary values for testing |
| 25 | +NUM_TOKENS = [10, 21] |
25 | 26 | SEEDS = [0] |
26 | 27 | DEVICES = [f"npu:{0}"] |
27 | 28 | # Set tolerance to 1 for quant ops |
@@ -198,3 +199,146 @@ def test_rotary_embedding_quant_with_leading_dim( |
198 | 199 | ref_key, |
199 | 200 | atol=DEFAULT_ATOL, |
200 | 201 | rtol=DEFAULT_RTOL) |
| 202 | + |
| 203 | + |
| 204 | +class ModelwithRotaryEmbedding(nn.Module): |
| 205 | + |
| 206 | + def __init__( |
| 207 | + self, |
| 208 | + hidden_size: int, |
| 209 | + num_heads: int, |
| 210 | + head_size: int, |
| 211 | + rotary_dim: int, |
| 212 | + max_position_embeddings: int, |
| 213 | + base: int, |
| 214 | + is_neox_style: bool, |
| 215 | + dtype: torch.dtype, |
| 216 | + ) -> None: |
| 217 | + super().__init__() |
| 218 | + self.qkv_proj = nn.Linear(hidden_size, num_heads * head_size * 3) |
| 219 | + self.rope = RotaryEmbedding( |
| 220 | + head_size=head_size, |
| 221 | + rotary_dim=rotary_dim, |
| 222 | + max_position_embeddings=max_position_embeddings, |
| 223 | + base=base, |
| 224 | + is_neox_style=is_neox_style, |
| 225 | + dtype=dtype, |
| 226 | + ) |
| 227 | + self.o_proj = nn.Linear(num_heads * head_size, hidden_size) |
| 228 | + |
| 229 | + def forward( |
| 230 | + self, |
| 231 | + positions: torch.Tensor, |
| 232 | + hidden_states: torch.Tensor, |
| 233 | + offsets: Optional[torch.Tensor] = None, |
| 234 | + ) -> torch.Tensor: |
| 235 | + # we simulated a simple attention layer to test if it can be seamlessly captured into aclgraph |
| 236 | + qkv = self.qkv_proj(hidden_states) |
| 237 | + q, k, v = qkv.chunk(3, dim=-1) |
| 238 | + query, key = torch.ops._C.rotary_embedding( |
| 239 | + positions, |
| 240 | + q, |
| 241 | + k, |
| 242 | + self.rope.head_size, |
| 243 | + self.rope.cos_sin_cache, |
| 244 | + self.rope.is_neox_style, |
| 245 | + ) |
| 246 | + query = query.view(q.shape) |
| 247 | + key = key.view(k.shape) |
| 248 | + o = self.o_proj(query) |
| 249 | + return o |
| 250 | + |
| 251 | + |
| 252 | +# The first graph seems will have some accuracy issue when directly run pytest on the ops folder, |
| 253 | +# add a warmup graph replay for workaround |
| 254 | +ACL_GRPAH_FIRST_RUN = True |
| 255 | + |
| 256 | + |
| 257 | +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) |
| 258 | +@pytest.mark.parametrize("num_tokens", BATCH_SIZES) |
| 259 | +@pytest.mark.parametrize("num_heads", NUM_HEADS) |
| 260 | +@pytest.mark.parametrize("head_size", HEAD_SIZES) |
| 261 | +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) |
| 262 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 263 | +@pytest.mark.parametrize("seed", SEEDS) |
| 264 | +@pytest.mark.parametrize("device", DEVICES) |
| 265 | +@torch.inference_mode() |
| 266 | +def test_capture_rotary_embedding_in_aclgraph( |
| 267 | + is_neox_style: bool, |
| 268 | + num_tokens: int, |
| 269 | + num_heads: int, |
| 270 | + head_size: int, |
| 271 | + rotary_dim: int, |
| 272 | + dtype: torch.dtype, |
| 273 | + seed: int, |
| 274 | + device: str, |
| 275 | + max_position_embeddings: int = 8192, |
| 276 | + base: int = 10000, |
| 277 | +): |
| 278 | + """Test if the rotary embedding can be captured in aclgraph.""" |
| 279 | + torch.manual_seed(seed) |
| 280 | + torch.set_default_device(device) |
| 281 | + if rotary_dim is None: |
| 282 | + rotary_dim = head_size |
| 283 | + model = ModelwithRotaryEmbedding( |
| 284 | + hidden_size=num_heads * head_size, |
| 285 | + num_heads=num_heads, |
| 286 | + head_size=head_size, |
| 287 | + rotary_dim=rotary_dim, |
| 288 | + max_position_embeddings=max_position_embeddings, |
| 289 | + base=base, |
| 290 | + is_neox_style=is_neox_style, |
| 291 | + dtype=dtype, |
| 292 | + ) |
| 293 | + |
| 294 | + def custom_op_checking_backend(gm: torch.fx.GraphModule, example_input): |
| 295 | + # Validate if the rotary_embedding custom kernel is indeed inside the graph by |
| 296 | + # string match |
| 297 | + graph = str(gm.graph) |
| 298 | + assert "_C.rotary_embedding" in graph |
| 299 | + return gm |
| 300 | + |
| 301 | + static_positions = torch.randint(0, max_position_embeddings, |
| 302 | + (num_tokens, )) |
| 303 | + static_hidden_states = torch.randn(num_tokens, |
| 304 | + num_heads * head_size, |
| 305 | + dtype=dtype, |
| 306 | + device="npu") |
| 307 | + compiled_model = torch.compile(model, backend=custom_op_checking_backend) |
| 308 | + stream = torch.npu.Stream() |
| 309 | + stream.wait_stream(torch.npu.current_stream()) |
| 310 | + with torch.npu.stream(stream): |
| 311 | + # warmup the fx graph before capture |
| 312 | + for i in range(3): |
| 313 | + static_output = compiled_model(static_positions, |
| 314 | + static_hidden_states, |
| 315 | + offsets=None) |
| 316 | + stream.wait_stream(torch.npu.current_stream()) |
| 317 | + |
| 318 | + aclgraph = torch.npu.NPUGraph() |
| 319 | + |
| 320 | + with torch.npu.graph(aclgraph): |
| 321 | + # Capture the model in aclgraph. |
| 322 | + static_output = compiled_model(static_positions, static_hidden_states) |
| 323 | + # Capture the model in aclgraph. |
| 324 | + random_filled_positions = torch.randint(0, |
| 325 | + max_position_embeddings, |
| 326 | + (num_tokens, ), |
| 327 | + device="npu") |
| 328 | + random_filled_hidden_states = torch.randn(num_tokens, |
| 329 | + num_heads * head_size, |
| 330 | + dtype=dtype, |
| 331 | + device="npu") |
| 332 | + static_positions.copy_(random_filled_positions) |
| 333 | + static_hidden_states.copy_(random_filled_hidden_states) |
| 334 | + |
| 335 | + aclgraph.replay() |
| 336 | + global ACL_GRPAH_FIRST_RUN |
| 337 | + if ACL_GRPAH_FIRST_RUN: |
| 338 | + ACL_GRPAH_FIRST_RUN = False |
| 339 | + return |
| 340 | + output_reference = model(static_positions, static_hidden_states) |
| 341 | + torch.testing.assert_close(static_output, |
| 342 | + output_reference, |
| 343 | + atol=DEFAULT_ATOL, |
| 344 | + rtol=DEFAULT_RTOL) |
0 commit comments