Skip to content

Commit a631461

Browse files
oulgenpytorchmergebot
authored andcommitted
Delete Lark (#123689)
Now that we are using MLIR bindings inside triton, lets delete Lark parser. Pull Request resolved: #123689 Approved by: https://github.com/jansel
1 parent 8d9af8b commit a631461

File tree

6 files changed

+3
-234
lines changed

6 files changed

+3
-234
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ junitparser==2.1.1
5252
#Pinned versions: 2.1.1
5353
#test that import:
5454

55-
lark==0.12.0
56-
#Description: parser
57-
#Pinned versions: 0.12.0
58-
#test that import:
59-
6055
librosa>=0.6.2 ; python_version < "3.11"
6156
#Description: A python package for music and audio analysis
6257
#Pinned versions: >=0.6.2

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,3 @@ fsspec
1818
setuptools ; python_version >= "3.12"
1919
packaging
2020
optree>=0.9.1
21-
lark

test/inductor/test_triton_kernels.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import torch._dynamo.testing
88

99
import torch._inductor.test_case
10-
from torch._dynamo import config
11-
from torch._dynamo.testing import make_test_cls_with_patches
1210

1311
from torch._higher_order_ops.triton_kernel_wrap import (
1412
generate_ttir,
@@ -1225,7 +1223,6 @@ def f(x, y):
12251223

12261224
def make_mutation_test(fn):
12271225
@requires_cuda
1228-
@requires_lark
12291226
@skipIfRocm
12301227
def test_fn(self):
12311228
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
@@ -1776,7 +1773,7 @@ def fwd_kernel(
17761773
)
17771774

17781775

1779-
if HAS_CUDA and HAS_LARK:
1776+
if HAS_CUDA:
17801777
t = torch.randn(4)
17811778
tt = torch.randn(4, 1)
17821779
tests = [
@@ -1921,15 +1918,6 @@ def fwd_kernel(
19211918

19221919
common_utils.instantiate_parametrized_tests(KernelTests)
19231920

1924-
no_opt_test_class = make_test_cls_with_patches(
1925-
KernelTests,
1926-
"NoOptimization",
1927-
"_no_optimizations",
1928-
(config, "optimize_user_defined_triton_kernels", False),
1929-
)
1930-
1931-
globals()[no_opt_test_class.__name__] = no_opt_test_class
1932-
no_opt_test_class.__module__ = __name__
19331921

19341922
if __name__ == "__main__":
19351923
from torch._inductor.test_case import run_tests

torch/_dynamo/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,6 @@ def default_debug_dir_root():
393393
# enable/disable dynamo tracing for `torch.func` transforms
394394
capture_func_transforms = True
395395

396-
# enable/disable user-defined triton kernel optimizations
397-
optimize_user_defined_triton_kernels = True
398-
399396
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
400397
log_compilation_metrics = True
401398

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 2 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import inspect
33
import logging
44
import threading
5-
import warnings
65
from collections import defaultdict
76
from typing import Any, Dict, List, Optional, Union
87

@@ -363,193 +362,6 @@ def mlir_to_functions(op) -> None:
363362
return functions
364363

365364

366-
def parse_ttir(ttir, kwargs):
367-
"""
368-
Given a Triton emitted TTIR text, this function lexes and parses the
369-
code using a minimal grammar defined inside. During the lexing/parsing,
370-
we drop any constant value and type information as they are not
371-
necessary to us.
372-
Being able to choose what we need makes this not a general purpose TTIR
373-
parser which further makes parsing much simpler.
374-
"""
375-
# TODO(oulgen):
376-
# - Support closures (e.g. "tt.reduce")
377-
378-
try:
379-
import lark # type: ignore[import-not-found]
380-
from lark import Lark, Transformer, v_args
381-
except ModuleNotFoundError:
382-
warnings.warn(
383-
"Using slow path for user-defined Triton kernels. `pip install lark` to fix this."
384-
)
385-
raise
386-
387-
# Ops looks like one of the following forms:
388-
#
389-
# %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>
390-
# tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32>
391-
# %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950
392-
grammar = """
393-
start: (module_block | loc_line)+
394-
395-
loc_line: "#loc" /.+/ NEWLINE
396-
397-
module_block: "module" "{" func_block+ "}" LOC
398-
399-
func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func
400-
401-
?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt
402-
403-
if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if
404-
for: [assign_lhs "="] "scf.for" args rest stmt* "}" divisibility_annot? LOC -> process_for
405-
while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while
406-
407-
condition_stmt: "scf.condition" "(" arg ")" args rest
408-
label_stmt: LABEL ":" "// pred:" LABEL
409-
| LABEL "(" /.+/ NEWLINE
410-
cf_stmt: "cf" "." NAME /.+/ NEWLINE
411-
412-
op: OP_NAME LOC
413-
| [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op
414-
415-
?rest: (":" | "{" | "\\"" | "->" | "<" | "=") /.+/ NEWLINE
416-
divisibility_annot: "{" "tt.divisibility_arg1" /[^}]+/ "}"
417-
418-
args: | "(" ")" | "("? arg ("," arg)* ")"?
419-
420-
?arg: INTERMEDIATE
421-
| INTERMEDIATE_CONSTANT
422-
| CONSTANT
423-
| PARAM
424-
| "[" args "]"
425-
| arg_with_index
426-
427-
?arg_with_index: arg "#" DIGIT+
428-
429-
?assign_lhs: (INTERMEDIATE | INTERMEDIATE_CONSTANT) [":" DIGIT+]
430-
431-
PARAM.5: "%arg" DIGIT+
432-
INTERMEDIATE.4: "%" DIGIT+
433-
INTERMEDIATE_CONSTANT.3: "%" NAME
434-
CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")?
435-
LABEL: "^bb" DIGIT+
436-
437-
NAME: (LETTER | DIGIT | "_")+
438-
NON_CF_NAME: /(?!(cf))/ NAME
439-
FN_NAME: "@" (NAME | ESCAPED_STRING)
440-
OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""?
441-
442-
LOC.5: "loc(#loc" DIGIT* ")"
443-
444-
%import common.LETTER
445-
%import common.DIGIT
446-
%import common.WS
447-
%import common.NEWLINE
448-
%import common.ESCAPED_STRING
449-
%import common.FLOAT
450-
%ignore WS
451-
"""
452-
453-
next_fake_intermediate = 0
454-
455-
def convert(token):
456-
if isinstance(token, lark.tree.Tree):
457-
if token.data == "args":
458-
res = []
459-
for a in token.children:
460-
c = convert(a)
461-
if isinstance(c, list):
462-
res.extend(c)
463-
else:
464-
res.append(c)
465-
return res
466-
elif token.data in {"assign_lhs", "arg_with_index"}:
467-
# Drop length/index qualifier
468-
return convert(token.children[0])
469-
else:
470-
raise AssertionError(f"Tree node with {token.data}")
471-
472-
if token is None or (
473-
isinstance(token, lark.lexer.Token)
474-
and token.type in ("CONSTANT", "INTERMEDIATE_CONSTANT")
475-
):
476-
nonlocal next_fake_intermediate
477-
next_fake_intermediate -= 1
478-
return Intermediate(next_fake_intermediate)
479-
480-
assert isinstance(token, lark.lexer.Token)
481-
482-
if token.type == "INTERMEDIATE":
483-
return Intermediate(int(token.value[len("%") :]))
484-
if token.type == "PARAM":
485-
return Param(int(token.value[len("%arg") :]))
486-
487-
raise AssertionError(f"{type(token.type)} => {token.value} invalid")
488-
489-
# In alternative representation, function names are quoted.
490-
# It should be possible to move this into the grammar alltogether.
491-
def convert_name(token):
492-
if token is None:
493-
return None
494-
s = token.value
495-
if len(s) > 2 and s[0] == '"' and s[-1] == '"':
496-
return s[1:-1]
497-
return s
498-
499-
functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
500-
501-
def extend_dict_list(d1, d2):
502-
for key, values in d2.items():
503-
d1[key].extend(values)
504-
505-
@v_args(inline=True)
506-
class TransformOps(Transformer):
507-
def process_op(self, ret, op_name, fn_name, args, *rest):
508-
return Op(
509-
convert_name(op_name),
510-
convert_name(fn_name),
511-
convert(args),
512-
convert(ret),
513-
)
514-
515-
def process_func(self, name, _args, *stmts):
516-
ops: Dict[Intermediate, List[Op]] = defaultdict(list)
517-
for e in stmts:
518-
if isinstance(e, Op):
519-
ops[e.ret].append(e)
520-
elif isinstance(e, dict):
521-
extend_dict_list(ops, e)
522-
functions[name.value] = ops
523-
524-
def _process_scf(self, ret, stmts):
525-
ret = convert(ret)
526-
ops: Dict[Intermediate, List[Op]] = defaultdict(list)
527-
for e in stmts:
528-
if isinstance(e, Op):
529-
if e.name == "scf.yield":
530-
ops[ret].append(Op(e.name, None, e.args, ret))
531-
else:
532-
ops[e.ret].append(e)
533-
elif isinstance(e, dict):
534-
extend_dict_list(ops, e)
535-
return ops
536-
537-
def process_if(self, ret, _args, _rest, *stmts):
538-
return self._process_scf(ret, stmts)
539-
540-
def process_for(self, ret, _args, _rest, *stmts):
541-
return self._process_scf(ret, stmts)
542-
543-
def process_while(self, ret, _args, _rest, *stmts):
544-
return self._process_scf(ret, stmts)
545-
546-
parser = Lark(
547-
grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps()
548-
)
549-
parser.parse(ttir)
550-
return functions
551-
552-
553365
class MemoizeWithCycleCheck:
554366
def __init__(self, fn):
555367
self.fn = fn
@@ -637,20 +449,10 @@ def identify_mutated_tensors(kernel, kwargs):
637449
ttir_module = None
638450
functions = None
639451
try:
640-
from torch._dynamo import config
641-
642-
if not config.optimize_user_defined_triton_kernels:
643-
raise ValueError("optimize_user_defined_triton_kernels is False")
644-
645452
ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
646453

647-
# extract functions from TTIR
648-
if hasattr(ttir_module, "walk"):
649-
# use MLIR bindings exposed by Triton code
650-
functions = ttir_to_functions(ttir_module)
651-
else:
652-
# parse string representation of Triton IR
653-
functions = parse_ttir(str(ttir_module), kwargs)
454+
# extract functions from TTIR using MLIR bindings exposed by Triton code
455+
functions = ttir_to_functions(ttir_module)
654456

655457
assert functions is not None
656458
kernel_name = next(iter(functions.keys()))

torch/testing/_internal/triton_utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,6 @@
55
from torch.testing._internal.inductor_utils import HAS_CUDA
66

77

8-
def has_lark():
9-
try:
10-
import lark # noqa: F401
11-
12-
return True
13-
except ModuleNotFoundError:
14-
return False
15-
16-
17-
HAS_LARK = has_lark()
18-
19-
requires_lark = unittest.skipUnless(HAS_LARK, "requires lark")
208
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
219

2210
if HAS_CUDA:

0 commit comments

Comments
 (0)