forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Loss.cpp
494 lines (433 loc) · 19.7 KB
/
Loss.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Dispatch.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/PointwiseOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
constexpr float EPSILON = 1e-12;
namespace {
static inline at::Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
if (reduction == at::Reduction::Mean) {
return unreduced.mean();
} else if (reduction == at::Reduction::Sum) {
return unreduced.sum();
}
return unreduced;
}
}
namespace at { namespace native {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(smooth_l1_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(smooth_l1_backward_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(huber_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(huber_backward_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(mse_stub);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(mse_backward_stub);
Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
TORCH_CHECK(
target.dim() == 1,
"1D target tensor expected, multi-target not supported");
auto prod_sum = (input1 * input2).sum(1);
auto mag_square1 = (input1 * input1).sum(1) + EPSILON;
auto mag_square2 = (input2 * input2).sum(1) + EPSILON;
auto denom = (mag_square1 * mag_square2).sqrt_();
auto cos = prod_sum / denom;
auto zeros = at::zeros_like(cos, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto pos = 1 - cos;
auto neg = (cos - margin).clamp_min_(0);
auto output_pos = at::where(target == 1, pos, zeros);
auto output_neg = at::where(target == -1, neg, zeros);
auto output = output_pos + output_neg;
return apply_loss_reduction(output, reduction);
}
Tensor hinge_embedding_loss(const Tensor& self, const Tensor& target, double margin, int64_t reduction) {
auto zeros = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto margin_clamp = (margin - self).clamp_min_(0);
auto output_margin = at::where(target != 1, margin_clamp, zeros);
auto output_self = at::where(target != -1, self, zeros);
auto output = output_margin + output_self;
return apply_loss_reduction(output, reduction);
}
Tensor triplet_margin_loss(const Tensor& anchor, const Tensor& positive, const Tensor& negative, double margin,
double p, double eps, bool swap, int64_t reduction) {
auto dist_pos = at::pairwise_distance(anchor, positive, p, eps);
auto dist_neg = at::pairwise_distance(anchor, negative, p, eps);
if (swap) {
auto dist_swap = at::pairwise_distance(positive, negative, p, eps);
dist_neg = at::min(dist_neg, dist_swap);
}
auto output = at::clamp_min(margin + dist_pos - dist_neg, 0);
return apply_loss_reduction(output, reduction);
}
Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
auto output = (-target * (input1 - input2) + margin).clamp_min_(0);
return apply_loss_reduction(output, reduction);
}
Tensor _kl_div_log_target(const Tensor& input, const Tensor& target, int64_t reduction) {
auto output = at::exp(target) * (target - input);
return apply_loss_reduction(output, reduction);
}
Tensor _kl_div_non_log_target(const Tensor& input, const Tensor& target, int64_t reduction) {
auto output_pos = target * (at::log(target) - input);
auto zeros = at::zeros_like(output_pos, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto output = at::where(target > 0, output_pos, zeros);
return apply_loss_reduction(output, reduction);
}
Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
return log_target ? _kl_div_log_target(input, target, reduction)
: _kl_div_non_log_target(input, target, reduction);
}
Tensor kl_div_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto grad_expand = grad.expand_as(input);
if (!log_target) {
auto iter = TensorIteratorConfig()
.add_output(grad_input)
.add_input(target)
.add_input(grad_expand)
.build();
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "kl_div_backward_cpu", [&]() {
cpu_serial_kernel(iter, [](scalar_t target_val, scalar_t grad_val) -> scalar_t{
return target_val > 0 ? -target_val * grad_val : 0;
});
});
}
else {
grad_input = -at::exp(target) * grad_expand;
}
if (reduction == at::Reduction::Mean) {
return grad_input / input.numel();
}
return grad_input;
}
Tensor binary_cross_entropy_cpu(const Tensor& input, const Tensor& target, const c10::optional<Tensor>& weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor loss = at::empty_like(input);
return at::native::binary_cross_entropy_out_cpu(
input, target, weight, reduction, loss);
}
Tensor& binary_cross_entropy_out_cpu(const Tensor& input, const Tensor& target, const c10::optional<Tensor>& weight_opt, int64_t reduction, Tensor& loss) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor loss_squeezed = at::squeeze(loss);
auto iter = TensorIteratorConfig()
.add_output(loss_squeezed)
.add_owned_input(at::squeeze(input))
.add_owned_input(at::squeeze(target))
.build();
AT_DISPATCH_FLOATING_TYPES(loss.scalar_type(), "binary_cross_entropy", [&] {
at::native::cpu_kernel(
iter,
[] (scalar_t input_val, scalar_t target_val) {
TORCH_CHECK(
(input_val >= 0) && (input_val <= 1),
"all elements of input should be between 0 and 1"
);
// Binary cross entropy tensor is defined by the equation:
// L = -w (y ln(x) + (1-y) ln(1-x))
return (target_val - scalar_t(1))
* std::max(scalar_t(std::log(scalar_t(1) - input_val)), scalar_t(-100))
- target_val * std::max(scalar_t(std::log(input_val)), scalar_t(-100));
}
);
});
if (weight.defined()) {
loss.mul_(weight);
}
if (reduction != at::Reduction::None) {
Tensor loss_reduced = apply_loss_reduction(loss, reduction);
loss.resize_as_(loss_reduced).copy_(loss_reduced);
}
return loss;
}
Tensor binary_cross_entropy_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const c10::optional<Tensor>& weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor grad_input = at::empty_like(input);
return at::native::binary_cross_entropy_backward_out_cpu(
grad, input, target, weight, reduction, grad_input);
}
Tensor& binary_cross_entropy_backward_out_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, const c10::optional<Tensor>& weight_opt, int64_t reduction, Tensor& grad_input) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor grad_input_squeezed = at::squeeze(grad_input);
auto iter = TensorIteratorConfig()
.add_output(grad_input_squeezed)
.add_owned_input(at::squeeze(grad))
.add_owned_input(at::squeeze(input))
.add_owned_input(at::squeeze(target))
.build();
AT_DISPATCH_FLOATING_TYPES(grad_input.scalar_type(), "binary_cross_entropy_backward", [&] {
at::native::cpu_kernel(
iter,
[] (scalar_t grad_val, scalar_t input_val, scalar_t target_val) {
// The gradient is the partial derivative of BCELoss
// with respect to x
// d(L)/d(x) = -w (y - x) / (x - x^2)
return grad_val * (input_val - target_val)
/ (scalar_t(std::max(
(scalar_t(1) - input_val) * input_val,
scalar_t(EPSILON)
)));
}
);
});
if (weight.defined()) {
grad_input.mul_(weight);
}
if (reduction == at::Reduction::Mean) {
grad_input.div_(input.numel());
}
return grad_input;
}
Tensor binary_cross_entropy_with_logits(const Tensor& input, const Tensor& target, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& pos_weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();});
Tensor loss;
auto max_val = (-input).clamp_min_(0);
if (pos_weight.defined()) {
// pos_weight need to be broadcasted, thus mul(target) is not inplace.
auto log_weight = (pos_weight - 1).mul(target).add_(1);
loss = (1 - target).mul_(input).add_(log_weight.mul_(((-max_val).exp_().add_((-input - max_val).exp_())).log_().add_(max_val)));
} else {
loss = (1 - target).mul_(input).add_(max_val).add_((-max_val).exp_().add_((-input -max_val).exp_()).log_());
}
if (weight.defined()) {
loss.mul_(weight);
}
return apply_loss_reduction(loss, reduction);
}
Tensor binary_cross_entropy_with_logits_backward(const Tensor& grad, const Tensor& input, const Tensor& target, const c10::optional<Tensor>& weight_opt, const c10::optional<Tensor>& pos_weight_opt, int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();});
Tensor grad_input;
if (pos_weight.defined()) {
// pos_weight need to be broadcasted, thus mul(target) is not inplace.
auto t = pos_weight.mul(target);
grad_input = t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t).mul_(grad);
} else {
grad_input = (input.sigmoid() - target).mul_(grad);
}
if (weight.defined()) {
grad_input.mul_(weight);
}
if (reduction == at::Reduction::Mean) {
return grad_input / input.numel();
}
return grad_input;
}
Tensor poisson_nll_loss(const Tensor& input, const Tensor& target, const bool log_input, const bool full, const double eps, const int64_t reduction)
{
Tensor loss;
if (log_input) {
loss = at::exp(input) - target * input;
} else {
loss = input - target * at::log(input + eps);
}
if (full) {
auto stirling_term = target * at::log(target) - target + 0.5 * at::log(2 * c10::pi<double> * target);
loss += stirling_term.masked_fill(target <= 1, 0);
}
return apply_loss_reduction(loss, reduction);
}
Tensor& soft_margin_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, Tensor& grad_input) {
auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
auto z = at::exp(-target * input);
// inplace version of: grad_input = -norm * target * z / (1. + z) * grad_output;
at::mul_out(grad_input, target, z).mul_(-norm);
z.add_(1);
grad_input.div_(z).mul_(grad_output);
return grad_input;
}
Tensor soft_margin_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
auto grad_input = at::empty({0}, input.options());
at::soft_margin_loss_backward_out(grad_input, grad_output, input, target, reduction);
return grad_input;
}
Tensor& soft_margin_loss_out(const Tensor& input,
const Tensor& target,
int64_t reduction,
Tensor& output) {
// compute inplace variant of: output = at::log(1. + at::exp(-input * target));
at::neg_out(output, input).mul_(target).exp_().add_(1.).log_();
if (reduction != Reduction::None) {
auto tmp = apply_loss_reduction(output, reduction);
output.resize_({});
output.copy_(tmp);
}
return output;
}
Tensor soft_margin_loss(
const Tensor& input,
const Tensor& target,
int64_t reduction) {
auto output = at::empty({0}, input.options());
at::soft_margin_loss_out(output, input, target, reduction);
return output;
}
Tensor smooth_l1_loss(const Tensor& input, const Tensor& target, const int64_t reduction, double beta) {
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.")
if (beta == 0) {
return at::native::l1_loss(input, target, reduction);
}
Tensor loss;
auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
smooth_l1_stub(iter.device_type(), iter, beta);
return apply_loss_reduction(iter.output(), reduction);
}
Tensor& smooth_l1_loss_out(const Tensor& input, const Tensor& target, int64_t reduction, double beta, Tensor& result) {
TORCH_CHECK(beta >= 0, "smooth_l1_loss does not support negative values for beta.")
if (beta == 0) {
return at::native::l1_loss_out(input, target, reduction, result);
}
if (reduction != Reduction::None) {
Tensor loss;
auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
smooth_l1_stub(iter.device_type(), iter, beta);
if (reduction == Reduction::Mean) {
at::mean_out(result, iter.output(), 0);
} else {
at::sum_out(result, iter.output(), 0);
}
} else {
auto iter = TensorIterator::borrowing_binary_op(result, input, target);
smooth_l1_stub(iter.device_type(), iter, beta);
}
return result;
}
Tensor& smooth_l1_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta, Tensor& grad_input) {
if (beta <= 0)
return at::native::l1_loss_backward_out(
grad_output, input, target, reduction, grad_input);
auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
auto iter = at::TensorIteratorConfig()
.add_output(grad_input)
.add_input(input)
.add_input(target)
.add_input(grad_output)
.build();
smooth_l1_backward_stub(iter.device_type(), iter, norm, beta);
return grad_input;
}
Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
if (beta <= 0)
return at::native::l1_loss_backward(grad_output, input, target, reduction);
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction, beta);
}
Tensor huber_loss(const Tensor& input, const Tensor& target, int64_t reduction, double delta) {
TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
Tensor loss = at::empty_like(input);
auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
huber_stub(iter.device_type(), iter, delta);
return apply_loss_reduction(loss, reduction);
}
Tensor& huber_loss_out(const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& result) {
TORCH_CHECK(delta > 0, "huber_loss does not support non-positive values for delta.")
auto iter = TensorIterator::borrowing_binary_op(result, input, target);
huber_stub(iter.device_type(), iter, delta);
if (reduction != Reduction::None) {
auto reduced = apply_loss_reduction(result, reduction);
result.resize_({});
result.copy_(reduced);
}
return result;
}
Tensor huber_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double delta) {
auto grad_input = at::zeros_like(input, MemoryFormat::Contiguous);
return at::huber_loss_backward_out(grad_input, grad_output, input, target, reduction, delta);
}
Tensor& huber_loss_backward_out(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double delta, Tensor& grad_input) {
auto norm = (reduction == Reduction::Mean) ? (1. / input.numel()) : 1.;
auto iter = at::TensorIteratorConfig()
.add_output(grad_input)
.add_input(input)
.add_input(target)
.add_input(grad_output)
.build();
huber_backward_stub(iter.device_type(), iter, norm, delta);
return grad_input;
}
Tensor mse_loss(const Tensor& input, const Tensor& target, int64_t reduction) {
Tensor loss;
auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
mse_stub(iter.device_type(), iter);
return apply_loss_reduction(iter.output(), reduction);
}
Tensor& mse_loss_out(const Tensor& input, const Tensor& target, int64_t reduction, Tensor&result) {
if (reduction != Reduction::None) {
Tensor loss;
auto iter = TensorIterator::borrowing_binary_op(loss, input, target);
mse_stub(iter.device_type(), iter);
if (reduction == Reduction::Mean) {
at::mean_out(result, iter.output(), 0);
} else {
at::sum_out(result, iter.output(), 0);
}
} else {
auto iter = TensorIterator::borrowing_binary_op(result, input, target);
mse_stub(iter.device_type(), iter);
}
return result;
}
Tensor mse_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
Tensor grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return at::mse_loss_backward_out(grad_input, grad_output, input, target, reduction);
}
Tensor& mse_loss_backward_out(const Tensor& grad_output,
const Tensor& input, const Tensor& target, int64_t reduction, Tensor& grad_input) {
auto norm = reduction == Reduction::Mean ? 2. / input.numel() : 2.;
auto iter = at::TensorIteratorConfig()
.add_output(grad_input)
.add_input(input)
.add_input(target)
.add_input(grad_output)
.build();
mse_backward_stub(iter.device_type(), iter, norm);
return grad_input;
}
Tensor l1_loss(const Tensor& input, const Tensor& target, int64_t reduction) {
const auto float_type = c10::toValueType(input.scalar_type());
Tensor result = at::empty({0}, input.options().dtype(float_type));
return at::l1_loss_out(result, input, target, reduction);
}
Tensor& l1_loss_out(const Tensor& input, const Tensor& target, int64_t reduction, Tensor& result) {
if (reduction != Reduction::None) {
auto diff = at::sub(input, target);
auto loss = diff.is_complex() ? diff.abs() : diff.abs_();
if (reduction == Reduction::Mean) {
return at::mean_out(result, loss, IntArrayRef{});
} else {
return at::sum_out(result, loss, IntArrayRef{});
}
} else {
auto diff = input.is_complex() ? at::sub(input, target) : at::sub_out(result, input, target);
return at::abs_out(result, diff);
}
}
Tensor l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
Tensor grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
return at::l1_loss_backward_out(grad_input, grad_output, input, target, reduction);
}
Tensor& l1_loss_backward_out(const Tensor& grad_output,
const Tensor& input, const Tensor& target, int64_t reduction, Tensor& grad_input) {
auto norm = reduction == Reduction::Mean ? grad_output / input.numel() : grad_output;
return at::sub_out(grad_input, input, target).sgn_().mul_(norm);
}
}} // namespace at::native