-
Notifications
You must be signed in to change notification settings - Fork 5.6k
/
cross_grad_kernel.cu
231 lines (209 loc) · 7.98 KB
/
cross_grad_kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/cross_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/index_calculator.h"
namespace phi {
template <typename T>
__global__ void CrossGrad(const T* x,
const T* y,
const T* out,
T* out_dx,
T* out_dy,
const int stride,
const int N,
phi::funcs::IndexCalculator index_calculator) {
CUDA_KERNEL_LOOP(i, N) {
int offset = index_calculator(i);
auto pos0 = offset + 0 * stride;
auto pos1 = offset + 1 * stride;
auto pos2 = offset + 2 * stride;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType x_pos0_mp = static_cast<MPType>(x[pos0]);
MPType x_pos1_mp = static_cast<MPType>(x[pos1]);
MPType x_pos2_mp = static_cast<MPType>(x[pos2]);
MPType y_pos0_mp = static_cast<MPType>(y[pos0]);
MPType y_pos1_mp = static_cast<MPType>(y[pos1]);
MPType y_pos2_mp = static_cast<MPType>(y[pos2]);
MPType out_pos0_mp = static_cast<MPType>(out[pos0]);
MPType out_pos1_mp = static_cast<MPType>(out[pos1]);
MPType out_pos2_mp = static_cast<MPType>(out[pos2]);
out_dx[pos0] =
static_cast<T>(out_pos2_mp * y_pos1_mp - out_pos1_mp * y_pos2_mp);
out_dy[pos0] =
static_cast<T>(out_pos1_mp * x_pos2_mp - out_pos2_mp * x_pos1_mp);
out_dx[pos1] =
static_cast<T>(out_pos0_mp * y_pos2_mp - out_pos2_mp * y_pos0_mp);
out_dy[pos1] =
static_cast<T>(out_pos2_mp * x_pos0_mp - out_pos0_mp * x_pos2_mp);
out_dx[pos2] =
static_cast<T>(out_pos1_mp * y_pos0_mp - out_pos0_mp * y_pos1_mp);
out_dy[pos2] =
static_cast<T>(out_pos0_mp * x_pos1_mp - out_pos1_mp * x_pos0_mp);
}
}
template <typename T, typename Context>
void CrossGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad) {
auto& input_x = x;
auto& input_y = y;
auto& input_out_grad = out_grad;
auto* output_x_grad = x_grad;
auto* output_y_grad = y_grad;
int dim = axis;
auto input_x_dims = input_x.dims();
if (dim != DDim::kMaxRank) {
PADDLE_ENFORCE_EQ(
dim < input_x_dims.size() && dim >= (0 - input_x_dims.size()),
true,
errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_x_dims.size(),
input_x_dims.size() - 1,
dim));
if (dim < 0) {
dim += input_x_dims.size();
}
PADDLE_ENFORCE_EQ(
input_x_dims[dim] == 3,
true,
errors::InvalidArgument(
"Input(X/Y).dims[dim] must be equal to 3. But received: "
"Input(X/Y).dims[dim] = [%d].",
input_x_dims[dim]));
} else {
for (auto i = 0; i < input_x_dims.size(); i++) {
if (input_x_dims[i] == 3) {
dim = i;
break;
}
}
PADDLE_ENFORCE_EQ(
dim == DDim::kMaxRank,
false,
errors::InvalidArgument("There must be at least one dimension 'd' "
"so that Input(X/Y).dims()[d] is equal to 3. "
"But received: Input(X/Y).dims() == [%s].",
input_x_dims));
}
std::vector<int> cal_dims;
std::vector<int> left_strides;
std::vector<int> full_strides;
std::vector<int> merged_dims;
for (int i = 0; i < dim; i++) {
if (i == 0) {
merged_dims.push_back(input_x_dims[i]);
} else {
merged_dims[0] *= input_x_dims[i];
}
}
int merge_axis = merged_dims.size();
merged_dims.push_back(input_x_dims[dim]);
for (int i = dim + 1; i < input_x_dims.size(); i++) {
if (i == dim + 1) {
merged_dims.push_back(input_x_dims[i]);
} else {
merged_dims[merge_axis + 1] *= input_x_dims[i];
}
}
int full_dim = 1;
for (int i = 0; i < merged_dims.size(); i++) {
full_strides.insert(full_strides.begin(), full_dim);
full_dim *= merged_dims[merged_dims.size() - i - 1];
if (i == merge_axis) {
continue;
}
cal_dims.push_back(i);
}
int left_dim = 1;
for (int i = merged_dims.size() - 1; i >= 0; i--) {
if (i == merge_axis) {
continue;
}
left_strides.insert(left_strides.begin(), left_dim);
left_dim *= merged_dims[i];
}
const auto* input_x_data = input_x.data<T>();
const auto* input_y_data = input_y.data<T>();
int64_t numel = x.numel();
const auto* input_out_grad_data = input_out_grad.data<T>();
auto* output_x_grad_data = dev_ctx.template Alloc<T>(x_grad);
auto* output_y_grad_data = dev_ctx.template Alloc<T>(y_grad);
auto index_calculator = phi::funcs::IndexCalculator(
merged_dims.size() - 1, cal_dims, left_strides, full_strides);
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel / 3);
if (IsComplexType(x.dtype())) {
DenseTensor x_conj, y_conj;
DenseTensorMeta meta_xy(x.dtype(), x.dims());
x_conj.set_meta(meta_xy);
y_conj.set_meta(meta_xy);
auto* input_x_conj_data = dev_ctx.template Alloc<T>(&x_conj);
auto* input_y_conj_data = dev_ctx.template Alloc<T>(&y_conj);
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::ConjFunctor<T> functor_x(
input_x_data, numel, input_x_conj_data);
phi::funcs::ConjFunctor<T> functor_y(
input_y_data, numel, input_y_conj_data);
for_range(functor_x);
for_range(functor_y);
CrossGrad<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_x_conj_data,
input_y_conj_data,
input_out_grad_data,
output_x_grad_data,
output_y_grad_data,
full_strides[merge_axis],
numel / 3,
index_calculator);
} else {
CrossGrad<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(input_x_data,
input_y_data,
input_out_grad_data,
output_x_grad_data,
output_y_grad_data,
full_strides[merge_axis],
numel / 3,
index_calculator);
}
}
} // namespace phi
PD_REGISTER_KERNEL(cross_grad,
GPU,
ALL_LAYOUT,
phi::CrossGradKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}