Skip to content

Commit

Permalink
[PERF] Allow prologue fusion for reduce op (#426)
Browse files Browse the repository at this point in the history
Allow prologue for fp32 `reduce`. 

`reduce` uses vectorized calculations that don't allow to use fusing(it is possible but not implemented yet). For fp32 there are no vectors and we can enable fusion (with small modification `reduce` kernels itself).

Motivation.
In llama2 the part of the calculation is fp32 including `pow`+`reduce`. 

Performance improvement on llama2-7B +0.241%
  • Loading branch information
vadiklyutiy committed Dec 20, 2024
1 parent 646f7e7 commit 6606477
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 11 deletions.
3 changes: 3 additions & 0 deletions python/hidet/graph/ops/fusion/apply_prologue_epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def extract(self) -> Tuple[Dict[Tensor, Prologue], Dict[Tensor, Epilogue], Dict[
# extract epilogues
epilogues: Dict[Tensor, Epilogue] = {}
for task_output, tensor in zip(self.anchor_task.outputs, self.anchor_operator.outputs):
if self.tensor_map[task_output] in self.graph.outputs:
# this output does not have a epilogue, skip
continue
axes = [var('i') for _ in range(len(task_output.shape))]
value = var('value', task_output.type.dtype)
bss = BufferStoreStmt(buf=task_output, indices=axes, value=value)
Expand Down
27 changes: 21 additions & 6 deletions python/hidet/graph/ops/reduce/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,23 @@ def __init__(

def allow_epilogue(self) -> bool:
rank = len(self.inputs[0].shape)
if rank - 1 in self.dims: # pylint: disable=simplifiable-if-statement
if rank - 1 in self.dims:
# use self.cuda_schedule_reduce_by_warp
return True
else:
# use self.cuda_schedule_reduce_by_default
return False
nbytes = self.inputs[0].type.dtype.nbytes
if nbytes == 4: # pylint: disable=simplifiable-if-statement
return True
else:
return False

def allow_prologue(self) -> bool:
return False
nbytes = self.inputs[0].type.dtype.nbytes
if nbytes == 4: # pylint: disable=simplifiable-if-statement
return True
else:
return False

def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
rank = len(self.inputs[0].shape)
Expand Down Expand Up @@ -138,7 +146,10 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]):

smem_staging = dynamic_shared_memory(byte_offset=0, dtype=accumulate_dtype)
rv = ro.initial_value(data_type(accumulate_dtype))
x_vectorized = tensor_pointer(vtype, shape=read_shape, init=cast(x, ~vtype))
if lanes == 1:
x_vectorized = x
else:
x_vectorized = tensor_pointer(vtype, shape=read_shape, init=cast(x, ~vtype))

# initialize staging shared memory
if perform_atomic_reduce:
Expand Down Expand Up @@ -291,8 +302,12 @@ def reduce_kernel(x: xdtype[x.shape], y: xdtype[y.shape]):
attrs.cuda.min_blocks = 1
attrs.cuda.dynamic_smem_bytes = smem_needed

x_vectorized = tensor_pointer(vtype, shape=x_vectorized_shape, init=cast(x, ~vtype))
y_vectorized = tensor_pointer(vtype, shape=y_vectorized_shape, init=cast(y, ~vtype))
if lanes == 1:
x_vectorized = x
y_vectorized = y
else:
x_vectorized = tensor_pointer(vtype, shape=x_vectorized_shape, init=cast(x, ~vtype))
y_vectorized = tensor_pointer(vtype, shape=y_vectorized_shape, init=cast(y, ~vtype))
rv = register_tensor(accumulate_dtype, [lanes])
for lane_id in grid(lanes, "u+"):
rv[lane_id] = ro.initial_value(accumulate_dtype)
Expand Down
14 changes: 9 additions & 5 deletions tests/benchmarks/bench_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def get_full_model_name(model_name):
return short_to_full_model_name[model_name]


def bench_causal_lm(model_name, bs, genlen, dtype, backend, mode):
comp_backend = Backend(backend, mode, dtype)
def bench_causal_lm(model_name, bs, genlen, dtype, backend, mode, cache):
comp_backend = Backend(backend, mode, dtype, cache)

dtype = getattr(torch, dtype)
model_name = get_full_model_name(model_name)
Expand All @@ -103,7 +103,7 @@ def bench_causal_lm(model_name, bs, genlen, dtype, backend, mode):
END_OF_SENTENCE_ID = tokenizer.eos_token_id

with torch.no_grad(), torch.autocast("cuda"):
_, torch_output = bench_gen_model(model, tokenizer, inputs, bs=bs, genlen=genlen, bench_iters=1, warmup_iters=0)
_, torch_output = bench_gen_model(model, tokenizer, inputs, bs=bs, genlen=genlen)
# Temporary workaround for gpt-j
# gpt-j initializes tensors during the first forwasd pass
# which causes recompilation during the second forward pass
Expand Down Expand Up @@ -159,9 +159,11 @@ def bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode):
parser.add_argument('--dtype', type=str, default='float16', help='Specify precision. E.g., float32')
parser.add_argument('--backend', type=str, default='hidet', help='torch.compile backend')
parser.add_argument('--mode', type=str, default='max-autotune', help='torch.compile mode')
parser.add_argument('--cache', type=str, default='', help='')

args = parser.parse_args()

model_name, dtype, backend, mode = args.model, args.dtype, args.backend, args.mode
model_name, dtype, backend, mode, cache = args.model, args.dtype, args.backend, args.mode, args.cache

seqlen = SEQLEN_DEFAULT
bs = BS_DEFAULT
Expand All @@ -179,6 +181,8 @@ def bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode):
if model_class[get_full_model_name(model_name)] == 'AutoModelForMaskedLM':
latency = bench_masked_lm(model_name, seqlen, bs, dtype, backend, mode)
elif model_class[get_full_model_name(model_name)] == 'AutoModelForCausalLM':
latency = bench_causal_lm(model_name, bs=bs, genlen=genlen, dtype=dtype, backend=backend, mode=mode)
latency = bench_causal_lm(
model_name, bs=bs, genlen=genlen, dtype=dtype, backend=backend, mode=mode, cache=cache
)

print(latency)

0 comments on commit 6606477

Please sign in to comment.