Skip to content

Commit

Permalink
fix volatile
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Dec 19, 2024
1 parent 3262b8e commit a4c66bb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
25 changes: 19 additions & 6 deletions benchmarks/fp8_dynamic_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
dynamic_scaled_quant,
eager_dynamic_scaled_quant,
)
from transformer_nuggets.utils import benchmark_torch_function_in_microseconds
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds

device = torch.device("cuda")

Expand Down Expand Up @@ -41,7 +41,9 @@ class Experiment:


def get_configs() -> List[ExperimentConfig]:
sizes = [2**21, 2**22, 2**23, 2**24]
# We hang for anything bigger than this
# sizes = [2**21, 2**22, 2**23, 2**24]
sizes = [2**21, 2**22]
high_precision_dtypes = [torch.float32]
low_precision_dtypes = [torch.float8_e4m3fn, torch.float8_e5m2]
configs = []
Expand Down Expand Up @@ -70,7 +72,18 @@ def correctness_check(hp_tensor, triton_tensor, config):
config.low_precision_dtype,
).to(config.high_precision_dtype)

compiled_pytorch_fn = torch.compile(eager_dynamic_scaled_quant, fullgraph=True)
compiled_out = compiled_pytorch_fn(
hp_tensor,
config.low_precision_dtype,
).to(config.high_precision_dtype)

print(f"Deviation between Triton and Nuggets: {torch.abs(nuggets_out - eager_out).max()}")
print(
f"Deviation between Eager and Compiled PyTorch: {torch.abs(eager_out - compiled_out).max()}"
)

# Find the index of the maximum deviation
max_dev_index = torch.abs(nuggets_out - eager_out).argmax().item()
print(f"nuggets_out tensor value: {nuggets_out.flatten()[max_dev_index]:.4f}")
print(f"eager_out tensor value: {eager_out.flatten()[max_dev_index]:.4f}")
Expand All @@ -83,21 +96,21 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
triton_hp_tensor = high_precision_tensor.clone()

# Triton does different rounding as far as I can tell
if False:
if True:
correctness_check(high_precision_tensor, triton_hp_tensor, config)

triton_time = benchmark_torch_function_in_microseconds(
triton_time = benchmark_cuda_function_in_microseconds(
dynamic_scaled_quant,
triton_hp_tensor,
config.low_precision_dtype,
)
pytorch_time = benchmark_torch_function_in_microseconds(
pytorch_time = benchmark_cuda_function_in_microseconds(
eager_dynamic_scaled_quant,
high_precision_tensor,
config.low_precision_dtype,
)
compiled_pytorch_fn = torch.compile(eager_dynamic_scaled_quant, fullgraph=True)
compiled_pytorch_time = benchmark_torch_function_in_microseconds(
compiled_pytorch_time = benchmark_cuda_function_in_microseconds(
compiled_pytorch_fn,
high_precision_tensor,
config.low_precision_dtype,
Expand Down
13 changes: 7 additions & 6 deletions transformer_nuggets/fp8/scaled_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def dynamic_scaled_cast(
block_max = tl.max(tl.abs(inpt))
tl.atomic_max(abs_max_ptr, block_max)
# Spinlock global barrier
tl.atomic_add(spin_lock, 1)
while tl.load(spin_lock) < n_blocks:
tl.atomic_add(spin_lock, 1, sem="release")
while tl.load(spin_lock, volatile=True) < n_blocks:
pass
scale = max_val / (tl.clamp(tl.load(abs_max_ptr), -1e12, float("inf")))
scaled_inpt = inpt * scale
Expand All @@ -151,6 +151,7 @@ def dynamic_scaled_quant(
out_tensor = torch.empty_like(inpt_tensor, dtype=fp8_dtype, device="cuda")
numel = inpt_tensor.numel()
grid = lambda meta: (triton.cdiv(numel, meta["XBLOCK"]),)
assert inpt_tensor.is_contiguous(), "Input tensor must be contiguous"
tl_dtype = {torch.float8_e4m3fn: tl.float8e4nv, torch.float8_e5m2: tl.float8e5}[fp8_dtype]
max_val = torch.finfo(fp8_dtype).max
abs_max_scratch = torch.empty((), dtype=inpt_tensor.dtype, device="cuda")
Expand All @@ -161,7 +162,7 @@ def dynamic_scaled_quant(
abs_max_scratch,
spin_lock,
numel,
4096,
16384,
tl_dtype,
max_val,
)
Expand All @@ -177,8 +178,8 @@ def eager_dynamic_scaled_quant(
a: Input tensor to quantize
fp8_dtype: FP8 datatype to quantize to
"""
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated

scale = tensor_to_scale(a, fp8_dtype)
a = a * scale
return to_fp8_saturated(a, fp8_dtype)
tensor_scaled = a.to(torch.float32) * scale
return to_fp8_saturated(tensor_scaled, fp8_dtype)

0 comments on commit a4c66bb

Please sign in to comment.