Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions python/tvm/dlight/gpu/general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def f_layout_mapping(*iters):
sch.transform_block_layout(block_infos[-1].block_rv, index_map)

try:
# TODO: fix num_leading_s = 0 case
# Handle the case where num_leading_s = 0
if num_leading_s == 0:
num_leading_s = 1 # Use at least one spatial dimension for blockIdx.x

assert num_trailing_r > 0
for block in block_infos[1:-1]:
assert block.dom_kind() == dom_kind
Expand All @@ -100,7 +103,13 @@ def f_layout_mapping(*iters):
return None

loops = sch.get_loops(block_infos[-1].block_rv)
bx = sch.fuse(*loops[:num_leading_s])

# Ensure we have at least one spatial dimension for blockIdx.x
if num_leading_s > 0:
bx = sch.fuse(*loops[:num_leading_s])
else:
bx = loops[0] # Use the first loop as blockIdx.x

r_loop, tx = sch.split(loops[-1], [None, len_tx])
sch.reorder(tx, r_loop)
sch.bind(bx, "blockIdx.x")
Expand Down
47 changes: 46 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,54 @@ def _round(self, node: fx.Node) -> relax.Expr:
return self.block_builder.emit(relax.op.round(arg))

def _softmax(self, node: fx.Node) -> relax.Var:
"""
For large tensors with non-last dimension softmax, we transpose to move
the softmax dimension to the end, apply softmax, and then transpose
back to the original shape.
"""
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
input_shape = x.struct_info.shape
input_ndim = len(input_shape)

if dim < 0:
# Ensure dim is express as a positive index
dim = input_ndim + dim

# Check if this is a non-last dimension with large size (> 1024)
# The smallest power of 2 that doesn't work on a NVIDIA GeForce RTX
# 4090 if 8192. Using 1024 here to be safe.
is_large_non_last_dim = False
large_size_threshold = 1024

if dim != input_ndim - 1: # Not the last dimension
# Check if any dimension is large
for i, size in enumerate(input_shape):
if hasattr(size, "value") and size.value > large_size_threshold:
is_large_non_last_dim = True
break

if is_large_non_last_dim:
# Special handling for large tensors with non-last dimension softmax

# Get dimension ordering for transpose
dims = list(range(input_ndim))
dims.append(dims.pop(dim))

# Transpose
x_transposed = self.block_builder.emit(relax.op.permute_dims(x, dims))

# Apply softmax on last dimension
softmax_result = self.block_builder.emit(relax.op.nn.softmax(x_transposed, -1))

# Transpose back to original shape
inv_dims = [-1] * len(dims)
for i, d in enumerate(dims):
inv_dims[d] = i
return self.block_builder.emit(relax.op.permute_dims(softmax_result, inv_dims))
else:
# Regular softmax
return self.block_builder.emit(relax.op.nn.softmax(x, dim))

def _selu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down
71 changes: 71 additions & 0 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import relax
import tvm.testing
import numpy as np
import torch
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program
from torch.nn import Softmax


def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev):
"""
This util ensures that a torch module can successfully be exported to TVM
using torch.export and that the resuling IR program gives the same result
as PyTorch when ran on CUDA.
"""
raw_data_for_tvm = raw_data.copy() # In case the data is modified
torch_data = torch.from_numpy(raw_data)
example_args = (torch_data,)

with torch.no_grad():
exported_program = export(torch_module, example_args)
mod_from_torch = from_exported_program(exported_program, keep_params_as_input=True)

tvm_mod, tvm_params = relax.frontend.detach_params(mod_from_torch)

relax_pipeline = relax.get_default_pipeline(tvm.target.Target.from_device(tvm.cuda()))
ex = relax.build(tvm_mod, target=target, relax_pipeline=relax_pipeline)
vm = relax.VirtualMachine(ex, dev)

gpu_data = tvm.nd.array(raw_data_for_tvm, dev)
gpu_params = [tvm.nd.array(p, dev) for p in tvm_params["main"]]
gpu_out = vm["main"](gpu_data, *gpu_params)

pytorch_out = torch_module(torch_data).detach().numpy()
actual = gpu_out[0].numpy()
desired = pytorch_out
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)


@tvm.testing.parametrize_targets("cuda")
def test_softmax_non_last_dim_large_tensor(target, dev):
"""
Tests ingesting a PyTorch exported model that uses softmax on a large
tensor, with the softmax dimension not being that last dimension, and
running it with CUDA.
"""
torch_module = Softmax(dim=2).eval()
raw_data = np.random.rand(10, 4, 32, 16384).astype("float32")
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev)


if __name__ == "__main__":
tvm.testing.main()
Loading