Skip to content

Commit fd12761

Browse files
committed
Support PaddlePaddle with compatible API and tvm-ffi
1 parent 3697b39 commit fd12761

File tree

6 files changed

+71
-24
lines changed

6 files changed

+71
-24
lines changed

flashinfer/fp4_quantization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ def fp4_quantize_sm100(
180180
- Scale factors tensor with shape determined by layout and sf_vec_size
181181
"""
182182
if enable_pdl is None:
183-
enable_pdl = device_support_pdl(input.device)
183+
# enable_pdl = device_support_pdl(input.device)
184+
enable_pdl = device_support_pdl(input.place)
184185
out_val = torch.empty(
185186
(*input.shape[:-1], input.shape[-1] // 2),
186187
dtype=torch.uint8,
@@ -480,9 +481,11 @@ def fp4_quantize(
480481

481482
assert input.shape[-1] % sf_vec_size == 0
482483
if enable_pdl is None:
483-
enable_pdl = device_support_pdl(input.device)
484+
# enable_pdl = device_support_pdl(input.device)
485+
enable_pdl = device_support_pdl(input.place)
484486
# get input device sm version
485-
major, minor = get_compute_capability(input.device)
487+
# major, minor = get_compute_capability(input.device)
488+
major, minor = get_compute_capability(input.place)
486489
x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100(
487490
input,
488491
global_scale,

flashinfer/fused_moe/core.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from typing import Any, Dict, List, Optional, Tuple, Union
2121

2222
import torch
23-
import tvm_ffi
23+
import paddle
24+
25+
with paddle.compat.use_torch_proxy_guard(enable=False):
26+
import tvm_ffi
2427

2528
from ..artifacts import ArtifactPath, MetaInfoHash
2629
from ..autotuner import (
@@ -463,11 +466,15 @@ def __init__(
463466
use_mxfp8_act_scaling,
464467
)
465468

469+
def paddle_dtype_to_tvm_ffi_dtype(dtype: paddle.dtype):
470+
dtype_str = str(dtype).split(".", 1)[-1]
471+
return tvm_ffi.dtype(dtype_str)
472+
466473
if instance_key not in MoERunner.runner_dict:
467474
MoERunner.runner_dict[instance_key] = module.init(
468-
x_dtype,
469-
weight_dtype,
470-
output_dtype,
475+
paddle_dtype_to_tvm_ffi_dtype(x_dtype),
476+
paddle_dtype_to_tvm_ffi_dtype(weight_dtype),
477+
paddle_dtype_to_tvm_ffi_dtype(output_dtype),
471478
use_deepseek_fp8_block_scale,
472479
use_w4_group_scaling,
473480
use_mxfp8_act_scaling,
@@ -565,7 +572,8 @@ def cutlass_fused_moe(
565572
enable_pdl: Optional[bool] = None,
566573
) -> List[torch.Tensor]:
567574
if enable_pdl is None:
568-
enable_pdl = device_support_pdl(input.device)
575+
# enable_pdl = device_support_pdl(input.device)
576+
enable_pdl = device_support_pdl(input.place)
569577
tuner = AutoTuner.get()
570578
MoERunner.refine_tuning_config(tune_max_num_tokens)
571579

@@ -623,17 +631,22 @@ def cutlass_fused_moe(
623631
else moe_runner.fused_moe_runner.run_moe
624632
)
625633
num_active_experts_per_node = torch.empty(
626-
(1,), dtype=torch.int32, device=input.device
634+
# (1,), dtype=torch.int32, device=input.device
635+
(1,),
636+
dtype=torch.int32,
637+
device=input.place,
627638
)
628639
experts_to_token_score = torch.empty(
629640
(fc2_expert_weights.shape[0], input.shape[0]),
630641
dtype=torch.float32,
631-
device=input.device,
642+
# device=input.device,
643+
device=input.place,
632644
)
633645
active_expert_global_ids = torch.empty(
634646
(fc2_expert_weights.shape[0],),
635647
dtype=torch.int32,
636-
device=input.device,
648+
# device=input.device,
649+
device=input.place,
637650
)
638651
min_latency_output = (
639652
[
@@ -897,7 +910,8 @@ def cutlass_fused_moe(
897910
raise NotImplementedError("min latency mode not yet implemented for Blackwell.")
898911

899912
if enable_pdl is None:
900-
enable_pdl = device_support_pdl(input.device)
913+
# enable_pdl = device_support_pdl(input.device)
914+
enable_pdl = device_support_pdl(input.place)
901915

902916
num_rows = input.shape[0]
903917
if min_latency_mode:
@@ -906,10 +920,16 @@ def cutlass_fused_moe(
906920
output_shape = (num_rows, hidden_size)
907921

908922
if output is None:
909-
output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
923+
# output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
924+
output = torch.empty(output_shape, dtype=output_dtype, device=input.place)
910925
else:
911926
check_shape_dtype_device(
912-
output, output_shape, output_dtype, input.device, "output"
927+
# output, output_shape, output_dtype, input.device, "output"
928+
output,
929+
output_shape,
930+
output_dtype,
931+
input.place,
932+
"output",
913933
)
914934

915935
major, minor = torch.cuda.get_device_capability()

flashinfer/jit/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import dataclasses
22
import logging
33
import os
4-
import tvm_ffi
4+
import paddle
5+
6+
with paddle.compat.use_torch_proxy_guard(enable=False):
7+
import tvm_ffi
58
from contextlib import nullcontext
69
from pathlib import Path
710
from typing import Dict, List, Optional, Sequence, Union

flashinfer/jit/cpp_ext.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from pathlib import Path
1111
from typing import List, Optional
1212

13-
import tvm_ffi
13+
import paddle
14+
15+
with paddle.compat.use_torch_proxy_guard(enable=False):
16+
import tvm_ffi
1417
import torch
1518

1619
from . import env as jit_env

flashinfer/utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616

1717
import functools
1818
import math
19+
import os
1920
from enum import Enum
2021
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union
2122

2223
import torch
2324
import torch.version
24-
from torch.torch_version import TorchVersion
25-
from torch.torch_version import __version__ as torch_version
2625

2726
from .jit import gen_jit_spec, env as jit_env
2827

@@ -231,6 +230,7 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
231230

232231
@functools.cache
233232
def get_compute_capability(device: torch.device) -> Tuple[int, int]:
233+
return torch.device.cuda.get_device_capability(device.gpu_device_id())
234234
if device.type != "cuda":
235235
raise ValueError("device must be a cuda device")
236236
return torch.cuda.get_device_capability(device.index)
@@ -249,7 +249,13 @@ def _check_cached_qkv_data_type(
249249
)
250250

251251

252-
if TorchVersion(torch_version) < TorchVersion("2.4"):
252+
def use_paddle_compatible_api() -> bool:
253+
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]
254+
255+
256+
if use_paddle_compatible_api() or torch.torch_version.TorchVersion(
257+
torch.torch_version.__version__
258+
) < torch.torch_version.TorchVersion("2.4"):
253259

254260
def register_custom_op(
255261
name: str,
@@ -492,15 +498,16 @@ def check_shape_dtype_device(
492498
expected_device: Optional[torch.device],
493499
name: str,
494500
) -> None:
495-
if expected_shape and x.shape != torch.Size(expected_shape):
501+
if expected_shape and tuple(x.shape) != torch.Size(expected_shape):
496502
raise ValueError(
497503
f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}"
498504
)
499505
if expected_dtype and x.dtype != expected_dtype:
500506
raise ValueError(
501507
f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}"
502508
)
503-
if expected_device and x.device != expected_device:
509+
# if expected_device and x.device != expected_device:
510+
if expected_device and x.place != expected_device:
504511
raise ValueError(
505512
f"Invalid device of {name}: expected {expected_device}, got {x.device}"
506513
)
@@ -548,8 +555,8 @@ def set_log_level(lvl_str: str) -> None:
548555

549556

550557
def device_support_pdl(device: torch.device) -> bool:
551-
if device.type != "cuda":
552-
return False
558+
# if device.type != "cuda":
559+
# return False
553560
major, _ = get_compute_capability(device)
554561
return major >= 9
555562

setup.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir())
3232

3333

34+
def use_paddle_compatible_api() -> bool:
35+
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]
36+
37+
3438
def write_if_different(path: Path, content: str) -> None:
3539
if path.exists() and path.read_text() == content:
3640
return
@@ -83,7 +87,6 @@ def generate_build_meta(aot_build_meta: dict) -> None:
8387
cmdclass: Mapping[str, type[setuptools.Command]] = {}
8488
install_requires = [
8589
"numpy",
86-
"torch",
8790
"ninja",
8891
"requests",
8992
"nvidia-ml-py",
@@ -95,9 +98,17 @@ def generate_build_meta(aot_build_meta: dict) -> None:
9598
"packaging>=24.2",
9699
"nvidia-cudnn-frontend>=1.13.0",
97100
]
101+
if not use_paddle_compatible_api():
102+
install_requires.append("torch")
103+
98104
generate_build_meta({})
99105

100106
if enable_aot:
107+
if use_paddle_compatible_api():
108+
import paddle
109+
110+
paddle.compat.enable_torch_proxy()
111+
101112
import torch
102113

103114
cuda_version = get_cuda_version()

0 commit comments

Comments
 (0)