Skip to content

Commit

Permalink
[webgpu-native] Add gather (#22183)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
  • Loading branch information
qjia7 and fs-eire authored Sep 27, 2024
1 parent 41f6ff3 commit b1b5e1f
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 3 deletions.
82 changes: 82 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/gather.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/tensor/gather.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& data = shader.AddInput("data", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);

std::ostringstream calc_data_indices;
calc_data_indices.imbue(std::locale::classic());
calc_data_indices << " var indices_indices = input_indices_indices_t(0);\n";
for (int i = 0; i < indices.Rank(); i++) {
calc_data_indices << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n";
}
calc_data_indices << " var idx = " << indices.GetByIndices("indices_indices") << ";\n"
<< " if (idx < 0) {\n"
<< " idx = idx + input_indices_value_t(uniforms.data_shape[" << axis_ << "]);\n"
<< " }\n"
<< " var data_indices : data_indices_t;\n";
for (int i = 0, j = 0; i < data.Rank(); i++) {
if (i == SafeInt<int>(axis_)) {
calc_data_indices << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n";
j += indices.Rank();
} else {
calc_data_indices << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n";
j++;
}
}

shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"),
" let output_indices = ", output.OffsetToIndices("global_idx"), ";\n",
calc_data_indices.str(), " ",
output.SetByOffset("global_idx", data.GetByIndices("data_indices")));

return Status::OK();
}

Status Gather::ComputeInternal(ComputeContext& context) const {
Prepare p;
ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p));
uint32_t data_size = SafeInt<uint32_t>(p.output_tensor->Shape().Size());
if (data_size == 0) {
return Status::OK();
}

uint32_t axis = static_cast<uint32_t>(p.axis);
GatherProgram program{axis};
program
.AddInputs({{p.input_tensor, ProgramTensorMetadataDependency::TypeAndRank},
{p.indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutput({p.output_tensor, ProgramTensorMetadataDependency::Rank})
.SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.CacheHint(std::to_string(axis))
.AddUniformVariables({{data_size}});
return context.RunProgram(program);
}

#define WEBGPU_GATHER_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()), \
KERNEL_CLASS);

#define WEBGPU_GATHER_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE).TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList<TypeList<int32_t, int64_t>>()), \
KERNEL_CLASS);

WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 1, 10, Gather, WebGpuSupportedNumberTypes())
WEBGPU_GATHER_VERSIONED_KERNEL(Gather, 11, 12, Gather, WebGpuSupportedNumberTypes())
WEBGPU_GATHER_KERNEL(Gather, 13, Gather, WebGpuSupportedNumberTypes())

} // namespace webgpu
} // namespace onnxruntime
34 changes: 34 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/cpu/tensor/gatherbase.h"

namespace onnxruntime {
namespace webgpu {

class GatherProgram final : public Program<GatherProgram> {
public:
GatherProgram(const uint32_t axis) : Program{"Gather"}, axis_{axis} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32});

private:
uint32_t axis_;
};

class Gather final : public WebGpuKernel, public GatherBase {
public:
Gather(const OpKernelInfo& info) : WebGpuKernel(info), GatherBase(info) {}

protected:
Status ComputeInternal(ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Gather)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements)>,
Expand Down

0 comments on commit b1b5e1f

Please sign in to comment.