Skip to content

Commit

Permalink
dfttest2.py: added support for hip backends
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Mar 10, 2024
1 parent 9017199 commit f7b9413
Showing 1 changed file with 60 additions and 7 deletions.
67 changes: 60 additions & 7 deletions dfttest2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.3"
__version__ = "0.4.0"

from dataclasses import dataclass
import math
Expand Down Expand Up @@ -31,7 +31,17 @@ class CPU:
class GCC:
pass

backendT = typing.Union[Backend.cuFFT, Backend.NVRTC, Backend.CPU, Backend.GCC]
@dataclass(frozen=False)
class hipFFT:
device_id: int = 0
in_place: bool = True

@dataclass(frozen=False)
class HIPRTC:
device_id: int = 0
num_streams: int = 1

backendT = typing.Union[Backend.cuFFT, Backend.NVRTC, Backend.CPU, Backend.GCC, Backend.hipFFT, Backend.HIPRTC]


def init_backend(backend: backendT) -> backendT:
Expand All @@ -43,6 +53,10 @@ def init_backend(backend: backendT) -> backendT:
backend = Backend.CPU()
elif backend is Backend.GCC: # type: ignore
backend = Backend.GCC()
elif backend is Backend.hipFFT: # type: ignore
backend = Backend.hipFFT()
elif backend is Backend.HIPRTC: # type: ignore
backend = Backend.HIPRTC()
return backend


Expand Down Expand Up @@ -258,7 +272,7 @@ def DFTTest2(
zero_mean = zmean
backend = init_backend(backend)

if isinstance(backend, (Backend.CPU, Backend.NVRTC, Backend.GCC)):
if isinstance(backend, (Backend.CPU, Backend.NVRTC, Backend.GCC, Backend.HIPRTC)):
if radius not in range(4):
raise ValueError("invalid radius (tbsize)")
if block_size != 16:
Expand Down Expand Up @@ -336,6 +350,10 @@ def DFTTest2(
rdft = core.dfttest2_cpu.RDFT
elif isinstance(backend, Backend.GCC):
rdft = core.dfttest2_gcc.RDFT
elif isinstance(backend, Backend.hipFFT):
rdft = core.dfttest2_hip.RDFT
elif isinstance(backend, Backend.HIPRTC):
rdft = core.dfttest2_hiprtc.RDFT
else:
raise TypeError("unknown backend")

Expand Down Expand Up @@ -386,6 +404,10 @@ def DFTTest2(
to_single = core.dfttest2_cuda.ToSingle
elif isinstance(backend, Backend.NVRTC):
to_single = core.dfttest2_nvrtc.ToSingle
elif isinstance(backend, Backend.hipFFT):
to_single = core.dfttest2_hip.ToSingle
elif isinstance(backend, Backend.HIPRTC):
to_single = core.dfttest2_hiprtc.ToSingle
else:
raise TypeError("unknown backend")

Expand Down Expand Up @@ -487,6 +509,29 @@ def DFTTest2(
device_id=backend.device_id,
num_streams=backend.num_streams
)
if isinstance(backend, Backend.hipFFT):
return core.dfttest2_hip.DFTTest(
clip,
kernel=kernel,
radius=radius,
block_size=block_size,
block_step=block_step,
planes=planes,
in_place=backend.in_place,
device_id=backend.device_id
)
elif isinstance(backend, Backend.HIPRTC):
return core.dfttest2_hiprtc.DFTTest(
clip,
kernel=kernel,
radius=radius,
block_size=block_size,
block_step=block_step,
planes=planes,
in_place=False,
device_id=backend.device_id,
num_streams=backend.num_streams
)
else:
raise TypeError("unknown backend")

Expand All @@ -503,14 +548,21 @@ def select_backend(
if sbsize == 16 and tbsize in [1, 3, 5, 7]:
if hasattr(core, "dfttest2_nvrtc"):
return Backend.NVRTC()
elif hasattr(core, "dfttest2_hiprtc"):
return Backend.HIPRTC()
elif hasattr(core, "dfttest2_cuda"):
return Backend.cuFFT()
elif hasattr(core, "dfttest2_hip"):
return Backend.hipFFT()
elif hasattr(core, "dfttest2_cpu"):
return Backend.CPU()
else:
return Backend.GCC()
else:
return Backend.cuFFT()
if hasattr(core, "dfttest2_cuda"):
return Backend.cuFFT()
else:
return Backend.hipFFT()


FREQ = float
Expand Down Expand Up @@ -702,12 +754,13 @@ def DFTTest(
backend: Backend implementation to use.
All available backends can be found in the dfttest2.Backend "namespace":
dfttest2.Backend.{CPU, cuFFT, NVRTC, GCC}
dfttest2.Backend.{CPU, cuFFT, NVRTC, GCC, hipFFT, HIPRTC}
The CPU, NVRTC and GCC backends require sbsize=16.
The cuFFT and NVRTC backend require a CUDA-enabled system.
The cuFFT and NVRTC backends require a CUDA-enabled system.
The hipFFT and HIPRTC backends require a CUDA-enabled system.
Speed: NVRTC >> cuFFT > CPU == GCC
Speed: NVRTC == HIPRTC >> cuFFT > hipFFT > CPU == GCC
"""

if (
Expand Down

0 comments on commit f7b9413

Please sign in to comment.