forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FusedSGDKernel.cpp
268 lines (257 loc) · 8.45 KB
/
FusedSGDKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Parallel.h>
#include <ATen/OpMathType.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/FusedSGD.h>
#include <ATen/Dispatch.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
namespace at::native {
namespace{
template <typename scalar_t, typename opmath_t>
typename std::enable_if<
std::is_same<scalar_t, Half>::value || std::is_same<scalar_t, BFloat16>::value,
void>::
type inline sgd_math(
scalar_t* param_ptr,
scalar_t* grad_ptr,
scalar_t* momentum_buf_ptr,
const double weight_decay,
const double momentum,
const double lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step,
const float* grad_scale_ptr,
int64_t size
){
using lpVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<opmath_t>;
int64_t d = 0;
for (; d < size - (size % lpVec::size()); d += lpVec::size()) {
lpVec param_lpvec = lpVec::loadu(param_ptr + d);
auto [param_vec1, param_vec2] = vec::convert_to_float<scalar_t>(param_lpvec);
lpVec grad_lpvec = lpVec::loadu(grad_ptr + d);
auto [grad_vec1, grad_vec2] = vec::convert_to_float<scalar_t>(grad_lpvec);
if (grad_scale_ptr) {
grad_vec1 = grad_vec1 / fVec(float(*grad_scale_ptr));
grad_vec2 = grad_vec2 / fVec(float(*grad_scale_ptr));
lpVec grad_vec_to_store = vec::convert_from_float<scalar_t>(grad_vec1, grad_vec2);
grad_vec_to_store.store(grad_ptr + d);
}
if (maximize){
grad_vec1 = grad_vec1 * fVec(opmath_t(-1.0));
grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
}
if (weight_decay != 0.0){
grad_vec1 = vec::fmadd(param_vec1, fVec(scalar_t(weight_decay)), grad_vec1);
grad_vec2 = vec::fmadd(param_vec2, fVec(scalar_t(weight_decay)), grad_vec2);
}
if (momentum != 0.0) {
fVec momentum_vec1, momentum_vec2;
if (is_first_step) {
momentum_vec1 = grad_vec1;
momentum_vec2 = grad_vec2;
} else {
momentum_vec1 = fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum));
momentum_vec2 = fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum));
momentum_vec1 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec1, momentum_vec1);
momentum_vec2 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec2, momentum_vec2);
}
vec::convert_from_float<scalar_t>(momentum_vec1, momentum_vec2).store(momentum_buf_ptr + d);;
if (nesterov) {
grad_vec1 = vec::fmadd(momentum_vec1, fVec(scalar_t(momentum)), grad_vec1);
grad_vec2 = vec::fmadd(momentum_vec2, fVec(scalar_t(momentum)), grad_vec2);
} else {
grad_vec1 = momentum_vec1;
grad_vec2 = momentum_vec2;
}
}
}
for (; d < size; d++) {
opmath_t grad_val = grad_ptr[d];
opmath_t param_val = param_ptr[d];
if (grad_scale_ptr) {
grad_val = grad_ptr[d] / opmath_t(*grad_scale_ptr);
grad_ptr[d] = grad_val;
}
if (maximize) grad_val = -grad_val;
if (weight_decay != 0.0){
grad_val += param_val * opmath_t(weight_decay);
}
if (momentum != 0.0) {
opmath_t momentum_buf_var = momentum_buf_ptr[d];
if (is_first_step) {
momentum_buf_var = grad_val;
} else {
momentum_buf_var = momentum_buf_var * opmath_t(momentum) +
grad_val * opmath_t(1 - dampening);
}
momentum_buf_ptr[d] = momentum_buf_var;
if (nesterov) {
grad_val += momentum_buf_var * opmath_t(momentum);
} else {
grad_val = momentum_buf_var;
}
}
param_ptr[d] = param_val - grad_val * opmath_t(lr);
}
}
template <typename scalar_t, typename opmath_t>
typename std::enable_if<
std::is_same<scalar_t, float>::value || std::is_same<scalar_t, double>::value,
void>::
type inline sgd_math(
scalar_t* param_ptr,
scalar_t* grad_ptr,
scalar_t* momentum_buf_ptr,
const double weight_decay,
const double momentum,
const double lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step,
const float* grad_scale_ptr,
int64_t size
){
using Vec = at::vec::Vectorized<scalar_t>;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec param_vec = Vec::loadu(param_ptr + d);
Vec grad_vec = Vec::loadu(grad_ptr + d);
if (grad_scale_ptr) {
grad_vec = grad_vec / Vec(scalar_t(*grad_scale_ptr));
Vec grad_vec_to_store = grad_vec;
grad_vec_to_store.store(grad_ptr + d);
}
if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
if (weight_decay != 0.0){
grad_vec = vec::fmadd(param_vec, Vec(scalar_t(weight_decay)), grad_vec);
}
if (momentum != 0.0) {
Vec momentum_vec;
if (is_first_step) {
momentum_vec = grad_vec;
} else {
momentum_vec =
Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum));
momentum_vec = vec::fmadd(Vec(scalar_t(1 - dampening)), grad_vec, momentum_vec);
}
momentum_vec.store(momentum_buf_ptr + d);
if (nesterov) {
grad_vec = vec::fmadd(momentum_vec, Vec(scalar_t(momentum)), grad_vec);
} else {
grad_vec = momentum_vec;
}
}
param_vec += grad_vec * Vec(scalar_t(-lr));
param_vec.store(param_ptr + d);
}
for (; d < size; d++) {
scalar_t grad_val = grad_ptr[d];
if (grad_scale_ptr) {
grad_val = grad_ptr[d] / scalar_t(*grad_scale_ptr);
grad_ptr[d] = grad_val;
}
if (maximize) grad_val = -grad_val;
if (weight_decay != 0.0){
grad_val += param_ptr[d] * scalar_t(weight_decay);
}
if (momentum != 0.0) {
if (is_first_step) {
momentum_buf_ptr[d] = grad_val;
} else {
momentum_buf_ptr[d] = momentum_buf_ptr[d] * scalar_t(momentum) +
grad_val * scalar_t(1 - dampening);
}
if (nesterov) {
grad_val += momentum_buf_ptr[d] * scalar_t(momentum);
} else {
grad_val = momentum_buf_ptr[d];
}
}
param_ptr[d] -= grad_val * scalar_t(lr);
}
}
template <typename scalar_t>
void sgd_fused_step_impl(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& momentum_buffer,
const double weight_decay,
const double momentum,
const double lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step,
const float* grad_scale_ptr) {
using opmath_t = at::opmath_type<scalar_t>;
scalar_t* param_data = param.data_ptr<scalar_t>();
scalar_t* grad_data = grad.data_ptr<scalar_t>();
bool has_momentum_buffer = momentum != 0.0;
scalar_t* momentum_buffer_data = has_momentum_buffer ? momentum_buffer.data_ptr<scalar_t>() : nullptr;
constexpr size_t cache_line_size = 64;
constexpr int64_t cache_line_aligned_task_unit = cache_line_size / sizeof(scalar_t);
size_t num_units = divup(param.numel(), cache_line_aligned_task_unit);
auto sgd_fn = [&](int64_t begin, int64_t end) {
// local pointers
begin *= cache_line_aligned_task_unit;
end = std::min(end * cache_line_aligned_task_unit, param.numel());
scalar_t* param_ptr = param_data + begin;
scalar_t* grad_ptr = grad_data + begin;
scalar_t* momentum_buffer_ptr = has_momentum_buffer ? momentum_buffer_data + begin : nullptr;
const int64_t size = end - begin;
sgd_math<scalar_t, opmath_t>(
param_ptr,
grad_ptr,
momentum_buffer_ptr,
weight_decay,
momentum,
lr,
dampening,
nesterov,
maximize,
is_first_step,
grad_scale_ptr,
size
);
};
at::parallel_for(
0, num_units, 0, sgd_fn);
}
void fused_sgd_kernel(
const at::Tensor& param,
const at::Tensor& grad,
const at::Tensor& momentum_buffer,
const double weight_decay,
const double momentum,
const double lr,
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step,
const float* grad_scale_ptr
) {
Tensor grad_contiguous = grad.contiguous();
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, param.scalar_type(), "fused_sgd_kernel", [&] {
sgd_fused_step_impl<scalar_t>(
param,
grad,
momentum_buffer,
weight_decay,
momentum,
lr,
dampening,
nesterov,
maximize,
is_first_step,
grad_scale_ptr);
});
}
}
REGISTER_DISPATCH(fused_sgd_stub, &fused_sgd_kernel);
} // namespace at::native