Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
12cc124
add cutlass ragged dot
pcmoritz Jan 18, 2026
70dce5f
update
pcmoritz Jan 18, 2026
d70e010
update
pcmoritz Jan 18, 2026
ac85d10
update
pcmoritz Jan 18, 2026
106e4ae
update
pcmoritz Jan 18, 2026
94eb625
add backward
pcmoritz Jan 18, 2026
cf80c97
update
pcmoritz Jan 18, 2026
7f1fe1d
update
pcmoritz Jan 18, 2026
1d9df09
use grouped gemm
pcmoritz Jan 18, 2026
527c1a0
update
pcmoritz Jan 18, 2026
656756f
update
pcmoritz Jan 18, 2026
aee36c7
update
pcmoritz Jan 18, 2026
cfb3404
Merge branch 'tx-ragged-dot-cutlass' of github.com:pcmoritz/SkyRL int…
pcmoritz Jan 18, 2026
6e9ead9
update
pcmoritz Jan 18, 2026
63914e7
update
pcmoritz Jan 18, 2026
f92d00d
update
pcmoritz Jan 18, 2026
ad0bfee
fix
pcmoritz Jan 18, 2026
70cba86
update
pcmoritz Jan 18, 2026
3f4dd25
optimize
pcmoritz Jan 18, 2026
3f6669d
fixes
pcmoritz Jan 19, 2026
7b22f86
optimize
pcmoritz Jan 19, 2026
accff8e
try to use clusters
pcmoritz Jan 19, 2026
f1fb36c
update schedule
pcmoritz Jan 19, 2026
b1c48f4
try tile size
pcmoritz Jan 19, 2026
046a033
update
pcmoritz Jan 19, 2026
5b14a8a
optimize
pcmoritz Jan 19, 2026
23a74e5
optimize
pcmoritz Jan 19, 2026
4c86409
simplify
pcmoritz Jan 19, 2026
2dcce20
simplify
pcmoritz Jan 19, 2026
e731efe
add lto
pcmoritz Jan 19, 2026
8ce60ce
fix
pcmoritz Jan 19, 2026
2a005e2
update
pcmoritz Jan 19, 2026
083c150
update
pcmoritz Jan 19, 2026
7370e39
add flags
pcmoritz Jan 19, 2026
f2f7e6c
update
pcmoritz Jan 19, 2026
bbb6004
replace backward
pcmoritz Jan 19, 2026
20bb382
add proper backward
pcmoritz Jan 19, 2026
0b51d0b
optimize
pcmoritz Jan 19, 2026
80a276f
update
pcmoritz Jan 19, 2026
fdfff39
update
pcmoritz Jan 19, 2026
3923e99
update
pcmoritz Jan 19, 2026
0acef25
optimize
pcmoritz Jan 19, 2026
cc26ba2
simplify
pcmoritz Jan 19, 2026
9b72ae7
update
pcmoritz Jan 19, 2026
1f1e728
try without caching
pcmoritz Jan 19, 2026
1717190
clean up
pcmoritz Jan 19, 2026
3199c0d
update
pcmoritz Jan 19, 2026
1071a5a
update
pcmoritz Jan 19, 2026
688c962
simplify
pcmoritz Jan 19, 2026
de0b140
simplify
pcmoritz Jan 19, 2026
9586e57
simplify
pcmoritz Jan 19, 2026
5a96b68
simplify
pcmoritz Jan 19, 2026
11fa3c4
update
pcmoritz Jan 19, 2026
6698906
update
pcmoritz Jan 19, 2026
5f9c413
cleanup
pcmoritz Jan 19, 2026
e23037d
simplify
pcmoritz Jan 19, 2026
56968ba
simplify
pcmoritz Jan 19, 2026
0a184e4
unify code
pcmoritz Jan 19, 2026
4292fc8
update
pcmoritz Jan 19, 2026
0392cac
update
pcmoritz Jan 19, 2026
eb5e004
update
pcmoritz Jan 19, 2026
1b19afa
update build
pcmoritz Jan 19, 2026
7aa2fe7
update
pcmoritz Jan 19, 2026
be21295
update
pcmoritz Jan 19, 2026
5b67b4c
update
pcmoritz Jan 19, 2026
a29c3ef
simplify
pcmoritz Jan 19, 2026
ce86680
update
pcmoritz Jan 19, 2026
8220f6e
update
pcmoritz Jan 19, 2026
86e3ad9
optimize
pcmoritz Jan 19, 2026
d16a825
try masking
pcmoritz Jan 19, 2026
e346551
revert
pcmoritz Jan 19, 2026
b758b2b
update
pcmoritz Jan 19, 2026
2e2fff0
update
pcmoritz Jan 19, 2026
a2a390d
add benchmark script
pcmoritz Jan 19, 2026
729d816
add tuning script
pcmoritz Jan 19, 2026
5c1bed3
fix
pcmoritz Jan 19, 2026
5cf3521
fix
pcmoritz Jan 19, 2026
2fb40bc
update tile size
pcmoritz Jan 19, 2026
c23c9e7
fine grained sweeping (works before)
pcmoritz Jan 19, 2026
2b1c3d6
optimize
pcmoritz Jan 19, 2026
2282958
Revert to state at 2fb40bc2 (update tile size)
pcmoritz Jan 19, 2026
f4cfed2
use kernel for everything
pcmoritz Jan 19, 2026
3feab96
update
pcmoritz Jan 19, 2026
f03a1e0
update
pcmoritz Jan 19, 2026
89b452f
update
pcmoritz Jan 19, 2026
2b997fa
revert
pcmoritz Jan 19, 2026
ddf1ab7
update
pcmoritz Jan 20, 2026
fc8c75b
fix
pcmoritz Jan 20, 2026
b0e14f3
update
pcmoritz Jan 20, 2026
79b0d44
update
pcmoritz Jan 20, 2026
1a7485f
add tests for lora
pcmoritz Jan 20, 2026
8fa5c2e
update
pcmoritz Jan 20, 2026
f6a6a92
update
pcmoritz Jan 20, 2026
2db8415
update
pcmoritz Jan 20, 2026
ae19ae9
update
pcmoritz Jan 20, 2026
988cc06
update tiles
pcmoritz Jan 20, 2026
0167744
update
pcmoritz Jan 20, 2026
3dfcc14
update
pcmoritz Jan 20, 2026
50357b4
update
pcmoritz Jan 20, 2026
58956a4
update
pcmoritz Jan 20, 2026
447305d
update
pcmoritz Jan 20, 2026
c1f1ab1
update
pcmoritz Jan 20, 2026
fa814f1
update
pcmoritz Jan 20, 2026
f72b285
update
pcmoritz Jan 20, 2026
2ef0d7e
update
pcmoritz Jan 20, 2026
59be86f
update
pcmoritz Jan 20, 2026
ce1eba9
update
pcmoritz Jan 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions skyrl-tx/benchmarks/bench_ragged_dot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
"""Benchmark ragged_dot CUTLASS kernel with Qwen3-30B-A3B MoE and LoRA shapes."""

