Skip to content

Commit

Permalink
cherry pick softmax infer kernel (PaddlePaddle#45957)
Browse files Browse the repository at this point in the history
  • Loading branch information
JZZ-NOTE authored Sep 13, 2022
1 parent 29c44eb commit 0903020
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 25 deletions.
23 changes: 12 additions & 11 deletions paddle/fluid/framework/ir/is_test_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@ class Graph;
void IsTestPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
"for activations and pooling.";
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
"softshrink", "exp", "brelu",
"pow", "leaky_relu", "stanh",
"relu", "tanh", "tanh_shrink",
"sqrt", "abs", "ceil",
"elu", "floor", "cos",
"sin", "round", "reciprocal",
"hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus",
"softsign", "silu", "mish"};
auto op_list = {"pool2d", "sigmoid", "logsigmoid",
"softshrink", "exp", "brelu",
"pow", "leaky_relu", "stanh",
"relu", "tanh", "tanh_shrink",
"sqrt", "abs", "ceil",
"elu", "floor", "cos",
"sin", "round", "reciprocal",
"hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus",
"softsign", "silu", "mish",
"gumbel_softmax"};
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/cpu/gumbel_softmax_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,10 @@ struct OneHotGenerator<CPUContext, T> {

PD_REGISTER_KERNEL(
gumbel_softmax, CPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}

PD_REGISTER_KERNEL(gumbel_softmax_infer,
CPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
7 changes: 7 additions & 0 deletions paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,10 @@ struct GumbleNoiseGenerator<GPUContext, T> {

PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}

PD_REGISTER_KERNEL(gumbel_softmax_infer,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxInferKernel,
float,
double) {}
8 changes: 8 additions & 0 deletions paddle/phi/kernels/gumbel_softmax_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,12 @@ void GumbelSoftmaxKernel(const Context& dev_ctx,
int axis,
DenseTensor* out);

template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& dev_ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out);

} // namespace phi
50 changes: 36 additions & 14 deletions paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ template <typename Context, typename T>
struct OneHotGenerator;

template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
void GumbelSoftmaxKernelHelper(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out,
bool is_test) {
const int rank = x.dims().size();
axis = funcs::CanonicalAxis(axis, rank);
int axis_dim = x.dims()[axis];
Expand Down Expand Up @@ -80,18 +81,39 @@ void GumbelSoftmaxKernel(const Context& ctx,
size_to_axis,
size_from_axis,
temperature);

#ifdef PADDLE_ON_INFERENCE
paddle::operators::math::SoftmaxFunctor<Context, T, true>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
#else
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
#endif
if (is_test) {
paddle::operators::math::SoftmaxFunctor<Context, T, true>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
} else {
paddle::operators::math::SoftmaxFunctor<Context, T, false>()(
ctx, axis_dim, &x_noise_2d, &out_2d);
}

if (hard) {
OneHotGenerator<Context, T>::Transform(ctx, x, out, axis);
}
}

template <typename T, typename Context>
void GumbelSoftmaxKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, false);
}

template <typename T, typename Context>
void GumbelSoftmaxInferKernel(const Context& ctx,
const DenseTensor& x,
float temperature,
bool hard,
int axis,
DenseTensor* out) {
GumbelSoftmaxKernelHelper<T, Context>(
ctx, x, temperature, hard, axis, out, true);
}

} // namespace phi
18 changes: 18 additions & 0 deletions paddle/phi/ops/compat/gumbel_softmax_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ limitations under the License. */

namespace phi {

KernelSignature GumbelSoftmaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
bool is_test = false;
if (ctx.HasAttr("is_test")) {
is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
}
if (is_test) {
return KernelSignature("gumbel_softmax_infer",
{"X"},
{"temperature", "hard", "axis"},
{"Out"});
} else {
return KernelSignature(
"gumbel_softmax", {"X"}, {"temperature", "hard", "axis"}, {"Out"});
}
}

KernelSignature GumbelSoftmaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
Expand All @@ -24,5 +41,6 @@ KernelSignature GumbelSoftmaxGradOpArgumentMapping(

} // namespace phi

PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax, phi::GumbelSoftmaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(gumbel_softmax_grad,
phi::GumbelSoftmaxGradOpArgumentMapping);

0 comments on commit 0903020

Please sign in to comment.