From f4c2ff9fd6f732ed34efa7ad518c46f7019c66b7 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Date: Wed, 21 Aug 2024 05:24:48 +0500 Subject: [PATCH] [PERF] Allow prologue fusion for `reduce` op (#426) Allow prologue for fp32 `reduce`. `reduce` uses vectorized calculations that don't allow to use fusing(it is possible but not implemented yet). For fp32 there are no vectors and we can enable fusion (with small modification `reduce` kernels itself). Motivation. In llama2 the part of the calculation is fp32 including `pow`+`reduce`. Performance improvement on llama2-7B +0.241% --- .../ops/fusion/apply_prologue_epilogue.py | 3 +++ python/hidet/graph/ops/reduce/reduce.py | 27 ++++++++++++++----- tests/benchmarks/bench_transformer.py | 14 ++++++---- 3 files changed, 33 insertions(+), 11 deletions(-) diff --git a/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py b/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py index 8c50b6f47..cdf68e7ce 100644 --- a/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py +++ b/python/hidet/graph/ops/fusion/apply_prologue_epilogue.py @@ -215,6 +215,9 @@ def extract(self) -> Tuple[Dict[Tensor, Prologue], Dict[Tensor, Epilogue], Dict[ # extract epilogues epilogues: Dict[Tensor, Epilogue] = {} for task_output, tensor in zip(self.anchor_task.outputs, self.anchor_operator.outputs): + if self.tensor_map[task_output] in self.graph.outputs: + # this output does not have a epilogue, skip + continue axes = [var('i') for _ in range(len(task_output.shape))] value = var('value', task_output.type.dtype) bss = BufferStoreStmt(buf=task_output, indices=axes, value=value) diff --git a/python/hidet/graph/ops/reduce/reduce.py b/python/hidet/graph/ops/reduce/reduce.py index 7147ebc02..7932f89d3 100644 --- a/python/hidet/graph/ops/reduce/reduce.py +++ b/python/hidet/graph/ops/reduce/reduce.py @@ -47,15 +47,23 @@ def __init__( def allow_epilogue(self) -> bool: rank = len(self.inputs[0].shape) - if rank - 1 in self.dims: # pylint: disable=simplifiable-if-statement + if rank - 1 in self.dims: # use self.cuda_schedule_reduce_by_warp return True else: # use self.cuda_schedule_reduce_by_default - return False + nbytes = self.inputs[0].type.dtype.nbytes + if nbytes == 4: # pylint: disable=simplifiable-if-statement + return True + else: + return False def allow_prologue(self) -> bool: - return False + nbytes = self.inputs[0].type.dtype.nbytes + if nbytes == 4: # pylint: disable=simplifiable-if-statement + return True + else: + return False def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]: rank = len(self.inputs[0].shape) @@ -138,7 +146,10 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]): smem_staging = dynamic_shared_memory(byte_offset=0, dtype=accumulate_dtype) rv = ro.initial_value(data_type(accumulate_dtype)) - x_vectorized = tensor_pointer(vtype, shape=read_shape, init=cast(x, ~vtype)) + if lanes == 1: + x_vectorized = x + else: + x_vectorized = tensor_pointer(vtype, shape=read_shape, init=cast(x, ~vtype)) # initialize staging shared memory if perform_atomic_reduce: @@ -291,8 +302,12 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]): attrs.cuda.min_blocks = 1 attrs.cuda.dynamic_smem_bytes = smem_needed - x_vectorized = tensor_pointer(vtype, shape=x_vectorized_shape, init=cast(x, ~vtype)) - y_vectorized = tensor_pointer(vtype, shape=y_vectorized_shape, init=cast(y, ~vtype)) + if lanes == 1: + x_vectorized = x + y_vectorized = y + else: + x_vectorized = tensor_pointer(vtype, shape=x_vectorized_shape, init=cast(x, ~vtype)) + y_vectorized = tensor_pointer(vtype, shape=y_vectorized_shape, init=cast(y, ~vtype)) rv = register_tensor(accumulate_dtype, [lanes]) for lane_id in grid(lanes, "u+"): rv[lane_id] = ro.initial_value(accumulate_dtype) diff --git a/tests/benchmarks/bench_transformer.py b/tests/benchmarks/bench_transformer.py index a056f11f8..2d3ea2cf6 100644 --- a/tests/benchmarks/bench_transformer.py +++ b/tests/benchmarks/bench_transformer.py @@ -81,8 +81,8 @@ def get_full_model_name(model_name): return short_to_full_model_name[model_name] -def bench_causal_lm(model_name, bs, genlen, dtype, backend, mode): - comp_backend = Backend(backend, mode, dtype) +def bench_causal_lm(model_name, bs, genlen, dtype, backend, mode, cache): + comp_backend = Backend(backend, mode, dtype, cache) dtype = getattr(torch, dtype) model_name = get_full_model_name(model_name) @@ -103,7 +103,7 @@ def bench_causal_lm(model_name, bs, genlen, dtype, backend, mode): END_OF_SENTENCE_ID = tokenizer.eos_token_id with torch.no_grad(), torch.autocast("cuda"): - _, torch_output = bench_gen_model(model, tokenizer, inputs, bs=bs, genlen=genlen, bench_iters=1, warmup_iters=0) + _, torch_output = bench_gen_model(model, tokenizer, inputs, bs=bs, genlen=genlen) # Temporary workaround for gpt-j # gpt-j initializes tensors during the first forwasd pass # which causes recompilation during the second forward pass @@ -159,9 +159,11 @@ def bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode): parser.add_argument('--dtype', type=str, default='float16', help='Specify precision. E.g., float32') parser.add_argument('--backend', type=str, default='hidet', help='torch.compile backend') parser.add_argument('--mode', type=str, default='max-autotune', help='torch.compile mode') + parser.add_argument('--cache', type=str, default='', help='') + args = parser.parse_args() - model_name, dtype, backend, mode = args.model, args.dtype, args.backend, args.mode + model_name, dtype, backend, mode, cache = args.model, args.dtype, args.backend, args.mode, args.cache seqlen = SEQLEN_DEFAULT bs = BS_DEFAULT @@ -179,6 +181,8 @@ def bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode): if model_class[get_full_model_name(model_name)] == 'AutoModelForMaskedLM': latency = bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode) elif model_class[get_full_model_name(model_name)] == 'AutoModelForCausalLM': - latency = bench_causal_lm(model_name, bs=bs, genlen=genlen, dtype=dtype, backend=backend, mode=mode) + latency = bench_causal_lm( + model_name, bs=bs, genlen=genlen, dtype=dtype, backend=backend, mode=mode, cache=cache + ) print(latency)