Skip to content

Commit df09264

Browse files
committed
Merge remote-tracking branch 'origin/main' into wengshiy/embeddingbag_krnl
2 parents 1c2c154 + 9056c46 commit df09264

File tree

86 files changed

+1790
-1524
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1790
-1524
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def _untie_weights_and_save_locally(model_id):
568568
"""
569569

570570

571-
def quantize_and_upload(model_id, quant):
571+
def quantize_and_upload(model_id, quant, push_to_hub):
572572
_int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig(
573573
weight_dtype=torch.int4,
574574
weight_granularity=PerGroup(32),
@@ -657,9 +657,13 @@ def quantize_and_upload(model_id, quant):
657657
card = ModelCard(content)
658658

659659
# Push to hub
660-
quantized_model.push_to_hub(quantized_model_id, safe_serialization=False)
661-
tokenizer.push_to_hub(quantized_model_id)
662-
card.push_to_hub(quantized_model_id)
660+
if push_to_hub:
661+
quantized_model.push_to_hub(quantized_model_id, safe_serialization=False)
662+
tokenizer.push_to_hub(quantized_model_id)
663+
card.push_to_hub(quantized_model_id)
664+
else:
665+
quantized_model.save_pretrained(quantized_model_id, safe_serialization=False)
666+
tokenizer.save_pretrained(quantized_model_id)
663667

664668
# Manual Testing
665669
prompt = "Hey, are you conscious? Can you talk to me?"
@@ -700,5 +704,11 @@ def quantize_and_upload(model_id, quant):
700704
type=str,
701705
help="Quantization method. Options are FP8, INT4, INT8_INT4, AWQ-INT4",
702706
)
707+
parser.add_argument(
708+
"--push_to_hub",
709+
action="store_true",
710+
default=False,
711+
help="Flag to indicate whether push to huggingface hub or not",
712+
)
703713
args = parser.parse_args()
704-
quantize_and_upload(args.model_id, args.quant)
714+
quantize_and_upload(args.model_id, args.quant, args.push_to_hub)

.github/scripts/torchao_model_releases/release.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
# Default quantization options
1616
default_quants=("FP8" "INT4" "INT8-INT4")
17+
push_to_hub=""
1718
# Parse arguments
1819
while [[ $# -gt 0 ]]; do
1920
case "$1" in
@@ -29,6 +30,10 @@ while [[ $# -gt 0 ]]; do
2930
shift
3031
done
3132
;;
33+
--push_to_hub)
34+
push_to_hub="--push_to_hub"
35+
shift
36+
;;
3237
*)
3338
echo "Unknown option: $1"
3439
exit 1
@@ -38,14 +43,14 @@ done
3843
# Use default quants if none specified
3944
if [[ -z "$model_id" ]]; then
4045
echo "Error: --model_id is required"
41-
echo "Usage: $0 --model_id <model_id> [--quants <quant1> [quant2 ...]]"
46+
echo "Usage: $0 --model_id <model_id> [--quants <quant1> [quant2 ...]] [--push_to_hub]"
4247
exit 1
4348
fi
4449
if [[ ${#quants[@]} -eq 0 ]]; then
4550
quants=("${default_quants[@]}")
4651
fi
4752
# Run the python command for each quantization option
4853
for quant in "${quants[@]}"; do
49-
echo "Running: python quantize_and_upload.py --model_id $model_id --quant $quant"
50-
python quantize_and_upload.py --model_id "$model_id" --quant "$quant"
54+
echo "Running: python quantize_and_upload.py --model_id $model_id --quant $quant $push_to_hub"
55+
python quantize_and_upload.py --model_id "$model_id" --quant "$quant" $push_to_hub
5156
done

.github/workflows/release_model.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,4 @@ jobs:
4343
pip install .
4444
HF_MODEL_ID=${{ github.event.inputs.hf_model_id }}
4545
cd .github/scripts/torchao_model_releases
46-
./release.sh --model_id $HF_MODEL_ID
46+
./release.sh --model_id $HF_MODEL_ID --push_to_hub

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class ExperimentConfig:
3535
@dataclass(frozen=True)
3636
class ExperimentResult:
3737
bf16_us: float
38-
fp8_us: float
39-
fp8_speedup: float
38+
scaled_us: float
39+
scaled_speedup: float
4040

4141

4242
@dataclass(frozen=True)
@@ -48,8 +48,8 @@ class Experiment:
4848
def get_configs() -> List[ExperimentConfig]:
4949
# Llama4 shapes
5050
A_shapes = [(16640, 5120)]
51-
B_shapes = [(1, 8192, 5120), (16, 8192, 5120), (128, 8192, 5120)]
52-
recipes = [MoEScalingType.FP8_ROWWISE]
51+
B_shapes = [(16, 8192, 5120)]
52+
recipes = [MoEScalingType.MXFP8, MoEScalingType.FP8_ROWWISE]
5353
high_precision_dtypes = [torch.bfloat16]
5454
configs = []
5555
for A_shape, B_shape, recipe, high_precision_dtype in itertools.product(
@@ -93,7 +93,8 @@ def run_experiment(
9393
# which represents the right operand.
9494
n_groups = config.B_shape[0]
9595
Mg = A.shape[0]
96-
offs = generate_jagged_offs(n_groups, Mg, multiple_of=16)
96+
token_group_alignment_size = 32 if config.recipe == MoEScalingType.MXFP8 else 16
97+
offs = generate_jagged_offs(n_groups, Mg, multiple_of=token_group_alignment_size)
9798

9899
labels = torch.ones(
99100
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
@@ -107,6 +108,7 @@ def run_experiment(
107108
offs,
108109
labels=labels,
109110
use_compile=args.compile,
111+
fullgraph=False,
110112
)
111113
if args.profile:
112114
profile_fwd_bwd(
@@ -116,18 +118,20 @@ def run_experiment(
116118
offs,
117119
labels=labels,
118120
use_compile=args.compile,
121+
fullgraph=False,
119122
profile_name="bf16_profile",
120123
)
121124

122125
# benchmark scaled grouped mm with dynamic fp8 rowwise quant
123-
fp8_us = bench_fwd_bwd_microseconds(
126+
scaled_us = bench_fwd_bwd_microseconds(
124127
_scaled_grouped_mm,
125128
A,
126129
B_t,
127130
offs,
128131
scaling_type=config.recipe,
129132
labels=labels,
130133
use_compile=args.compile,
134+
fullgraph=False,
131135
)
132136
if args.profile:
133137
profile_fwd_bwd(
@@ -139,22 +143,24 @@ def run_experiment(
139143
labels=labels,
140144
use_compile=args.compile,
141145
profile_name="scaled_profile",
146+
fullgraph=False,
142147
)
143148

144149
return ExperimentResult(
145150
bf16_us=round(bf16_us, 3),
146-
fp8_us=round(fp8_us, 3),
147-
fp8_speedup=round(bf16_us / fp8_us, 3),
151+
scaled_us=round(scaled_us, 3),
152+
scaled_speedup=round(bf16_us / scaled_us, 3),
148153
)
149154

150155

151156
def print_results(experiments: List[Experiment]):
152157
headers = [
153158
"A_shape",
154159
"B_shape",
160+
"recipe",
155161
"bf16_time_us",
156162
"scaled_time_us",
157-
"fp8_speedup",
163+
"scaled_speedup",
158164
]
159165
rows = []
160166
for experiment in experiments:
@@ -164,9 +170,10 @@ def print_results(experiments: List[Experiment]):
164170
[
165171
A_shape,
166172
B_shape,
173+
experiment.config.recipe,
167174
experiment.result.bf16_us,
168-
experiment.result.fp8_us,
169-
f"{experiment.result.fp8_speedup}x",
175+
experiment.result.scaled_us,
176+
f"{experiment.result.scaled_speedup}x",
170177
]
171178
)
172179
print(tabulate(rows, headers=headers))

scripts/clean_release_notes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
"topic: performance": "Performance",
9090
"topic: documentation": "Documentation",
9191
"topic: for developer": "Developers",
92+
"topic: not user facing": "Not User Facing",
9293
}
9394

9495

@@ -123,6 +124,7 @@ def clean_release_notes():
123124
"Performance": [],
124125
"Documentation": [],
125126
"Developers": [],
127+
"Not User Facing": [],
126128
}
127129
with open(input_file, "r") as in_f, open(output_file, "a") as out_f:
128130
for line in in_f.readlines():
@@ -195,8 +197,6 @@ def get_commit_category(
195197
pr_number = parse_pr_number(commit_line)
196198
if pr_number in pr_number_to_label:
197199
label = pr_number_to_label[pr_number]
198-
if label == "topic: not user facing":
199-
return None
200200
if label in GITHUB_LABEL_TO_CATEGORY:
201201
return GITHUB_LABEL_TO_CATEGORY[label]
202202
elif any(x in commit_line.lower() for x in ["revert", "version.txt"]):

setup.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def get_extensions():
433433
extra_link_args.append("/DEBUG")
434434

435435
rocm_sparse_marlin_supported = False
436+
rocm_tiled_layout_supported = False
436437
if use_rocm:
437438
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
438439
found_col16 = False
@@ -488,8 +489,11 @@ def get_extensions():
488489
# Define ROCm source directories
489490
rocm_source_dirs = [
490491
os.path.join(extensions_dir, "rocm", "swizzle"),
491-
os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout"),
492492
]
493+
if rocm_tiled_layout_supported:
494+
rocm_source_dirs.append(
495+
os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout")
496+
)
493497
if rocm_sparse_marlin_supported:
494498
rocm_source_dirs.extend([os.path.join(extensions_dir, "cuda", "sparse_marlin")])
495499

@@ -512,14 +516,8 @@ def get_extensions():
512516
sources = [s for s in sources if s not in mxfp8_sources_to_exclude]
513517

514518
# TOOD: Remove this and use what CUDA has once we fix all the builds.
519+
# TODO: Add support for other ROCm GPUs
515520
if use_rocm:
516-
# Add ROCm GPU architecture check
517-
gpu_arch = None
518-
if torch.cuda.is_available():
519-
gpu_arch = torch.cuda.get_device_properties(0).name
520-
if gpu_arch and gpu_arch != "gfx942":
521-
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
522-
print("Currently only gfx942 is supported. Compiling only for gfx942.")
523521
extra_compile_args["nvcc"].append("--offload-arch=gfx942")
524522
sources += rocm_sources
525523
else:

test/dtypes/test_affine_quantized_float.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import pytest
1515
import torch
1616
from torch._inductor.test_case import TestCase as InductorTestCase
17-
from torch.profiler import ProfilerActivity, profile
17+
from torch._inductor.utils import run_and_get_code
18+
from torch.testing import FileCheck
1819
from torch.testing._internal import common_utils
1920

2021
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl, preprocess_scale
@@ -766,32 +767,36 @@ def test_expected_kernels_on_gpu(self, granularity, float8_config_version):
766767
config,
767768
)
768769

769-
m = torch.compile(m, mode="default")
770+
m = torch.compile(m)
770771
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
771-
772-
# warm up
773-
_ = m(x)
774-
# capture trace
775-
with profile(activities=[ProfilerActivity.CUDA]) as prof:
776-
_ = m(x)
777-
778-
cuda_kernel_events = [x for x in prof.key_averages() if x.cuda_time > 0]
779-
780-
if granularity == PerTensor():
772+
out, code = run_and_get_code(m, x)
773+
774+
# triton kernel call looks like:
775+
# triton_per_fused__scaled_mm__to_copy_abs_amax_clamp_clone_div_expand_permute_transpose_unsqueeze_view_0.run(arg3_1, buf1, buf2, 128, 256, stream=stream0)
776+
# scaled_mm call looks like:
777+
# extern_kernels._scaled_mm(buf1, reinterpret_tensor(arg0_1, (256, 512), (1, 256), 0), buf2, reinterpret_tensor(arg1_1, (1, 512), (1, 1), 0), arg2_1, out_dtype=torch.bfloat16, use_fast_accum=True, out=buf3)
778+
if granularity == PerRow():
779+
# one triton kernel for quantizing the activation
780+
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(
781+
code[0]
782+
)
783+
# one scaled_mm call
784+
FileCheck().check("def call(").check_count(
785+
"._scaled_mm(", 1, exactly=True
786+
).run(code[0])
787+
else:
788+
assert granularity == PerTensor(), "unsupported"
789+
# three triton kernels for quantizing the activation:
781790
# kernel 1: x_max_tmp = max(x, ...)
782791
# kernel 2: x_max = max(x_max_tmp)
783792
# kernel 3: x_float8 = to_float8(x, x_max)
784-
# kernel 4: gemm
785-
assert len(cuda_kernel_events) == 4, (
786-
f"too many cuda kernels: {cuda_kernel_events}"
787-
)
788-
else:
789-
assert granularity == PerRow()
790-
# kernel 1: x_float8 = to_float8(x)
791-
# kernel 2: gemm
792-
assert len(cuda_kernel_events) == 2, (
793-
f"too many cuda kernels: {cuda_kernel_events}"
793+
FileCheck().check("def call(").check_count(".run(", 3, exactly=True).run(
794+
code[0]
794795
)
796+
# one scaled_mm call
797+
FileCheck().check("def call(").check_count(
798+
"._scaled_mm(", 1, exactly=True
799+
).run(code[0])
795800

796801

797802
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

test/dtypes/test_nf4.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
to_nf4,
4444
)
4545
from torchao.testing.utils import skip_if_rocm
46-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
46+
from torchao.utils import torch_version_at_least
4747

4848
bnb_available = False
4949

@@ -123,7 +123,7 @@ def test_backward_dtype_match(self, dtype: torch.dtype):
123123
@unittest.skipIf(not bnb_available, "Need bnb availble")
124124
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
125125
@unittest.skipIf(
126-
TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI"
126+
torch_version_at_least("2.7.0"), reason="Failing in CI"
127127
) # TODO: fix this
128128
@skip_if_rocm("ROCm enablement in progress")
129129
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
@@ -150,7 +150,7 @@ def test_reconstruction_qlora_vs_bnb(self, dtype: torch.dtype):
150150
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
151151
@skip_if_rocm("ROCm enablement in progress")
152152
@unittest.skipIf(
153-
TORCH_VERSION_AT_LEAST_2_7, reason="Failing in CI"
153+
torch_version_at_least("2.7.0"), reason="Failing in CI"
154154
) # TODO: fix this
155155
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
156156
def test_nf4_bnb_linear(self, dtype: torch.dtype):

test/integration/test_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@
7676
)
7777
from torchao.testing.utils import skip_if_rocm
7878
from torchao.utils import (
79-
TORCH_VERSION_AT_LEAST_2_7,
8079
benchmark_model,
8180
check_cpu_version,
8281
check_xpu_version,
8382
is_fbcode,
8483
is_sm_at_least_89,
8584
is_sm_at_least_90,
85+
torch_version_at_least,
8686
unwrap_tensor_subclass,
8787
)
8888

@@ -1883,7 +1883,7 @@ def forward(self, x):
18831883
model(x)
18841884

18851885
api(model)
1886-
if not TORCH_VERSION_AT_LEAST_2_7:
1886+
if not torch_version_at_least("2.7.0"):
18871887
unwrap_tensor_subclass(model)
18881888

18891889
# running model
@@ -1942,7 +1942,7 @@ def forward(self, x):
19421942
model(x)
19431943

19441944
api(model)
1945-
if not TORCH_VERSION_AT_LEAST_2_7:
1945+
if not torch_version_at_least("2.7.0"):
19461946
unwrap_tensor_subclass(model)
19471947

19481948
# running model

test/integration/test_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import torch
1818

1919
from packaging import version
20-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8
20+
from torchao.utils import torch_version_at_least
2121

22-
if not TORCH_VERSION_AT_LEAST_2_8:
22+
if not torch_version_at_least("2.8.0"):
2323
pytest.skip("Requires PyTorch 2.8 or higher", allow_module_level=True)
2424

2525

0 commit comments

Comments
 (0)