|
2 | 2 | import inspect
|
3 | 3 | import logging
|
4 | 4 | import threading
|
5 |
| -import warnings |
6 | 5 | from collections import defaultdict
|
7 | 6 | from typing import Any, Dict, List, Optional, Union
|
8 | 7 |
|
@@ -363,193 +362,6 @@ def mlir_to_functions(op) -> None:
|
363 | 362 | return functions
|
364 | 363 |
|
365 | 364 |
|
366 |
| -def parse_ttir(ttir, kwargs): |
367 |
| - """ |
368 |
| - Given a Triton emitted TTIR text, this function lexes and parses the |
369 |
| - code using a minimal grammar defined inside. During the lexing/parsing, |
370 |
| - we drop any constant value and type information as they are not |
371 |
| - necessary to us. |
372 |
| - Being able to choose what we need makes this not a general purpose TTIR |
373 |
| - parser which further makes parsing much simpler. |
374 |
| - """ |
375 |
| - # TODO(oulgen): |
376 |
| - # - Support closures (e.g. "tt.reduce") |
377 |
| - |
378 |
| - try: |
379 |
| - import lark # type: ignore[import-not-found] |
380 |
| - from lark import Lark, Transformer, v_args |
381 |
| - except ModuleNotFoundError: |
382 |
| - warnings.warn( |
383 |
| - "Using slow path for user-defined Triton kernels. `pip install lark` to fix this." |
384 |
| - ) |
385 |
| - raise |
386 |
| - |
387 |
| - # Ops looks like one of the following forms: |
388 |
| - # |
389 |
| - # %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32> |
390 |
| - # tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32> |
391 |
| - # %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950 |
392 |
| - grammar = """ |
393 |
| - start: (module_block | loc_line)+ |
394 |
| -
|
395 |
| - loc_line: "#loc" /.+/ NEWLINE |
396 |
| -
|
397 |
| - module_block: "module" "{" func_block+ "}" LOC |
398 |
| -
|
399 |
| - func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func |
400 |
| -
|
401 |
| - ?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt |
402 |
| -
|
403 |
| - if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if |
404 |
| - for: [assign_lhs "="] "scf.for" args rest stmt* "}" divisibility_annot? LOC -> process_for |
405 |
| - while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while |
406 |
| -
|
407 |
| - condition_stmt: "scf.condition" "(" arg ")" args rest |
408 |
| - label_stmt: LABEL ":" "// pred:" LABEL |
409 |
| - | LABEL "(" /.+/ NEWLINE |
410 |
| - cf_stmt: "cf" "." NAME /.+/ NEWLINE |
411 |
| -
|
412 |
| - op: OP_NAME LOC |
413 |
| - | [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op |
414 |
| -
|
415 |
| - ?rest: (":" | "{" | "\\"" | "->" | "<" | "=") /.+/ NEWLINE |
416 |
| - divisibility_annot: "{" "tt.divisibility_arg1" /[^}]+/ "}" |
417 |
| -
|
418 |
| - args: | "(" ")" | "("? arg ("," arg)* ")"? |
419 |
| -
|
420 |
| - ?arg: INTERMEDIATE |
421 |
| - | INTERMEDIATE_CONSTANT |
422 |
| - | CONSTANT |
423 |
| - | PARAM |
424 |
| - | "[" args "]" |
425 |
| - | arg_with_index |
426 |
| -
|
427 |
| - ?arg_with_index: arg "#" DIGIT+ |
428 |
| -
|
429 |
| - ?assign_lhs: (INTERMEDIATE | INTERMEDIATE_CONSTANT) [":" DIGIT+] |
430 |
| -
|
431 |
| - PARAM.5: "%arg" DIGIT+ |
432 |
| - INTERMEDIATE.4: "%" DIGIT+ |
433 |
| - INTERMEDIATE_CONSTANT.3: "%" NAME |
434 |
| - CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")? |
435 |
| - LABEL: "^bb" DIGIT+ |
436 |
| -
|
437 |
| - NAME: (LETTER | DIGIT | "_")+ |
438 |
| - NON_CF_NAME: /(?!(cf))/ NAME |
439 |
| - FN_NAME: "@" (NAME | ESCAPED_STRING) |
440 |
| - OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""? |
441 |
| -
|
442 |
| - LOC.5: "loc(#loc" DIGIT* ")" |
443 |
| -
|
444 |
| - %import common.LETTER |
445 |
| - %import common.DIGIT |
446 |
| - %import common.WS |
447 |
| - %import common.NEWLINE |
448 |
| - %import common.ESCAPED_STRING |
449 |
| - %import common.FLOAT |
450 |
| - %ignore WS |
451 |
| - """ |
452 |
| - |
453 |
| - next_fake_intermediate = 0 |
454 |
| - |
455 |
| - def convert(token): |
456 |
| - if isinstance(token, lark.tree.Tree): |
457 |
| - if token.data == "args": |
458 |
| - res = [] |
459 |
| - for a in token.children: |
460 |
| - c = convert(a) |
461 |
| - if isinstance(c, list): |
462 |
| - res.extend(c) |
463 |
| - else: |
464 |
| - res.append(c) |
465 |
| - return res |
466 |
| - elif token.data in {"assign_lhs", "arg_with_index"}: |
467 |
| - # Drop length/index qualifier |
468 |
| - return convert(token.children[0]) |
469 |
| - else: |
470 |
| - raise AssertionError(f"Tree node with {token.data}") |
471 |
| - |
472 |
| - if token is None or ( |
473 |
| - isinstance(token, lark.lexer.Token) |
474 |
| - and token.type in ("CONSTANT", "INTERMEDIATE_CONSTANT") |
475 |
| - ): |
476 |
| - nonlocal next_fake_intermediate |
477 |
| - next_fake_intermediate -= 1 |
478 |
| - return Intermediate(next_fake_intermediate) |
479 |
| - |
480 |
| - assert isinstance(token, lark.lexer.Token) |
481 |
| - |
482 |
| - if token.type == "INTERMEDIATE": |
483 |
| - return Intermediate(int(token.value[len("%") :])) |
484 |
| - if token.type == "PARAM": |
485 |
| - return Param(int(token.value[len("%arg") :])) |
486 |
| - |
487 |
| - raise AssertionError(f"{type(token.type)} => {token.value} invalid") |
488 |
| - |
489 |
| - # In alternative representation, function names are quoted. |
490 |
| - # It should be possible to move this into the grammar alltogether. |
491 |
| - def convert_name(token): |
492 |
| - if token is None: |
493 |
| - return None |
494 |
| - s = token.value |
495 |
| - if len(s) > 2 and s[0] == '"' and s[-1] == '"': |
496 |
| - return s[1:-1] |
497 |
| - return s |
498 |
| - |
499 |
| - functions: Dict[str, Dict[Intermediate, List[Op]]] = {} |
500 |
| - |
501 |
| - def extend_dict_list(d1, d2): |
502 |
| - for key, values in d2.items(): |
503 |
| - d1[key].extend(values) |
504 |
| - |
505 |
| - @v_args(inline=True) |
506 |
| - class TransformOps(Transformer): |
507 |
| - def process_op(self, ret, op_name, fn_name, args, *rest): |
508 |
| - return Op( |
509 |
| - convert_name(op_name), |
510 |
| - convert_name(fn_name), |
511 |
| - convert(args), |
512 |
| - convert(ret), |
513 |
| - ) |
514 |
| - |
515 |
| - def process_func(self, name, _args, *stmts): |
516 |
| - ops: Dict[Intermediate, List[Op]] = defaultdict(list) |
517 |
| - for e in stmts: |
518 |
| - if isinstance(e, Op): |
519 |
| - ops[e.ret].append(e) |
520 |
| - elif isinstance(e, dict): |
521 |
| - extend_dict_list(ops, e) |
522 |
| - functions[name.value] = ops |
523 |
| - |
524 |
| - def _process_scf(self, ret, stmts): |
525 |
| - ret = convert(ret) |
526 |
| - ops: Dict[Intermediate, List[Op]] = defaultdict(list) |
527 |
| - for e in stmts: |
528 |
| - if isinstance(e, Op): |
529 |
| - if e.name == "scf.yield": |
530 |
| - ops[ret].append(Op(e.name, None, e.args, ret)) |
531 |
| - else: |
532 |
| - ops[e.ret].append(e) |
533 |
| - elif isinstance(e, dict): |
534 |
| - extend_dict_list(ops, e) |
535 |
| - return ops |
536 |
| - |
537 |
| - def process_if(self, ret, _args, _rest, *stmts): |
538 |
| - return self._process_scf(ret, stmts) |
539 |
| - |
540 |
| - def process_for(self, ret, _args, _rest, *stmts): |
541 |
| - return self._process_scf(ret, stmts) |
542 |
| - |
543 |
| - def process_while(self, ret, _args, _rest, *stmts): |
544 |
| - return self._process_scf(ret, stmts) |
545 |
| - |
546 |
| - parser = Lark( |
547 |
| - grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps() |
548 |
| - ) |
549 |
| - parser.parse(ttir) |
550 |
| - return functions |
551 |
| - |
552 |
| - |
553 | 365 | class MemoizeWithCycleCheck:
|
554 | 366 | def __init__(self, fn):
|
555 | 367 | self.fn = fn
|
@@ -637,20 +449,10 @@ def identify_mutated_tensors(kernel, kwargs):
|
637 | 449 | ttir_module = None
|
638 | 450 | functions = None
|
639 | 451 | try:
|
640 |
| - from torch._dynamo import config |
641 |
| - |
642 |
| - if not config.optimize_user_defined_triton_kernels: |
643 |
| - raise ValueError("optimize_user_defined_triton_kernels is False") |
644 |
| - |
645 | 452 | ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
|
646 | 453 |
|
647 |
| - # extract functions from TTIR |
648 |
| - if hasattr(ttir_module, "walk"): |
649 |
| - # use MLIR bindings exposed by Triton code |
650 |
| - functions = ttir_to_functions(ttir_module) |
651 |
| - else: |
652 |
| - # parse string representation of Triton IR |
653 |
| - functions = parse_ttir(str(ttir_module), kwargs) |
| 454 | + # extract functions from TTIR using MLIR bindings exposed by Triton code |
| 455 | + functions = ttir_to_functions(ttir_module) |
654 | 456 |
|
655 | 457 | assert functions is not None
|
656 | 458 | kernel_name = next(iter(functions.keys()))
|
|
0 commit comments