Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transpose 2d v1 #434

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/scripts/bench/bench_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def bench_conv2d(params: str, *args, **kwargs) -> float:
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_transpose2d(params: str, *args, **kwargs) -> float:
x_shape = params
x_shape = [int(s) for s in x_shape.split('x')]
x = hidet.symbol(x_shape, dtype='float32', device='cuda')
o = hidet.ops.transpose(x)
g = hidet.trace_from(o, inputs=[x])
g = hidet.graph.optimize(g)
g = g.cuda_graph()
return bench_torch_model(lambda: g.run_async(), [])

def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float:
x_shape, w_shape = params.split(',')
x_shape = [int(s) for s in x_shape.split('x')]
Expand Down Expand Up @@ -101,6 +111,7 @@ def bench_reduce(params: str, *args, **kwargs) -> float:
'matmul_f16': bench_matmul_f16,
'batch_matmul': bench_batch_matmul,
'conv2d': bench_conv2d,
'transpose2d' : bench_transpose2d,
'conv2d_gemm_f16': bench_conv2d_gemm_f16,
'attn': bench_attn,
'attn_mask_add': bench_attn_mask_add,
Expand Down
90 changes: 89 additions & 1 deletion python/hidet/graph/ops/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
from typing import List, Optional, Union, Sequence
from hidet.ir.type import DataType, data_type
from hidet.ir.expr import Expr, Constant, if_then_else, convert, cast as ir_cast, is_constant
from hidet.ir.expr import Int
from hidet.ir.expr import Int, logical_and
from hidet.ir.layout import RowMajorLayout
from hidet.ir.utils import index_deserialize, index_serialize
from hidet.utils import prod
from .utils import Task, InverseMap, Operator, Tensor, TensorNode, compute, input_like, normalize_dim, can_broadcast
from .utils import TensorInput, normalize_slice
from hidet.ir.module import IRModule
from hidet.ir.library import tune
from hidet.utils.py import cdiv


def is_true(x: Union[Expr, bool]) -> bool:
Expand Down Expand Up @@ -320,6 +323,84 @@ def fmap(*indices):
super().__init__(name='tile', inputs=[data], outputs=[out])


class TransposeTask2D(Task):
def __init__(self, input: TensorNode):
self.input_shape = input.shape
self.input_dtype = input.type.dtype
self.output_shape = [self.input_shape[1], self.input_shape[0]]

output = compute(name='output', shape=self.output_shape, fcompute=lambda i, j: input[j, i])

super().__init__(name='transpose2d', inputs=[input], outputs=[output], attributes={})

def allow_prologue(self) -> bool:
return False

def allow_epilogue(self) -> bool:
return True

def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
return tune.extract_ir_modules(self.cuda_schedule_threads_coarsening_transpose)

@tune.space(1, coarsen_factor_row=[1, 2, 3, 4], coarsen_factor_col=[1, 2, 3, 4])
def cuda_schedule_threads_coarsening_transpose(self, coarsen_factor_row=1, coarsen_factor_col=1) -> IRModule:
# pylint: disable=unused-variable
import hidet
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.lang.cuda import shared_tensor
from hidet.lang import attrs

input, output = self.inputs[0], self.outputs[0]
tile_size_baseline = 32
numElementsPerThread_row, numElementsPerThread_col = coarsen_factor_row, coarsen_factor_col
blockSize_row = min(cdiv(self.input_shape[0], numElementsPerThread_row), tile_size_baseline)
blockSize_col = min(cdiv(self.input_shape[1], numElementsPerThread_col), tile_size_baseline)
numElementsPerBlock_row = numElementsPerThread_row * blockSize_row
numElementsPerBlock_col = numElementsPerThread_col * blockSize_col

sharedMemSize_row, sharedMemSize_col = numElementsPerBlock_row, numElementsPerBlock_col
if blockSize_row % tile_size_baseline == 0 and blockSize_col % tile_size_baseline == 0:
sharedMemSize_col += 1
block_size = (blockSize_row, blockSize_col)
grid_size = (
cdiv(self.input_shape[0], numElementsPerBlock_row),
cdiv(self.input_shape[1], numElementsPerBlock_col),
)
with hidet.script_module() as module:

@hidet.script
def transpose_kernel(
input: self.input_dtype[self.input_shape], output: self.input_dtype[self.output_shape]
):
attrs.cuda.grid_dim = grid_size
attrs.cuda.block_dim = block_size
tile = shared_tensor(self.input_dtype, shape=[sharedMemSize_row, sharedMemSize_col])

for kx in range(coarsen_factor_row):
for ky in range(coarsen_factor_col):
tx = threadIdx.x + blockSize_row * kx
ty = threadIdx.y + blockSize_col * ky
xIndex = blockIdx.x * numElementsPerBlock_row + tx
yIndex = blockIdx.y * numElementsPerBlock_col + ty

if xIndex < self.input_shape[0] and yIndex < self.input_shape[1]:
tile[tx, ty] = input[xIndex, yIndex]

syncthreads()
for kx in range(coarsen_factor_row):
for ky in range(coarsen_factor_col):
tx = threadIdx.x + blockSize_row * kx
ty = threadIdx.y + blockSize_col * ky
xIndex = blockIdx.y * numElementsPerBlock_col + tx
yIndex = blockIdx.x * numElementsPerBlock_row + ty

if xIndex < self.output_shape[0] and yIndex < self.output_shape[1]:
output[xIndex, yIndex] = tile[ty, tx]

ir_module = module.ir_module()
return ir_module


class ReshapeOp(Operator):
def __init__(self, x: Tensor, shape):
task = ReshapeTask(input_like(x, 'x'), shape)
Expand Down Expand Up @@ -536,8 +617,15 @@ def flatten(x: Tensor, start_dim=0, end_dim=-1) -> Tensor:
return FlattenOp(x, start_dim, end_dim).outputs[0]


class TransposeOp2D(Operator):
def __init__(self, input: Tensor):
super().__init__(inputs=[input], attributes={}, task=TransposeTask2D(input_like(input, 'input')))


def transpose(x: Tensor, axes: Optional[Sequence[int]] = None) -> Tensor:
rank = len(x.shape)
if rank == 2:
TransposeOp2D(x).outputs[0]
if axes is None:
axes = list(reversed(range(rank)))
axes = [normalize_dim(dim, rank) for dim in axes]
Expand Down
5 changes: 5 additions & 0 deletions tests/operators/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def test_transpose(shape, axes):
check_transform(shape, lambda x: np.transpose(x, axes), lambda x: ops.transpose(x, axes))


@pytest.mark.parametrize("shape", [[33, 44], [1, 100], [100, 1], [10, 20], [20, 10], [100, 200], [2000, 3000]])
def test_transpose_2d(shape):
check_transform(shape, lambda x: np.transpose(x), lambda x: ops.transpose(x))


@pytest.mark.parametrize(
"shapes, dtype, axis",
[
Expand Down
Loading