Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU EP] SoftMax Implementation #23538

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 231 additions & 0 deletions onnxruntime/core/providers/webgpu/math/softmax.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/providers/webgpu/math/softmax.h"
#include "core/providers/webgpu/tensor/transpose.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_variable.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Softmax,
kOnnxDomain,
1, 10,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Softmax,
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

ONNX_OPERATOR_KERNEL_EX(
Softmax,
kOnnxDomain,
13,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Softmax);

static std::string MaxVector(std::string name, int components) {
vraspar marked this conversation as resolved.
Show resolved Hide resolved
switch (components) {
case 1:
return name;
case 2:
return "max(" + name + ".x, " + name + ".y)";
case 3:
return "max(max(" + name + ".x, " + name + ".y), " + name + ".z)";
case 4:
return "max(max(" + name + ".x, " + name + ".y), max(" + name + ".z, " + name + ".w))";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

static std::string SumVector(std::string x, int components) {
vraspar marked this conversation as resolved.
Show resolved Hide resolved
switch (components) {
case 1:
return x;
case 2:
return "(" + x + ".x + " + x + ".y" + ")";
case 4:
return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

static int GetMaxComponents(int64_t size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Add input and output variables
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
int components = input.NumComponents();

std::string threadMaxDecl = input.ElementType() == "f32" ? "var threadMax = x_value_t(-3.402823e+38f);\n" : "var threadMax = x_value_t(-65504.0h);\n";

Check warning on line 84 in onnxruntime/core/providers/webgpu/math/softmax.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/softmax.cc:84: Add #include <string> for string [build/include_what_you_use] [4]
vraspar marked this conversation as resolved.
Show resolved Hide resolved

// Define shared memory for row max and row sum
shader.AdditionalImplementation()
<< "var<workgroup> rowMaxShared : x_value_t;\n"
vraspar marked this conversation as resolved.
Show resolved Hide resolved
<< "var<workgroup> rowSumShared : x_value_t;\n"
<< "var<workgroup> threadShared : array<x_value_t, " << WG << ">;\n";

// Define helper functions to get and set values
shader.AdditionalImplementation()
<< "fn getValue(row: i32, col: i32, row_stride: i32) -> x_value_t {\n"
<< " let index = row * row_stride + col;\n"
<< " return x[index];\n"
<< "}\n"
<< "fn setValue(row: i32, col: i32, row_stride: i32, value: x_value_t) {\n"
<< " let index = row * row_stride + col;\n"
<< " result[index] = value;\n"
<< "}\n";

// Main function body
shader.MainFunctionBody()
<< " let gindex = i32(global_idx);\n"
<< " let lindex = i32(local_idx);\n"
<< " const wg = " << WG << ";\n"
<< " let row = gindex / wg;\n"
<< " let cols = uniforms.packedCols;\n"
<< " let row_stride : i32 = uniforms.packedCols;\n"

// Find the row's max value
<< threadMaxDecl
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let value = getValue(row, col, row_stride);\n"
<< " threadMax = max(threadMax, value);\n"
<< " }\n"
<< " if (lindex < cols) {\n"
<< " threadShared[lindex] = threadMax;\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Reduce to find the max value
<< " var reduceSize = min(cols, wg);\n"
<< " for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {\n"
<< " reduceSize = currSize + (reduceSize & 1);\n"
<< " if (lindex < currSize) {\n"
<< " threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " rowMaxShared = x_value_t(" << MaxVector("threadShared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Find the row's sum of exponentials
<< " var threadSum = x_value_t(0.0);\n"
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);\n"
<< " threadSum += subExp;\n"
<< " }\n"
<< " threadShared[lindex] = threadSum;\n"
<< " workgroupBarrier();\n"

// Reduce to find the sum of exponentials
<< " for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) {\n"
<< " if (lindex < currSize) {\n"
<< " threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " rowSumShared = x_value_t(" << SumVector("threadShared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Calculate the final value for each element in the row
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;\n"
<< " setValue(row, col, row_stride, value);\n"
<< " }\n";

return Status::OK();
}

Status Softmax::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int64_t input_rank = input_shape.NumDimensions();
auto* output_tensor = context.Output(0, input_shape);

// normalize axis
int64_t axis = axis_ < 0 ? axis_ + input_rank : axis_;
bool is_transpose_required = axis < input_rank - 1;

TensorShape transposed_input_shape;
Tensor transposed_input_tensor;
Tensor intermediate_output;
InlinedVector<size_t> perm(input_rank);

if (is_transpose_required) {
std::iota(std::begin(perm), std::end(perm), 0);
perm[axis] = input_rank - 1;
perm[input_rank - 1] = axis;

std::vector<int64_t> transposed_input_dims;

Check warning on line 187 in onnxruntime/core/providers/webgpu/math/softmax.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/softmax.cc:187: Add #include <vector> for vector<> [build/include_what_you_use] [4]
vraspar marked this conversation as resolved.
Show resolved Hide resolved
for (auto e : perm) {
transposed_input_dims.push_back(input_shape[e]);
}

transposed_input_shape = TensorShape(transposed_input_dims);
transposed_input_tensor = context.CreateGPUTensor(input_tensor->DataType(), transposed_input_shape);
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, *input_tensor, transposed_input_tensor));
intermediate_output = context.CreateGPUTensor(output_tensor->DataType(), transposed_input_shape);
}

const int64_t cols = is_transpose_required ? transposed_input_shape[input_rank - 1] : input_shape[input_rank - 1];
const int64_t rows = input_shape.Size() / cols;
const int64_t components = GetMaxComponents(cols);
const auto packedCols = cols / components;
uint32_t WG = rows == 1 ? 256 : 64;

SoftmaxProgram program{WG};
if (is_transpose_required) {
program
.AddInputs({{&transposed_input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{&intermediate_output, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
} else {
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(components)}});
}

program
.CacheHint(std::to_string(components), std::to_string(WG))
.SetWorkgroupSize(WG)
.SetDispatchGroupSize(rows)
.AddUniformVariables({{static_cast<int32_t>(packedCols)}});

ORT_RETURN_IF_ERROR(context.RunProgram(program));

// If transpose was required, transpose the result back
if (is_transpose_required) {
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, intermediate_output, *output_tensor));
}

return Status::OK();
}
} // namespace webgpu
} // namespace onnxruntime
52 changes: 52 additions & 0 deletions onnxruntime/core/providers/webgpu/math/softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace webgpu {

class Softmax final : public WebGpuKernel {
public:
Softmax(const OpKernelInfo& info) : WebGpuKernel{info} {
int opset_ = info.node().SinceVersion();
int64_t axis;
Status status = info.GetAttr<int64_t>("axis", &axis);

if (status.IsOK()) {
axis_ = axis;
} else {
if (opset_ < 13) {
axis_ = 1; // opset-12 and below, the default axis value is 1
} else {
axis_ = -1; // opset-13, the default axis value is -1
}
}
}

Status ComputeInternal(ComputeContext& context) const override;

private:
int64_t axis_;
};

class SoftmaxProgram final : public Program<SoftmaxProgram> {
public:
SoftmaxProgram(uint32_t wg) : Program{"Softmax"}, WG{wg} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"packedCols", ProgramUniformVariableDataType::Int32});