import argparse
import time

import jax
import jax.numpy as jnp
from jax import lax

from tx.ffi import ragged_dot_ffi, ragged_dot_ffi_available


# Preset configurations for different workloads
PRESETS = {
"moe": {
"description": "MoE expert layer (Qwen3-30B-A3B)",
"num_tokens": 8192,
"num_groups": 128, # num_experts
"k_dim": 2048, # hidden_size
"n_dim": 768, # intermediate_size
},
"lora": {
"description": "LoRA adapter layer",
"num_tokens": 8192,
"num_groups": 32, # max_lora_adapters
"k_dim": 8, # lora_rank
"n_dim": 4096, # output features
},
"lora-moe": {
"description": "LoRA on MoE experts (combined groups)",
"num_tokens": 8192,
"num_groups": 1024, # num_experts * max_lora_adapters (128 * 8, capped at kernel limit)
"k_dim": 8, # lora_rank
"n_dim": 768, # intermediate_size
},
}


def generate_group_sizes(num_tokens: int, num_groups: int, key: jax.Array) -> jax.Array:
"""Generate random group sizes that sum to num_tokens."""
# Random assignment of tokens to groups
assignments = jax.random.randint(key, (num_tokens,), 0, num_groups)
return jnp.bincount(assignments, length=num_groups).astype(jnp.int32)


