forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RNN.cpp
1680 lines (1507 loc) · 66.5 KB
/
RNN.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <ATen/ATen.h>
#include <ATen/Config.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/cuda/Exceptions.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/MatrixRef.h>
#include <ATen/native/RNN.h>
#include <ATen/NativeFunctions.h>
#include <ATen/TensorUtils.h>
#include <c10/util/accumulate.h>
#include <c10/util/Exception.h>
#if !AT_CUDNN_ENABLED()
namespace at { namespace native {
// See Note [ATen preprocessor philosophy]
Tensor _cudnn_rnn_flatten_weight(
TensorList weight_arr, int64_t weight_stride0,
int64_t input_size,
int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size,
int64_t fn_num_layers, bool batch_first,
bool fn_bidirectional
) {
AT_ERROR("_cudnn_rnn_flatten_weight: ATen not compiled with cuDNN support");
}
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
const Tensor& input_r,
TensorList weight, int64_t weight_stride0, const c10::optional<Tensor>& weight_buf_r_opt, const Tensor& hx, const c10::optional<Tensor>& cx_opt,
int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size,
int64_t fn_num_layers, bool batch_first, double fn_dropout,
bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const c10::optional<Tensor>& fn_dropout_state_opt
) {
AT_ERROR("_cudnn_rnn: ATen not compiled with cuDNN support");
}
std::tuple<Tensor, Tensor, Tensor, std::vector<Tensor>> _cudnn_rnn_backward(
const Tensor& input, TensorList weight, int64_t weight_stride0, const Tensor& weight_buf, const Tensor& hx, const c10::optional<Tensor>& cx_opt,
const Tensor& output, const c10::optional<Tensor>& grad_output_r_opt, const c10::optional<Tensor>& grad_hy_r_opt, const c10::optional<Tensor>& grad_cy_r_opt,
int64_t mode, int64_t hidden_size, int64_t proj_size,
int64_t num_layers, bool batch_first, double dropout,
bool train, bool bidirectional, IntArrayRef batch_sizes, const c10::optional<Tensor>& dropout_state_opt, const Tensor& reserve,
std::array<bool, 4> output_mask
) {
AT_ERROR("_cudnn_rnn_backward: ATen not compiled with cuDNN support");
}
Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_seed,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support");
}
}} // namespace at::native
#else // AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/RNNUtils.h>
namespace at { namespace native {
namespace {
// DropoutDescriptor
struct DropoutDescriptorParams {
bool train;
double dropout;
Tensor dropout_state;
DropoutDescriptorParams() {}
void set(bool train_, double dropout_, Tensor dropout_state_) {
train = train_;
dropout = dropout_;
dropout_state = dropout_state_;
}
DropoutDescriptor descriptor(cudnnHandle_t handle) const {
auto dropout_p = train ? dropout : 0;
DropoutDescriptor dropout_desc;
if (dropout_p == 0) {
dropout_desc.set_no_dropout(handle);
} else {
dropout_desc.set(handle, dropout_p, dropout_state);
}
return dropout_desc;
}
};
// RNNDescriptor
struct RNNDescriptorParams {
int64_t hidden_size;
int64_t proj_size;
int64_t num_layers;
cudnnDirectionMode_t bidirectional;
cudnnRNNMode_t mode;
cudnnDataType_t datatype;
cudnnDataType_t input_datatype;
cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD;
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
int64_t num_directions() const {
return bidirectional ? 2 : 1;
}
void set_mode(int64_t fn_mode) {
switch (fn_mode) {
case CUDNN_RNN_RELU:
mode = CUDNN_RNN_RELU;
break;
case CUDNN_RNN_TANH:
mode = CUDNN_RNN_TANH;
break;
case CUDNN_LSTM:
mode = CUDNN_LSTM;
break;
case CUDNN_GRU:
mode = CUDNN_GRU;
break;
default:
{
std::ostringstream oss;
oss << "unrecognized cuDNN RNN mode " << fn_mode;
AT_ERROR(oss.str());
}
}
}
void set_bidirectional(bool fn_bidirectional) {
bidirectional = fn_bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL;
}
void set_algo(cudnnRNNAlgo_t algo){
this->algo = algo;
}
void set(int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool bidirectional, cudnnDataType_t datatype, cudnnDataType_t input_datatype) {
this->set_mode(mode);
this->hidden_size = hidden_size;
this->proj_size = proj_size;
this->num_layers = num_layers;
this->set_bidirectional(bidirectional);
this->datatype = datatype;
this->input_datatype = input_datatype;
}
RNNDescriptor descriptor(cudnnHandle_t handle, DropoutDescriptor&& dropout_desc) const {
RNNDescriptor rnn_desc;
rnn_desc.set(handle, hidden_size, proj_size, num_layers, std::move(dropout_desc), input_mode, bidirectional, mode, datatype, input_datatype, algo, at::globalContext().allowTF32CuDNN());
return rnn_desc;
}
// In some cases, a use of RNNDescriptor does not rely on the
// DropoutDescriptor. In this case, we fake up a no-dropout
// descriptor to make the RNN descriptor initialization go through.
// This is used by _cudnn_rnn_flatten_weight, which needs an
// RNNDescriptor for get_parameters(), but does not actually need
// a fully initialized dropout descriptor. This lets us avoid
// having to pass the dropout state to flatten, which has no business
// knowing what the dropout state is.
RNNDescriptor descriptor(cudnnHandle_t handle) const {
DropoutDescriptor dropout_desc;
dropout_desc.set_no_dropout(handle);
return descriptor(handle, std::move(dropout_desc));
}
};
// TensorDescriptor list
std::vector<TensorDescriptor> rnn_descriptor_sequence(const Tensor& tensor, IntArrayRef batch_sizes) {
std::vector<TensorDescriptor> descriptors(batch_sizes.size());
size_t i = 0;
// To be mutated in the loop
auto batch_tensor_size = tensor.sizes().vec();
for (auto batch_size : batch_sizes) {
batch_tensor_size[0] = batch_size;
// NB: cuDNN RNN API does not support 2d descriptors, so we
// must pad it out to 3d.
descriptors[i].set(getCudnnDataType(tensor), batch_tensor_size, tensor.strides(), 3);
i++;
}
return descriptors;
}
std::vector<TensorDescriptor> rnn_descriptor(const Tensor& tensor, int64_t N) {
std::vector<TensorDescriptor> descriptors(N);
for (int64_t i = 0; i < N; i++) {
descriptors[i].set(tensor, 5);
}
return descriptors;
}
// The best way to understand the meaning of the values stored in
// this struct is to consider each of the possible ways our
// input can be structured.
//
// Suppose you want to run RNN on the following variable
// length inputs:
//
// Sequence 1: ABCD
// Sequence 2: EF
// Sequence 3: G
//
// (Let _ be padding when we have non-packed representations.)
//
// # Packed input (batch_sizes is non-empty)
//
// input_size
// +------+ +
// | A | |
// | E | mini_batch = |
// | G | batch_sizes[0] = 3 |
// +------+ |
// | B | | batch_sizes_sum = 7
// | F | batch_sizes[1] = 2 |
// +------+ |
// | C | batch_sizes[2] = 1 |
// +------+ |
// | D | batch_sizes[3] = 1 |
// +------+ +
//
// (seq_length = 4)
//
// input.size() = batch_sizes_sum x input_size
//
// # Unpacked input (batch_first = false)
//
// mini_batch = 3
// +-------+
// | A E G |
// | B F _ | seq_length = 4
// | C _ _ |
// | D _ _ |
// +-------+
// ... input_size
// +-------+
//
// input.size() = seq_length x mini_batch x input_size
//
// # Unpacked input (batch_first = true)
//
// seq_length = 4
// +---------+
// | A B C D |
// | E F _ _ | mini_batch = 3
// | G _ _ _ |
// +---------+
// ... input_size
// +---------+
//
// input.size() = mini_batch x seq_length x input_size
//
struct TensorDescriptorListParams {
IntArrayRef batch_sizes;
int64_t seq_length;
int64_t mini_batch;
// NB: this is not input.size(), which is an IntArrayRef; instead, this
// size of the inner-most dimension. In NL applications, this is usually
// the size of the embedding. You can also think of this as the size
// of the "channel" dimension (at risk of confusing vision researchers :)
int64_t input_size;
// Only valid when !is_input_packed
int64_t batch_sizes_sum; // == sum(batch_sizes)
bool is_input_packed() const {
return batch_sizes.size() != 0;
}
void set(IntArrayRef input_sizes, IntArrayRef batch_sizes_, bool batch_first) {
batch_sizes = batch_sizes_;
if (is_input_packed()) {
seq_length = batch_sizes.size();
mini_batch = batch_sizes[0];
// NB: When input is packed, the mini_batch size is NOT the size
// of the outer dimension
batch_sizes_sum = input_sizes[0];
input_size = input_sizes[1];
} else {
if (batch_first) {
seq_length = input_sizes[1];
mini_batch = input_sizes[0];
} else {
seq_length = input_sizes[0];
mini_batch = input_sizes[1];
}
input_size = input_sizes[2];
// TODO: Actually, would this make ASAN's job harder catching
// an uninitialized access?
batch_sizes_sum = -1; // something bogus in case we access it
}
}
// TODO: check x for consistency with input_size?
std::vector<TensorDescriptor> descriptors(Tensor x) const {
auto is_input_packed = batch_sizes.size() != 0;
if (is_input_packed) {
return rnn_descriptor_sequence(x, batch_sizes);
} else {
return rnn_descriptor(x[0], seq_length);
}
}
};
// Everything together
struct RNNParams {
DropoutDescriptorParams dropout;
RNNDescriptorParams rnn;
TensorDescriptorListParams tensors;
};
// NB: Doesn't include the weight descriptor
struct RNNDescriptors {
RNNDescriptor rnn_desc;
// NB: this won't actually lay out the tensor descriptor pointers
// in the right way, so you'll have to preprocess them
std::vector<TensorDescriptor> x_descs;
std::vector<TensorDescriptor> y_descs;
TensorDescriptor hx_desc;
TensorDescriptor hy_desc;
TensorDescriptor cx_desc;
TensorDescriptor cy_desc;
RNNDescriptors(const RNNParams& fn, cudnnHandle_t handle, Tensor x, Tensor y, Tensor hx, Tensor cx) {
rnn_desc = fn.rnn.descriptor(handle, fn.dropout.descriptor(handle));
x_descs = fn.tensors.descriptors(x);
y_descs = fn.tensors.descriptors(y);
hx_desc.set(hx, 5);
hy_desc.set(hx, 5);
if (cx.defined()) {
cx_desc.set(cx, 5);
cy_desc.set(cx, 5);
}
}
// TODO: This is annoying, having to put the cudnnTensorDescriptor_t
// in a contiguous array...
std::vector<cudnnTensorDescriptor_t> get_descs(const std::vector<TensorDescriptor>& descs) {
std::vector<cudnnTensorDescriptor_t> r;
r.reserve(descs.size());
for (auto& desc : descs) {
r.emplace_back(desc.desc());
}
return r;
}
std::vector<cudnnTensorDescriptor_t> get_x_descs() {
return get_descs(x_descs);
}
std::vector<cudnnTensorDescriptor_t> get_y_descs() {
return get_descs(y_descs);
}
};
int64_t get_num_weights(cudnnHandle_t handle, const RNNDescriptor& rnn_desc,
const TensorDescriptor& x_desc, cudnnDataType_t datatype) {
size_t weight_size;
AT_CUDNN_CHECK(cudnnGetRNNParamsSize(handle, rnn_desc.desc(), x_desc.desc(), &weight_size, datatype));
auto elem_size = dataSize(datatype);
TORCH_INTERNAL_ASSERT(weight_size % elem_size == 0, "cudnnGetRNNParamsSize returned nonsensical weight_size");
return weight_size / elem_size;
}
int64_t _num_linear_layers(cudnnRNNMode_t mode) {
switch(mode) {
case CUDNN_LSTM:
return 8;
case CUDNN_GRU:
return 6;
case CUDNN_RNN_RELU:
return 2;
case CUDNN_RNN_TANH:
return 2;
default:
AT_ERROR("unknown cuDNN RNN mode ", mode);
}
}
void add_projection_weights(
cudnnHandle_t handle,
const RNNDescriptor& rnn_desc,
const TensorDescriptor& x_desc,
const FilterDescriptor& w_desc,
const Tensor& weight_buf,
int64_t layer,
std::vector<Tensor>& params
) {
void* matrix_pointer = nullptr;
// assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4 biases)
int64_t linear_id = 8;
FilterDescriptor lin_layer_mat_desc;
AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams(
/*handle=*/handle,
/*rnnDesc=*/rnn_desc.desc(),
/*layer=*/layer,
/*xDesc=*/x_desc.desc(),
/*wDesc=*/w_desc.desc(),
/*w=*/weight_buf.data_ptr(),
/*linLayerID=*/linear_id,
/*linLayerMatDesc=*/lin_layer_mat_desc.mut_desc(),
/*linLayerMat=*/&matrix_pointer));
cudnnDataType_t data_type;
cudnnTensorFormat_t format;
int nb_dims;
constexpr int min_dim = 3;
int filter_dim_a[min_dim];
AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor(
lin_layer_mat_desc.desc(),
min_dim,
&data_type,
&format,
&nb_dims,
filter_dim_a
));
TORCH_INTERNAL_ASSERT(nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim);
auto elem_size = dataSize(getCudnnDataType(weight_buf));
auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr();
TORCH_INTERNAL_ASSERT(offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size);
size_t offset = offset_bytes / elem_size;
int mat_numel = c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims);
// Generate a new parameter tensor which is a view into the weight_buf.
std::initializer_list<int64_t> size = {mat_numel, 1};
Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size);
params.emplace_back(std::move(param));
}
/*
Returns weight and bias tensors for each layer of the RNN. These tensors
are views on the underlying weight buffer allocated by CuDNN.
Note: for LSTM and GRU, which have multiple parameters of each type (4 and 3, respectively),
these parameters are concatenated along the first dimension.
These parameters are returned in a consistent order by CuDNN:
(reset, forget, cell, output) for LSTM
(reset, input, new) for GRU
Args:
fn: The RNN function object holding the RNN state
handle: a CuDNN handle
weight_buf: a 1D tensor containing the CuDNN-allocated weight (or grad_weight) buffer
Returns:
parameters: [(weight_ih, weight_hh, bias_ih, bias_hh)*], with length equal to the num_layers.
This is represented as a pair of vector, and outer-dimension stride
(NB: Can't return MatrixRef because we need to allocate the underlying tensor)
*/
std::pair<std::vector<Tensor>, size_t> // stride0
get_parameters(
cudnnHandle_t handle,
const RNNDescriptorParams& rnn,
const RNNDescriptor& rnn_desc,
const TensorDescriptor& x_desc,
const FilterDescriptor& w_desc,
const Tensor& weight_buf,
bool include_bias=true
) {
auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams };
std::vector<Tensor> params;
int64_t num_linear_layers = _num_linear_layers(rnn.mode);
int64_t num_layers = rnn.num_directions() * rnn.num_layers;
size_t cur_offset = 0;
size_t global_layer_params_count = 0;
for (int64_t layer = 0; layer < num_layers; layer++) {
size_t layer_params_count = 0;
for (auto cudnn_method : cudnn_methods) {
for (int64_t linear_id = 0; linear_id < num_linear_layers; linear_id++) {
FilterDescriptor lin_layer_mat_desc;
void* matrix_pointer;
AT_CUDNN_CHECK(cudnn_method(
handle,
rnn_desc.desc(),
layer,
x_desc.desc(),
w_desc.desc(),
weight_buf.data_ptr(),
linear_id,
lin_layer_mat_desc.mut_desc(),
&matrix_pointer
));
cudnnDataType_t data_type;
cudnnTensorFormat_t format;
int nb_dims;
constexpr int min_dim = 3;
int filter_dim_a[min_dim];
AT_CUDNN_CHECK(cudnnGetFilterNdDescriptor(
lin_layer_mat_desc.desc(),
min_dim,
&data_type,
&format,
&nb_dims,
filter_dim_a
));
TORCH_INTERNAL_ASSERT(nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim);
auto elem_size = dataSize(getCudnnDataType(weight_buf));
auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr();
TORCH_INTERNAL_ASSERT(offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size);
size_t offset = offset_bytes / elem_size;
// for all the RNN types provided by CUDNN, all the ih weights
// are the same size and are allocated in a contiguous chunk
// (same for the hh weights, and the ih and hh biases).
// Since we're storing all the weights in a single tensor anyway,
// might as well merge the CUDNN ones into a single tensor as well
int mat_numel = c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims);
if (linear_id == 0 || linear_id == num_linear_layers / 2) {
// We could also exclude bias params by restricting cudnn_methods to just { cudnnGetRNNLinLayerMatrixParams }
// at the very top. However, to do so would throw off the cur_offset account, which is currently a strict
// and informative check that all params are laid out the way we think they are. If include_bias is false,
// I'd rather keep full cur_offset checks rather than save some CPU overhead by skipping the cudnn_method =
// cudnnGetRNNLinLayerBiasParams iteration.
if (include_bias || cudnn_method != cudnnGetRNNLinLayerBiasParams) {
// Generate a new parameter tensor which is a view into the weight_buf.
std::initializer_list<int64_t> size = {
mat_numel * num_linear_layers / 2, 1};
Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size);
params.emplace_back(std::move(param));
layer_params_count++;
}
} else {
TORCH_INTERNAL_ASSERT(cur_offset == offset, "cur_offset = ", cur_offset, "; offset = ", offset);
}
cur_offset = offset + mat_numel;
}
} // for cudnn_method
if (rnn.proj_size != 0) {
add_projection_weights(handle, rnn_desc, x_desc, w_desc, weight_buf, layer, params);
layer_params_count++;
}
if (layer == 0) {
global_layer_params_count = layer_params_count;
} else {
TORCH_INTERNAL_ASSERT(global_layer_params_count == layer_params_count,
"global_layer_params_count = ", global_layer_params_count,
"; layer_params_count = ", layer_params_count);
}
} // for layer
return std::make_pair(params, global_layer_params_count);
}
// This is a lightweight version of the method above used to quickly get the expected
// parameter offsets.
std::vector<void*> get_expected_data_ptrs(
const Tensor& weight_buf, cudnnHandle_t handle, const RNNDescriptorParams& rnn,
const RNNDescriptor& rnn_desc, const TensorDescriptor& x_desc, cudnnDataType_t datatype) {
FilterDescriptor w_desc;
w_desc.set(weight_buf, 3);
int64_t num_linear_layers = _num_linear_layers(rnn.mode);
int64_t num_dir_layers = rnn.num_directions() * rnn.num_layers;
const auto cudnn_methods = { cudnnGetRNNLinLayerMatrixParams, cudnnGetRNNLinLayerBiasParams };
std::vector<void*> data_ptrs;
if (rnn.proj_size != 0) {
data_ptrs.reserve(num_dir_layers * (2 * 2 + 1));
} else {
data_ptrs.reserve(num_dir_layers * 2 * 2);
}
for (int64_t layer = 0; layer < num_dir_layers; layer++) {
for (auto cudnn_method : cudnn_methods) {
// This API returns a separate pointer for weight of every gate,
// but we represent them as a single tensor, so we're only interested
// in a very limited subset of possible values.
const std::array<int64_t, 2> linear_offsets = { 0, num_linear_layers / 2 };
for (int64_t linear_id : linear_offsets) {
FilterDescriptor lin_layer_mat_desc;
void* matrix_pointer;
AT_CUDNN_CHECK(cudnn_method(
handle,
rnn_desc.desc(),
layer,
x_desc.desc(),
w_desc.desc(),
weight_buf.data_ptr(),
linear_id,
lin_layer_mat_desc.mut_desc(),
&matrix_pointer
));
data_ptrs.push_back(matrix_pointer);
}
}
if (rnn.proj_size != 0) {
// assuming it's LSTM which has 8 "linear layers" (i.e. 4 weights and 4 biases)
int64_t linear_id = 8;
FilterDescriptor lin_layer_mat_desc;
void* matrix_pointer;
AT_CUDNN_CHECK(cudnnGetRNNLinLayerMatrixParams(
handle,
rnn_desc.desc(),
layer,
x_desc.desc(),
w_desc.desc(),
weight_buf.data_ptr(),
linear_id,
lin_layer_mat_desc.mut_desc(),
&matrix_pointer
));
data_ptrs.push_back(matrix_pointer);
}
}
return data_ptrs;
}
void _viewOrCopyOneParam(const Tensor& param_from, const Tensor& param_to,
bool copy, bool allow_type_change=false) {
// if copying, allow_type_change may be true or false.
// if viewing, allow_type_change must be false.
TORCH_INTERNAL_ASSERT(copy || !allow_type_change,
"if viewing, type change is not allowed.");
TORCH_INTERNAL_ASSERT(allow_type_change || (param_from.scalar_type() == param_to.scalar_type()),
"parameter types mismatch");
if (copy) {
param_to.copy_(param_from.view_as(param_to));
} else {
param_from.resize_as_(param_to);
}
}
void _viewOrCopyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to,
bool copy, bool allow_type_change=false) {
TORCH_INTERNAL_ASSERT(params_from.size(0) == params_to.size(0), "number of layers mismatch");
for (size_t i = 0; i < params_from.size(0); i++) {
auto layer_params_from = params_from[i];
auto layer_params_to = params_to[i];
// NOTE: these lists have all weights before all biases, so if the layer
// doesn't use biases, iteration will terminate once layer_params_from ends
// and ignore them.
// NOTE: there is an exception from the above statement. If LSTMs with projections
// are used, weights layout will be w_ih, w_hh, b_ih, b_hh, w_hr. So need to handle no-bias
// case specially, because will need to copy 0->0, 1->1, 2->4. This case can be uniquely
// identified by checking if number of defined parameters for each layer is 3.
if (layer_params_from.size() == 3 && layer_params_to.size() != 3) {
_viewOrCopyOneParam(layer_params_from[0], layer_params_to[0], copy, allow_type_change);
_viewOrCopyOneParam(layer_params_from[1], layer_params_to[1], copy, allow_type_change);
_viewOrCopyOneParam(layer_params_from[2], layer_params_to[4], copy, allow_type_change);
continue;
}
if (layer_params_to.size() == 3 && layer_params_from.size() != 3) {
_viewOrCopyOneParam(layer_params_from[0], layer_params_to[0], copy, allow_type_change);
_viewOrCopyOneParam(layer_params_from[1], layer_params_to[1], copy, allow_type_change);
_viewOrCopyOneParam(layer_params_from[4], layer_params_to[2], copy, allow_type_change);
continue;
}
for (auto a = layer_params_from.begin(), b = layer_params_to.begin();
a != layer_params_from.end() && b != layer_params_to.end();
++a, ++b) {
_viewOrCopyOneParam(*a, *b, copy, allow_type_change);
}
}
}
void _copyParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
_viewOrCopyParams(params_from, params_to, true);
}
void _viewParams(MatrixRef<Tensor> params_from, MatrixRef<Tensor> params_to) {
_viewOrCopyParams(params_from, params_to, false);
}
std::vector<int64_t> _input_size(const TensorDescriptorListParams& tensors) {
if (tensors.is_input_packed()) {
return {tensors.batch_sizes_sum, tensors.input_size};
} else {
return {tensors.seq_length, tensors.mini_batch, tensors.input_size};
}
}
std::vector<int64_t> _hidden_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) {
if (rnn.proj_size != 0) {
return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.proj_size};
} else {
return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size};
}
}
std::vector<int64_t> _cell_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) {
return {rnn.num_layers * rnn.num_directions(), tensors.mini_batch, rnn.hidden_size};
}
std::vector<int64_t> _output_size(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors) {
auto out_size = rnn.hidden_size;
if (rnn.proj_size != 0) {
out_size = rnn.proj_size;
}
if (tensors.is_input_packed()) {
return {tensors.batch_sizes_sum, out_size * rnn.num_directions()};
} else {
return {tensors.seq_length, tensors.mini_batch, out_size * rnn.num_directions()};
}
}
inline bool use_persist_common_heuristics(const RNNDescriptorParams& rnn,
const TensorDescriptorListParams& tensors) {
return rnn.num_layers == 1 &&
rnn.hidden_size <= 1024 &&
rnn.num_directions() == 1 &&
rnn.hidden_size % 128 == 0 &&
tensors.input_size % 128 == 0;
}
inline bool use_persist_device_heuristics(const RNNDescriptorParams& rnn,
const TensorDescriptorListParams& tensors) {
auto bsize = tensors.mini_batch;
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major == 7) {
if (prop->minor == 5) {
// Excludes Turing from using persistent rnn.
return false;
} else {
// technically, batch size should be multiple of 8, but there are quite a few multiple-of-8 batchsizes that give bad perf,
// weed them out
return ((bsize % 16 == 0 && bsize != 80 && bsize !=112) || bsize == 8) &&
((tensors.seq_length >=40 && bsize <=128) ||
(tensors.seq_length >=20 && bsize <=96) ||
(tensors.seq_length >=10 && bsize <=32));
}
} else if (prop->major >= 8) {
if (prop->minor == 6) {
// Excludes sm_86 GPU devices from using persistent rnn.
// This is because there are some edge cases that will throw exceptions with cudnn 8.0.5 on Nvidia A40 GPU.
return false;
}
// Based on tests by Vasily Volkov and xwang233. Vasily only tried bsize <= 128,
// so conservatively enable persistence for bsize <= 128 only.
// TODO: Run more tests for bsize > 128.
if (rnn.mode == CUDNN_GRU) {
// Persistent GRU performance is flakier than other RNN types. Exclude them for now.
// TODO: Write a more refined GRU heuristic.
return false;
} else if (rnn.mode == CUDNN_LSTM) {
// Persistent LSTMs are comparable to or better than non-persistent for bsize <= 128.
return (bsize % 8 == 0) && (bsize <= 128);
} else {
// Persistent RNN_RELU and TANH show poor performance when bsize >= 96 AND hidden size >= 896.
return (bsize % 8 == 0) && (bsize <= 128) && (bsize < 96 || rnn.hidden_size < 896);
}
} else {
return false;
}
}
cudnnRNNAlgo_t get_algo(const RNNDescriptorParams& rnn, const TensorDescriptorListParams& tensors, const Tensor input) {
// LSTM with projections only works with standard algorithm
if (rnn.proj_size != 0) {
return CUDNN_RNN_ALGO_STANDARD;
}
if (getCudnnDataType(input) == CUDNN_DATA_HALF &&
!tensors.is_input_packed()) {
if (use_persist_common_heuristics(rnn, tensors) &&
use_persist_device_heuristics(rnn, tensors)) {
return CUDNN_RNN_ALGO_PERSIST_STATIC;
}
}
return CUDNN_RNN_ALGO_STANDARD;
}
cudnnDataType_t promote_rnn_math_type(cudnnDataType_t dtype) {
if (dtype == CUDNN_DATA_HALF) {
return CUDNN_DATA_FLOAT;
}
return dtype;
}
} // anonymous namespace
// Utilities exposed in RNNUtils.h
namespace cudnn_rnn {
TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>>
copy_weights_to_flat_buf_views(
TensorList weight_arr,
int64_t weight_stride0,
int64_t input_size,
int64_t mode,
int64_t hidden_size,
int64_t proj_size,
int64_t num_layers,
bool batch_first,
bool bidirectional,
const cudnnDataType_t flat_buf_datatype,
const TensorOptions& flat_buf_options,
bool set_orig_weights_to_flat_buf,
bool allow_type_change /*=false*/,
bool include_bias /*=true*/) {
// flat_buf_datatype is accepted as a separate argument (rather than extracted
// from flat_buf_options) because to extract flat_buf_datatype from
// flat_buf_options, we'd need to say auto flat_buf_datatype =
// getCudnnDataTypeFromScalarType(typeMetaToScalarType(options.dtype()));
// typeMetaToScalarType is a surprisingly nontrivial function. We should
// avoid it if we can.
TORCH_CHECK(
weight_arr.size() > 0,
"copy_weights_to_flat_buf_views: cannot flatten empty weight list");
RNNDescriptorParams rnn;
rnn.set(
mode,
hidden_size,
proj_size,
num_layers,
bidirectional,
promote_rnn_math_type(flat_buf_datatype),
flat_buf_datatype);
auto handle = getCudnnHandle();
RNNDescriptor rnn_desc = rnn.descriptor(handle);
TensorGeometry x_geom({1, input_size});
TensorDescriptor x_desc;
// Why do we pad to 5 dims here (and elsewhere)?
// https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnRNNForwardTraining
// expects descriptors padded to 3 dimensions.
x_desc.set(flat_buf_datatype, x_geom.sizes(), x_geom.strides(), 5);
auto num_weights =
get_num_weights(handle, rnn_desc, x_desc, flat_buf_datatype);
auto weight_buf = at::zeros(num_weights, flat_buf_options);
FilterDescriptor w_desc;
w_desc.set(weight_buf, 3);
// Slice off views into weight_buf
std::vector<Tensor> params_arr;
size_t params_stride0;
std::tie(params_arr, params_stride0) = get_parameters(
handle, rnn, rnn_desc, x_desc, w_desc, weight_buf, include_bias);
MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
params{params_arr, params_stride0};
// Copy weights
_viewOrCopyParams(weight, params, /*copy=*/true, allow_type_change);
if (set_orig_weights_to_flat_buf) {
// Update the storage
for (size_t i = 0; i < weight.size(0); i++) {
// There is a special case for LSTM with projections and no bias,
// where weight copy is done in 0->0, 1->1, 2->4 layout
if (weight[i].size() == 3 && params[i].size() == 5) {
weight[i][0].set_(params[i][0].view_as(weight[i][0]));
weight[i][1].set_(params[i][1].view_as(weight[i][1]));
weight[i][2].set_(params[i][4].view_as(weight[i][2]));
} else {
for (auto orig_param_it = weight[i].begin(),
new_param_it = params[i].begin();
orig_param_it != weight[i].end() &&
new_param_it != params[i].end();
orig_param_it++, new_param_it++) {
auto orig_param = *orig_param_it, new_param = *new_param_it;
orig_param.set_(new_param.view_as(orig_param));
}
}
}
}
return std::make_tuple(weight_buf, params_arr);
}
} // namespace cudnn_rnn
using namespace cudnn_rnn;
// NB: does inplace update into TensorList
// It would be a relatively simple matter to refactor this into multiple
// functions, only one of which does an inplace update, but we leave this
// for future work
Tensor _cudnn_rnn_flatten_weight(
TensorList weight_arr, int64_t weight_stride0,
int64_t input_size,
int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size,
int64_t fn_num_layers, bool batch_first,
bool fn_bidirectional
) {
// returns flat weight_buf
return std::get<0>(copy_weights_to_flat_buf_views(
weight_arr,
weight_stride0,
input_size,
fn_mode,
fn_hidden_size,
fn_proj_size,
fn_num_layers,
batch_first,
fn_bidirectional,
/*flat_buf_datatype=*/getCudnnDataType(weight_arr[0]),
/*flat_buf_options=*/weight_arr[0].options(),
/*set_orig_weights_to_flat_buf=*/true));
}
const char * WEIGHT_FORMAT_WARN = "RNN module weights are not part of single contiguous "
"chunk of memory. This means they need to be compacted "
"at every call, possibly greatly increasing memory usage. "
"To compact weights again call flatten_parameters().";
// NB: when fn_batch_sizes is empty, that means no batch sizes was specified
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
const Tensor& input_r,
TensorList weight, int64_t weight_stride0, const c10::optional<Tensor>& weight_buf_r_opt, const Tensor& hx, const c10::optional<Tensor>& cx_opt,
int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_proj_size,
int64_t fn_num_layers, bool batch_first, double fn_dropout,
bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const c10::optional<Tensor>& fn_dropout_state_opt
) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_buf_r_maybe_owned = at::borrow_from_optional_tensor(weight_buf_r_opt);
const Tensor& weight_buf_r = *weight_buf_r_maybe_owned;
const Tensor& cx = c10::value_or_else(cx_opt, [] {return Tensor();});
const Tensor& fn_dropout_state = c10::value_or_else(fn_dropout_state_opt, [] {return Tensor();});
check_attributes(input_r, weight, {hx, cx}, /*check_dtype=*/true);
auto input = input_r;
auto weight_buf = weight_buf_r;
if (!weight_buf.defined()) {
TORCH_WARN(WEIGHT_FORMAT_WARN);
}
if (fn_dropout_state.defined()) {
auto input_arg = TensorArg(input, "input", 1);
auto dropout_state_arg = TensorArg(fn_dropout_state, "dropout_states", 15);
checkSameGPU("cudnn_rnn", input_arg, dropout_state_arg);
}
RNNParams fn;
auto datatype = getCudnnDataType(input);
fn.rnn.set(fn_mode, fn_hidden_size, fn_proj_size, fn_num_layers, fn_bidirectional, promote_rnn_math_type(datatype), datatype);
fn.dropout.set(fn_train, fn_dropout, fn_dropout_state);
fn.tensors.set(input.sizes(), fn_batch_sizes, batch_first);
// TODO: Set device to input
if (fn.rnn.mode != CUDNN_LSTM) {
TORCH_CHECK(!cx.defined(),
"rnn: illegal defined cx for non-LSTM RNN");
}
// TODO: can batch_first be a wrapper around this function?
auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
if (batch_first && !is_input_packed) {
input = input.transpose(0, 1);
}
auto hidden_size = _hidden_size(fn.rnn, fn.tensors);
auto cell_size = _cell_size(fn.rnn, fn.tensors);
auto output_size = _output_size(fn.rnn, fn.tensors);
TORCH_CHECK(hx.is_contiguous(),
"rnn: hx is not contiguous");
TORCH_CHECK(!cx.defined() || cx.is_contiguous(),
"rnn: cx is not contiguous");
auto x = input.contiguous();
auto output = at::empty(output_size, input.options());
auto hy = at::empty(hidden_size, hx.options());
Tensor cy;
if (cx.defined()) {
cy = at::empty(cell_size, cx.options());
} else {
cy = at::empty({0}, hx.options()); // NB: Not allowed to return undefined tensors
}
auto y = output;
auto handle = getCudnnHandle();
cudnnRNNAlgo_t algo = get_algo(fn.rnn, fn.tensors, input);
fn.rnn.set_algo(algo);
RNNDescriptors descs(fn, handle, x, y, hx, cx);
FilterDescriptor w_desc;
if (!weight_buf.defined()) {
auto num_weights = get_num_weights(handle, descs.rnn_desc, descs.x_descs[0], datatype);
weight_buf = at::empty(num_weights, x.options());
w_desc.set(weight_buf, 3);
weight_buf.zero_();
std::vector<Tensor> params;
size_t params_stride0;
std::tie(params, params_stride0) = get_parameters(handle, fn.rnn, descs.rnn_desc, descs.x_descs[0], w_desc, weight_buf);
_copyParams(MatrixRef<Tensor>{weight, static_cast<size_t>(weight_stride0)},
MatrixRef<Tensor>{params, params_stride0});
} else {
w_desc.set(weight_buf, 3);
}
TORCH_CHECK(!cx.defined() || cx.sizes().equals(cell_size),
"Expected cell size ", IntArrayRef{cell_size}, ", got ", cx.sizes());
size_t workspace_size;
auto x_descs_arr = descs.get_x_descs();
auto y_descs_arr = descs.get_y_descs();
AT_CUDNN_CHECK(cudnnGetRNNWorkspaceSize(
handle,
descs.rnn_desc.desc(),
fn.tensors.seq_length,
x_descs_arr.data(),
&workspace_size