Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.46】为 Paddle gumbel_softmax 算子实现 float16 数据类型支持 #50923

Merged
merged 5 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(gumbel_softmax_grad,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxGradKernel,
phi::dtype::float16,
float,
double) {}
30 changes: 21 additions & 9 deletions paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/phi/kernels/gumbel_softmax_kernel.h"

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h"
Expand Down Expand Up @@ -124,9 +124,11 @@ __global__ void AddGumbelNoiseCUDAKernel(const T* input_data,
int64_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int step = blockDim.x * gridDim.x;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
for (int64_t i = index; i < n; i += step) {
T gumbel_noise = -log(-log(noise[i]));
output_data[i] = (gumbel_noise + input_data[i]) / temperature;
MPType gumbel_noise = -log(-log(static_cast<MPType>(noise[i])));
output_data[i] = static_cast<T>(
(gumbel_noise + static_cast<MPType>(input_data[i])) / temperature);
}
}

Expand All @@ -152,10 +154,15 @@ struct GumbleNoiseGenerator<GPUContext, T> {
uint64_t offset = seed_offset.second;

thrust::counting_iterator<int64_t> index_sequence_begin(0);
thrust::transform(index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<T>(random_data),
UniformCUDAGenerator<T>(0.00001, 1, seed, size * offset));
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
thrust::transform(
index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<T>(random_data),
UniformCUDAGenerator<T>(static_cast<phi::dtype::float16>(0.00001),
static_cast<phi::dtype::float16>(1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否有修改的必要?无论T为何种类型,这里都cast到FP16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

seed,
size * offset));

// add gumbel noise to X
const int thread_size = 512;
Expand All @@ -168,5 +175,10 @@ struct GumbleNoiseGenerator<GPUContext, T> {
} // namespace phi
#endif

PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxKernel,
phi::dtype::float16,
float,
double) {}
8 changes: 8 additions & 0 deletions python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def init_attrs(self):
self.dtype = "float64"


class TestGumbelSoftmaxOp6(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10, 5]
self.attrs = {"hard": True, "axis": 1}
self.count_expected = 100
self.dtype = np.float16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这4个单测继承TestGumbelSoftmaxFP16OP。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

您好,因为前面TestGumbelSoftmax_ZeroDim_FP16OP是针对于ZeroDim的,所以内部没有init_attrs()函数。无法更改名字为TestGumbelSoftmaxFP16OP。所以直接继承自TestGumbelSoftmaxOp。


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

class TestGumbelSoftmaxOpSampleDistribution(OpTest):
def softmax(self, x):
x_row_max = x.max(axis=-1)
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None):
Parameters:
x (Tensor): An N-D Tensor, the first N - 1 dimensions index into a batch
of independent distributions and the last dimension represents
a vector of probabilities with datatype float32, float64.
a vector of probabilities with datatype float16, float32, float64.
temperature (float, optional): non-negative scalar temperature.
Default is 1.0.
hard (bool, optional): if True, the returned samples will be discretized as
Expand Down Expand Up @@ -1705,7 +1705,9 @@ def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None):
)

helper = LayerHelper("gumbel_softmax", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'gumbel_softmax')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'gumbel_softmax'
)
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='gumbel_softmax',
Expand Down