Skip to content

Commit

Permalink
[webgpu native] Add transpose shared (#22098)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
axinging authored Sep 27, 2024
1 parent 0f7a5f6 commit 41f6ff3
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 28 deletions.
91 changes: 74 additions & 17 deletions onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ ONNX_OPERATOR_KERNEL_EX(
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Transpose);

const std::string AppendPermFunction(gsl::span<const size_t> perm) {
const std::string AppendPermFunction(gsl::span<const int64_t> perm) {
std::ostringstream ss;
ss.imbue(std::locale::classic());
ss << "fn perm(i: y_indices_t)->x_indices_t {\n"
" var a: x_indices_t;\n";
ss << "fn perm(i: output_indices_t)->a_indices_t {\n"
" var a: a_indices_t;\n";
for (size_t i = 0; i < perm.size(); ++i) {
ss << " a[" << perm[i] << "] = i[" << i << "];\n";
}
Expand All @@ -60,21 +60,52 @@ const std::string AppendPermFunction(gsl::span<const size_t> perm) {
return ss.str();
}

auto SqueezeShape(const gsl::span<const int64_t>& shape, const gsl::span<const size_t>& adjusted_perm, InlinedVector<int64_t>& new_shape, InlinedVector<int64_t>& new_perm) {
for (auto i = 0; i < shape.size(); ++i) {
if (shape[i] != 1) {
new_shape.push_back(shape[i]);
}
if (shape[adjusted_perm[i]] != 1) {
new_perm.push_back(adjusted_perm[i]);
}
}
};

Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
shader.AppendImplementation(AppendPermFunction(this->perm_));
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"),
" let indices = ", output.OffsetToIndices("global_idx"),
";\n"
" let x_indices = perm(indices); \n"
" ",
output.SetByOffset("global_idx", input.GetByIndices("x_indices")));
const auto& input = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);

if (use_shared_) {
shader.AppendImplementation("var<workgroup> tile : array<array<output_value_t, tile_size + 1>, tile_size>;\n");
shader.SetMainFunctionBody(
" let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n"
" let workgroup_id_x = workgroup_idx % stride;\n"
" let workgroup_id_y = workgroup_idx / stride;\n"
" let input_col = workgroup_id_y * tile_size + local_id.x;\n"
" let input_row = workgroup_id_x * tile_size + local_id.y;\n"
" if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n"
" tile[local_id.y][local_id.x] = " +
input.GetByIndices("a_indices_t(input_row, input_col)") +
";\n"
" }\n"
" workgroupBarrier();\n"
" let output_col = workgroup_id_x * tile_size + local_id.x;\n"
" let output_row = workgroup_id_y * tile_size + local_id.y;\n"
" if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n " +
output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") + "\n }");
} else {
shader.AppendImplementation(AppendPermFunction(this->perm_));
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"),
" let indices = ", output.OffsetToIndices("global_idx"),
";\n"
" let x_indices = perm(indices);\n",
" ",
output.SetByOffset("global_idx", input.GetByIndices("x_indices")));
}
return Status::OK();
}

Status Transpose::ComputeInternal(ComputeContext& context) const {
// TODO: there is an optimized version of transpose to port.
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());
Expand All @@ -86,16 +117,42 @@ Status Transpose::ComputeInternal(ComputeContext& context) const {
TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);

InlinedVector<int64_t> new_shape{};
InlinedVector<int64_t> new_perm{};
SqueezeShape(input_shape.GetDims(), *p_perm, new_shape, new_perm);
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1});
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2});
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
auto new_input_shape = input_shape;
TensorShape new_output_shape(output_dims);
if (use_shared) {
new_input_shape = channels_last
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
: channels_first
? TensorShape({new_shape[0] * new_shape[1], new_shape[2]})
: new_shape;
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
}

uint32_t output_size = gsl::narrow_cast<int32_t>(input_tensor->Shape().Size());
TransposeProgram program{*p_perm};
TransposeProgram program{*p_perm, use_shared};
if (use_shared) {
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
}

program
.CacheHint(absl::StrJoin(*p_perm, "-"))
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({output_tensor})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
.AddUniformVariables({
{static_cast<uint32_t>(output_size)},
});

use_shared ? program.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
: program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

Expand Down
24 changes: 13 additions & 11 deletions onnxruntime/core/providers/webgpu/tensor/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,28 @@
namespace onnxruntime {
namespace webgpu {

class Transpose final : public WebGpuKernel, public TransposeBase {
public:
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
}
Status ComputeInternal(ComputeContext& context) const override;
constexpr static uint32_t TILE_SIZE = 16;
};

class TransposeProgram final : public Program<TransposeProgram> {
public:
TransposeProgram(const gsl::span<const size_t>& permutations)
: Program{"Transpose"}, perm_(permutations.begin(), permutations.end()) {
TransposeProgram(const gsl::span<const size_t>& permutations, bool use_shared)
: Program{"Transpose"}, perm_(permutations.begin(), permutations.end()), use_shared_(use_shared) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});
WEBGPU_PROGRAM_DEFINE_CONSTANTS({"tile_size", Transpose::TILE_SIZE});

private:
InlinedVector<size_t> perm_;
};

class Transpose final : public WebGpuKernel, public TransposeBase {
public:
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
}

Status ComputeInternal(ComputeContext& context) const override;
InlinedVector<int64_t> perm_;
const bool use_shared_;
};

} // namespace webgpu
Expand Down

0 comments on commit 41f6ff3

Please sign in to comment.