def benchmark_forward(
num_tokens: int,
k_dim: int,
n_dim: int,
num_groups: int,
num_warmup: int = 5,
num_iters: int = 20,
use_ffi: bool = True,
):
"""Benchmark forward pass: lhs[M, K] @ rhs[G, K, N] -> out[M, N]."""
key = jax.random.PRNGKey(42)
k1, k2, k3 = jax.random.split(key, 3)

lhs = jax.random.normal(k1, (num_tokens, k_dim), dtype=jnp.bfloat16)
rhs = jax.random.normal(k2, (num_groups, k_dim, n_dim), dtype=jnp.bfloat16)
group_sizes = generate_group_sizes(num_tokens, num_groups, k3)
group_offset = jnp.array([0], dtype=jnp.int32)

if use_ffi:
fn = lambda: ragged_dot_ffi(lhs, rhs, group_sizes, group_offset)
else:
fn = lambda: lax.ragged_dot(lhs, rhs, group_sizes)

# Warmup
for _ in range(num_warmup):
out = fn()
out.block_until_ready()

# Benchmark
start = time.perf_counter()
for _ in range(num_iters):
out = fn()
out.block_until_ready()
elapsed = time.perf_counter() - start

# FLOPs: 2 * M * K * N (matmul FLOPs)
flops = 2 * num_tokens * k_dim * n_dim
tflops = (flops * num_iters / elapsed) / 1e12

return elapsed / num_iters, tflops


def benchmark_backward(
num_tokens: int,
k_dim: int,
n_dim: int,
num_groups: int,
num_warmup: int = 5,
num_iters: int = 20,
use_ffi: bool = True,
):
"""Benchmark backward pass through ragged_dot."""
key = jax.random.PRNGKey(42)
k1, k2, k3 = jax.random.split(key, 3)

lhs = jax.random.normal(k1, (num_tokens, k_dim), dtype=jnp.bfloat16)
rhs = jax.random.normal(k2, (num_groups, k_dim, n_dim), dtype=jnp.bfloat16)
group_sizes = generate_group_sizes(num_tokens, num_groups, k3)
group_offset = jnp.array([0], dtype=jnp.int32)

if use_ffi:
def forward(lhs, rhs):
return ragged_dot_ffi(lhs, rhs, group_sizes, group_offset).sum()
else:
def forward(lhs, rhs):
return lax.ragged_dot(lhs, rhs, group_sizes).sum()

grad_fn = jax.grad(forward, argnums=(0, 1))

# Warmup
for _ in range(num_warmup):
d_lhs, d_rhs = grad_fn(lhs, rhs)
d_lhs.block_until_ready()
d_rhs.block_until_ready()

# Benchmark
start = time.perf_counter()
for _ in range(num_iters):
d_lhs, d_rhs = grad_fn(lhs, rhs)
d_lhs.block_until_ready()
d_rhs.block_until_ready()
elapsed = time.perf_counter() - start

# Backward FLOPs: d_lhs = grad @ rhs.T (2*M*N*K) + d_rhs = lhs.T @ grad (2*K*M*N)
# Total: 4 * M * K * N
flops = 4 * num_tokens * k_dim * n_dim
tflops = (flops * num_iters / elapsed) / 1e12

return elapsed / num_iters, tflops


def run_benchmark_suite(
num_tokens: int,
k_dim: int,
n_dim: int,
num_groups: int,
num_warmup: int,
num_iters: int,
run_forward: bool,
run_backward: bool,
):
"""Run the benchmark suite with the given configuration."""
if run_forward:
print("Forward Pass (lhs[M,K] @ rhs[G,K,N] -> out[M,N])")
print("-" * 60)

if ragged_dot_ffi_available():
ffi_time, ffi_tflops = benchmark_forward(
num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=True
)
print(f" CUTLASS FFI: {ffi_time*1000:8.3f} ms {ffi_tflops:8.2f} TFLOPS")

jax_time, jax_tflops = benchmark_forward(
num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=False
)
print(f" JAX ragged: {jax_time*1000:8.3f} ms {jax_tflops:8.2f} TFLOPS")

