From 5c6efb5c7547572eb6f3e58506c381f0fe996286 Mon Sep 17 00:00:00 2001 From: Mandar Deshpande Date: Tue, 12 Nov 2024 15:44:41 -0800 Subject: [PATCH] replace triton.ops dependencies in pytorch/ao (#1250) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/1250 `triton.ops` is moved to kernels directory with the 3.2 update. This change updates imports to be through explicit matmul and matmul_perf_model helper files copied to `pytorch/ao`. Reviewed By: bertmaher Differential Revision: D65678605 --- torchao/prototype/common/triton/__init__.py | 0 torchao/prototype/common/triton/matmul.py | 367 ++++++++++++++++++ .../common/triton/matmul_perf_model.py | 251 ++++++++++++ torchao/prototype/dora/kernels/common.py | 2 +- torchao/prototype/dora/kernels/smallk.py | 2 +- .../galore/kernels/adam_downproj_fused.py | 5 +- torchao/prototype/galore/kernels/matmul.py | 30 +- 7 files changed, 650 insertions(+), 7 deletions(-) create mode 100644 torchao/prototype/common/triton/__init__.py create mode 100644 torchao/prototype/common/triton/matmul.py create mode 100644 torchao/prototype/common/triton/matmul_perf_model.py diff --git a/torchao/prototype/common/triton/__init__.py b/torchao/prototype/common/triton/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/torchao/prototype/common/triton/matmul.py b/torchao/prototype/common/triton/matmul.py new file mode 100644 index 000000000..acd605336 --- /dev/null +++ b/torchao/prototype/common/triton/matmul.py @@ -0,0 +1,367 @@ +import torch + +from triton import Config, autotune, cdiv, heuristics, jit +from triton import language as tl +from .matmul_perf_model import early_config_prune, estimate_matmul_time + +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] + + +def upcast_if_fp8(a): + if "fp8" in str(a): + return torch.float16 + return a + + +def get_higher_dtype(a, b): + a = upcast_if_fp8(a) + b = upcast_if_fp8(b) + if a is b: + return a + + assert a in _ordered_datatypes + assert b in _ordered_datatypes + + for d in _ordered_datatypes: + if a is d: + return b + if b is d: + return a + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +@autotune( + configs=[ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + + get_configs_io_bound(), + key=["M", "N", "K"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": 10, + }, +) +@heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@jit +def _kernel( + A, + B, + C, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, # + acc_dtype: tl.constexpr, # + input_precision: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + AB_DTYPE: tl.constexpr, # +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot( + a, b, acc, out_dtype=acc_dtype, input_precision=input_precision + ) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +class _matmul(torch.autograd.Function): + kernel = _kernel + + _locks = {} + + @staticmethod + def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert ( + a.shape[1] == b.shape[0] + ), f"incompatible dimensions {a.shape} and {b.shape}" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if output_dtype is None: + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + # Allowed types for acc_type given the types of a and b. + supported_acc_dtypes = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,), + } + + if acc_dtype is None: + acc_dtype = supported_acc_dtypes[ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in supported_acc_dtypes[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in supported_acc_dtypes[b.dtype] + ), "acc_dtype not compatible with the type of b" + + def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + # launch kernel + grid = lambda META: ( + cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + _kernel[grid]( + a, + b, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + acc_dtype=acc_dtype, # + input_precision=input_precision, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, + AB_DTYPE=ab_dtype, + ) + return c + + @staticmethod + def forward( + ctx, + a, + b, + acc_dtype=None, + input_precision=None, + fp8_fast_accum=True, + output_dtype=None, + ): + return _matmul._call( + a, + b, + acc_dtype=acc_dtype, + input_precision=input_precision, + fp8_fast_accum=fp8_fast_accum, + output_dtype=output_dtype, + ) + + +matmul = _matmul.apply diff --git a/torchao/prototype/common/triton/matmul_perf_model.py b/torchao/prototype/common/triton/matmul_perf_model.py new file mode 100644 index 000000000..fd47ffc64 --- /dev/null +++ b/torchao/prototype/common/triton/matmul_perf_model.py @@ -0,0 +1,251 @@ + +# Source: https://github.com/triton-lang/kernels/blob/main/kernels/matmul_perf_model.py + +# This file is taken from the upstream triton-lang/kernels repo. +# Currently that repo does not have a license file, so disabling +# the license lint for now: +# @lint-ignore-every LICENSELINT + +# flake8: noqa +# pyre-ignore-all-errors +import functools +import heapq + +import torch + +from triton import cdiv +from triton.runtime import driver +from triton.testing import ( + get_dram_gbps, + get_max_simd_tflops, + get_max_tensorcore_tflops, + nvsmi, +) + + +@functools.lru_cache() +def get_clock_rate_in_khz(): + try: + return nvsmi(["clocks.max.sm"])[0] * 1e3 + except FileNotFoundError: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 + + +def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): + """return compute throughput in TOPS""" + total_warps = num_ctas * min(num_warps, 4) + if hasattr(driver, "active"): + num_subcores = ( + driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + ) # on recent GPUs + else: + num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + + tflops = ( + min(num_subcores, total_warps) + / num_subcores + * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) + ) + return tflops + + +def get_simd_tflops(device, num_ctas, num_warps, dtype): + """return compute throughput in TOPS""" + total_warps = num_ctas * min(num_warps, 4) + if hasattr(driver, "active"): + num_subcores = ( + driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + ) # on recent GPUs + else: + num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + + tflops = ( + min(num_subcores, total_warps) + / num_subcores + * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) + ) + return tflops + + +def get_tflops(device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) + + +def estimate_matmul_time( + # backend, device, + num_warps, + num_stages, # + A, + B, + C, # + M, + N, + K, # + BLOCK_M, + BLOCK_N, + BLOCK_K, + SPLIT_K, # + debug=False, + **kwargs, # +): + """return estimated running time in ms + = max(compute, loading) + store""" + device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() + + num_cta_m = cdiv(M, BLOCK_M) + num_cta_n = cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + if hasattr(driver, "active"): + num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] + else: + num_sm = driver.utils.get_device_properties(device)["multiprocessor_count"] + + + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min( + 1, num_ctas / 32 + ) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max( + min(1, (num_ctas - 32) / (108 - 32)), 0 + ) # 32-108, remaining 5% + dram_bw = get_dram_gbps(device) * ( + active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05 + ) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print( + f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " + f"loading time: {load_ms}ms, store time: {store_ms}ms, " + f"Activate CTAs: {active_cta_ratio*100}%" + ) + return total_time_ms + + +def early_config_prune(configs, named_args, **kwargs): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args["A"].element_size() + dtype = named_args["A"].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + config.num_stages, + ) + if hasattr(driver, "active"): + max_shared_memory = driver.active.utils.get_device_properties(device)[ + "max_shared_mem" + ] + else: + max_shared_memory = driver.utils.get_device_properties(device)[ + "max_shared_mem" + ] + + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + if dtype not in [torch.float16, torch.float32]: + configs = [config for config in configs if config.kwargs["SPLIT_K"] == 1] + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + kw["SPLIT_K"], + config.num_warps, + config.num_stages, + ) + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, + v, + key=lambda x: ( + 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 + else x[1] - optimal_num_stages + ), + ) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs diff --git a/torchao/prototype/dora/kernels/common.py b/torchao/prototype/dora/kernels/common.py index cd0950d4c..e1b85d238 100644 --- a/torchao/prototype/dora/kernels/common.py +++ b/torchao/prototype/dora/kernels/common.py @@ -5,7 +5,7 @@ import triton.language as tl # Re-exports -from triton.ops.matmul import ( +from torchao.prototype.common.triton.matmul import ( early_config_prune, estimate_matmul_time, get_configs_io_bound, diff --git a/torchao/prototype/dora/kernels/smallk.py b/torchao/prototype/dora/kernels/smallk.py index fc24ea223..b1d04878a 100644 --- a/torchao/prototype/dora/kernels/smallk.py +++ b/torchao/prototype/dora/kernels/smallk.py @@ -5,7 +5,7 @@ import torch import triton import triton.language as tl -from triton.ops.matmul import ( +from torchao.prototype.common.triton.matmul import ( estimate_matmul_time, get_configs_io_bound, get_higher_dtype, diff --git a/torchao/prototype/galore/kernels/adam_downproj_fused.py b/torchao/prototype/galore/kernels/adam_downproj_fused.py index 9049baa78..cfdbd4a03 100644 --- a/torchao/prototype/galore/kernels/adam_downproj_fused.py +++ b/torchao/prototype/galore/kernels/adam_downproj_fused.py @@ -3,12 +3,11 @@ import torch import triton import triton.language as tl -from triton.ops.matmul import get_higher_dtype, init_to_zero -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time +from torchao.prototype.common.triton.matmul_perf_model import early_config_prune, estimate_matmul_time from .adam_step import BETA1, BETA2, EPS from .custom_autotune import Config, autotune -from .matmul import TRITON_ACC_TYPES +from .matmul import TRITON_ACC_TYPES, get_higher_dtype, init_to_zero from .matmul import get_autotuner as default_mm_autotuner from .matmul import get_mm_heuristics, to_tl_type diff --git a/torchao/prototype/galore/kernels/matmul.py b/torchao/prototype/galore/kernels/matmul.py index b183f7ed6..de7e77fb1 100644 --- a/torchao/prototype/galore/kernels/matmul.py +++ b/torchao/prototype/galore/kernels/matmul.py @@ -1,8 +1,7 @@ import torch import triton import triton.language as tl -from triton.ops.matmul import get_higher_dtype, init_to_zero -from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time +from torchao.prototype.common.triton.matmul_perf_model import early_config_prune, estimate_matmul_time from .custom_autotune import Config, autotune, heuristics @@ -15,6 +14,33 @@ } AUTOTUNER_TOP_K = 50 +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] + + +def upcast_if_fp8(a): + if "fp8" in str(a): + return torch.float16 + return a + + +def get_higher_dtype(a, b): + a = upcast_if_fp8(a) + b = upcast_if_fp8(b) + if a is b: + return a + + assert a in _ordered_datatypes + assert b in _ordered_datatypes + + for d in _ordered_datatypes: + if a is d: + return b + if b is d: + return a + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() def set_tuner_top_k(k):