Skip to content

Commit e1f93f3

Browse files
authored
Fix conflict parameter name promote_dtye in FP8ComputeLegalize (#18334)
1 parent f971595 commit e1f93f3

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

include/tvm/tir/transform.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,11 @@ TVM_DLL Pass BF16ComputeLegalize();
357357
/*!
358358
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
359359
* before Ops, then add a cast back to fp8.
360-
* \param promote_dtype_str The data type used for type promotion, defaults to float16
360+
* \param promote_dtype The data type used for type promotion, defaults to float16
361361
* \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs
362362
* \return The pass.
363363
*/
364-
TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype_str = "float16");
364+
TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16");
365365

366366
/*!
367367
* \brief Legalize bf16 storage types to u16.

python/tvm/tir/transform/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def BF16ComputeLegalize():
244244
return _ffi_api.BF16ComputeLegalize() # type: ignore
245245

246246

247-
def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
247+
def FP8ComputeLegalize(promote_dtype: str = "float32"):
248248
"""Legalize fp8 compute Ops.
249249
250250
Parameters
@@ -257,7 +257,7 @@ def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
257257
fpass : tvm.transform.Pass
258258
The result pass
259259
"""
260-
return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore
260+
return _ffi_api.FP8ComputeLegalize(promote_dtype) # type: ignore
261261

262262

263263
def BF16StorageLegalize():

src/tir/transforms/unsupported_dtype_legalize.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,13 +780,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
780780
refl::GlobalDef().def("tir.transform.BF16StorageLegalize", BF16StorageLegalize);
781781
}
782782

783-
Pass FP8ComputeLegalize(ffi::String promote_dtype_str) {
783+
Pass FP8ComputeLegalize(ffi::String promote_dtype) {
784784
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
785785
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
786786
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
787787
return f;
788788
}
789-
return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype_str))).Legalize(f);
789+
return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f);
790790
};
791791
return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {});
792792
}

0 commit comments

Comments
 (0)