if ragged_dot_ffi_available():
print(f" Speedup: {jax_time/ffi_time:.2f}x")
print()

if run_backward:
print("Backward Pass (grad wrt lhs and rhs)")
print("-" * 60)

if ragged_dot_ffi_available():
ffi_time, ffi_tflops = benchmark_backward(
num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=True
)
print(f" CUTLASS FFI: {ffi_time*1000:8.3f} ms {ffi_tflops:8.2f} TFLOPS")

jax_time, jax_tflops = benchmark_backward(
num_tokens, k_dim, n_dim, num_groups, num_warmup, num_iters, use_ffi=False
)
print(f" JAX ragged: {jax_time*1000:8.3f} ms {jax_tflops:8.2f} TFLOPS")

if ragged_dot_ffi_available():
print(f" Speedup: {jax_time/ffi_time:.2f}x")
print()


def main():
parser = argparse.ArgumentParser(description="Benchmark ragged_dot CUTLASS kernel")
parser.add_argument("--preset", choices=list(PRESETS.keys()), help="Use a preset configuration (moe, lora, lora-moe)")
parser.add_argument("--all-presets", action="store_true", help="Run all preset configurations")
parser.add_argument("--num-tokens", type=int, default=8192, help="Number of tokens (M)")
parser.add_argument("--num-groups", type=int, default=128, help="Number of groups (G) - experts or adapters")
parser.add_argument("--k-dim", type=int, default=2048, help="K dimension - hidden_size (MoE) or lora_rank (LoRA)")
parser.add_argument("--n-dim", type=int, default=768, help="N dimension - output features")
parser.add_argument("--num-warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--num-iters", type=int, default=20, help="Benchmark iterations")
parser.add_argument("--backward-only", action="store_true", help="Only benchmark backward pass")
parser.add_argument("--forward-only", action="store_true", help="Only benchmark forward pass")
args = parser.parse_args()

print("Ragged Dot Benchmark")
print("=" * 60)
print(f"CUTLASS FFI available: {ragged_dot_ffi_available()}")
print(f"JAX backend: {jax.default_backend()}")
print(f"Devices: {jax.device_count()}")
print()

run_forward = not args.backward_only
run_backward = not args.forward_only

if args.all_presets:
# Run all presets
for preset_name, preset in PRESETS.items():
print("=" * 60)
print(f"Preset: {preset_name} - {preset['description']}")
print("=" * 60)
print(f"Config:")
print(f" num_tokens (M): {preset['num_tokens']}")
print(f" num_groups (G): {preset['num_groups']}")
print(f" k_dim (K): {preset['k_dim']}")
print(f" n_dim (N): {preset['n_dim']}")
print(f" warmup/iters: {args.num_warmup}/{args.num_iters}")
print()
run_benchmark_suite(
preset['num_tokens'], preset['k_dim'], preset['n_dim'], preset['num_groups'],
args.num_warmup, args.num_iters, run_forward, run_backward
)
elif args.preset:
# Use a specific preset
preset = PRESETS[args.preset]
print(f"Preset: {args.preset} - {preset['description']}")
print()
print(f"Config:")
print(f" num_tokens (M): {preset['num_tokens']}")
print(f" num_groups (G): {preset['num_groups']}")
print(f" k_dim (K): {preset['k_dim']}")
print(f" n_dim (N): {preset['n_dim']}")
print(f" warmup/iters: {args.num_warmup}/{args.num_iters}")
print()
run_benchmark_suite(
preset['num_tokens'], preset['k_dim'], preset['n_dim'], preset['num_groups'],
args.num_warmup, args.num_iters, run_forward, run_backward
)
else:
# Use custom config from args
print(f"Config:")
print(f" num_tokens (M): {args.num_tokens}")
print(f" num_groups (G): {args.num_groups}")
print(f" k_dim (K): {args.k_dim}")
print(f" n_dim (N): {args.n_dim}")
print(f" warmup/iters: {args.num_warmup}/{args.num_iters}")
print()
run_benchmark_suite(
args.num_tokens, args.k_dim, args.n_dim, args.num_groups,
args.num_warmup, args.num_iters, run_forward, run_backward
)


if __name__ == "__main__":
main()
Loading
Loading