private:
uint32_t WG;
};

} // namespace webgpu
} // namespace onnxruntime
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/webgpu/shader_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,17 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
template <typename TOffset>
inline std::string GetByOffset(TOffset&& offset) const;

std::string_view StorageType() const;
vraspar marked this conversation as resolved.
Show resolved Hide resolved
std::string_view ValueType() const;
std::string_view ElementType() const;

private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper);

void Impl(std::ostream& ss) const;

std::string GetByOffsetImpl(std::string_view offset) const;
std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const;
std::string_view StorageType() const;
std::string_view ValueType() const;
std::string_view ElementType() const;

friend class ShaderHelper;
};
Expand Down
53 changes: 53 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,59 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output) {
vraspar marked this conversation as resolved.
Show resolved Hide resolved
const auto& input_shape = input.Shape();
const auto& input_dims = input_shape.GetDims();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());

TensorShapeVector output_dims(rank);

for (int32_t i = 0; i < rank; i++) {
output_dims[i] = input_dims[permutations[i]];
}

TensorShape output_shape(output_dims);

InlinedVector<int64_t> new_shape{};
InlinedVector<int64_t> new_perm{};
vraspar marked this conversation as resolved.
Show resolved Hide resolved
SqueezeShape(input_shape.GetDims(), permutations, new_shape, new_perm);
const bool channels_last = new_perm == InlinedVector<int64_t>({2, 3, 1});
const bool channels_first = new_perm == InlinedVector<int64_t>({3, 1, 2});
const bool use_shared = (new_shape.size() == 2 && new_perm[0] > new_perm[1]) || channels_last || channels_first;
auto new_input_shape = input_shape;
TensorShape new_output_shape(output_dims);

if (use_shared) {
new_input_shape = channels_last
? TensorShape({new_shape[0], new_shape[1] * new_shape[2]})
: channels_first
? TensorShape({new_shape[0] * new_shape[1], new_shape[2]})
: new_shape;
new_output_shape = TensorShape({new_input_shape[1], new_input_shape[0]});
}

uint32_t output_size = gsl::narrow_cast<int32_t>(input_shape.Size());
TransposeProgram program{permutations, use_shared};

if (use_shared) {
program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1);
}
program
.CacheHint(absl::StrJoin(permutations, "-"))
.AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}})
.AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}})
.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
.AddUniformVariables({
{static_cast<uint32_t>(output_size)},
});

use_shared ? program.SetDispatchGroupSize(static_cast<uint32_t>((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE),
static_cast<uint32_t>(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE)))
: program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE);
return context.RunProgram(program);
}

Status Transpose::ComputeInternal(ComputeContext& context) const {
vraspar marked this conversation as resolved.
Show resolved Hide resolved
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class Transpose final : public WebGpuKernel, public TransposeBase {
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
}
Status ComputeInternal(ComputeContext& context) const override;
static Status DoTranspose(onnxruntime::webgpu::ComputeContext& context, const gsl::span<const size_t>& permutations, const Tensor& input, Tensor& output);

constexpr static uint32_t TILE_SIZE = 16;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 3, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
Expand Down
Loading
Loading