-
Notifications
You must be signed in to change notification settings - Fork 352
/
Copy pathbasic_linear.py
1167 lines (1085 loc) · 42.5 KB
/
basic_linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for linear layer without bias."""
from __future__ import annotations
from collections.abc import Callable, Iterable
import contextlib
import math
from typing import Any, Optional
import torch
from transformer_engine.pytorch.cpp_extensions import (
FP8TensorMeta,
fp8_gemm,
gemm,
)
from transformer_engine.pytorch.distributed import (
CudaRNGStatesTracker,
gather_along_first_dim,
reduce_scatter_along_first_dim,
)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import (
FP8GlobalStateManager,
get_fp8_te_dtype,
)
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import (
canonicalize_device,
canonicalize_dtype,
convert_tensor,
devices_match,
is_float8_tensor,
reshape,
)
from ...utils import clear_tensor_data
def _wait_async(handle: Optional[Any]) -> None:
"""Wait for asynchronous communication to finish, if needed"""
if handle is not None:
handle.wait()
class BasicLinear(BasicOperation):
"""Apply linear transformation: :math:`y = x A^T`
This is a drop-in replacement for `torch.nn.Linear` with
`bias=False`.
Parameters
----------
in_features: int
Inner dimension of input tensor
out_features: int
Inner dimension of output tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim)
rng_state_tracker_function: callable
Function that returns `CudaRNGStatesTracker`, which is used
for model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False`
Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
"""
def __init__(
self,
in_features: int,
out_features: int,
*,
device: Optional[torch.device | str] = None,
dtype: Optional[torch.dtype] = None,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] = None,
accumulate_into_main_grad: bool = False,
) -> None:
super().__init__()
# Weight tensor dimensions
self.in_features: int = in_features
self.out_features: int = out_features
# Weight tensor device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device
# Weight tensor datatype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Tensor parallel configuration
self.tensor_parallel_mode: Optional[str]
self.tensor_parallel_group: Optional[torch.distributed.ProcessGroup]
self.tensor_parallel_size: int
self.sequence_parallel: bool
self.local_in_features: int
self.local_out_features: int
(
self.tensor_parallel_mode,
self.tensor_parallel_group,
self.tensor_parallel_size,
self.sequence_parallel,
self.local_in_features,
self.local_out_features,
) = self._canonicalize_tensor_parallelism(
mode=tensor_parallel_mode,
process_group=tensor_parallel_group,
sequence_parallel=sequence_parallel,
in_features=in_features,
out_features=out_features,
)
# Whether weight tensor is natively in FP8
self._with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
if self._with_fp8_parameters:
self._fp8_metas = self._make_fp8_metas()
# Initialize parameters if needed
weight = torch.empty(
self.local_out_features,
self.local_in_features,
device="meta",
dtype=dtype,
)
weight = torch.nn.Parameter(weight)
self.weight: torch.nn.Parameter
self.register_parameter("weight", weight)
self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]]
self._rng_state_tracker_function = rng_state_tracker_function
if not defer_param_init:
self.reset_parameters()
# Whether to accumulate weight gradient into main_grad
self._accumulate_into_main_grad = accumulate_into_main_grad
@classmethod
def _canonicalize_tensor_parallelism(
cls,
*,
mode: Optional[str],
process_group: Optional[torch.distributed.ProcessGroup],
sequence_parallel: bool,
in_features: int,
out_features: int,
) -> tuple[
Optional[str],
Optional[torch.distributed.ProcessGroup],
int,
bool,
int,
int,
]:
"""Check configuration for tensor parallelism
Parameters
----------
mode: {`None`, "column", "row"}
Mode for tensor parallelism
process_group: torch.distributed.ProcessGroup
Process group for tensor parallelism
sequence_parallel: bool
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
in_features: int
Inner dimension of global input tensor
out_features: int
Inner dimension of global output tensor
Returns
-------
mode: {`None`, "column", "row"}
Mode for tensor parallelism
process_group: torch.distributed.ProcessGroup
Process group for tensor parallelism
group_size: int
Size of tensor-parallel process group
sequence_parallel: bool
Whether to apply sequence parallelism
local_in_features: int
Inner dimension of local input tensor
local_out_features: int
Inner dimension of local output tensor
"""
# Tensor-parallel group size
if mode is None:
group_size = 1
else:
group_size = torch.distributed.get_world_size(process_group)
# Disable tensor parallelism if not needed
if group_size == 1:
mode = None
process_group = None
sequence_parallel = False
# Determine local tensor dims
local_in_features = in_features
local_out_features = out_features
if mode is None:
pass
elif mode == "column":
# Distribute output tensor
if out_features % group_size != 0:
raise ValueError(
"Invalid configuration for tensor parallelism "
f"({mode=}, {out_features=}, {group_size=})"
)
local_out_features //= group_size
elif mode == "row":
# Distribute input tensor
if in_features % group_size != 0:
raise ValueError(
"Invalid configuration for tensor parallelism "
f"({mode=}, {in_features=}, {group_size=})"
)
local_in_features //= group_size
else:
raise ValueError(
"Supported modes for tensor parallelism are "
f'`None`, "row", and "column" (got {mode=})'
)
return (
mode,
process_group,
group_size,
sequence_parallel,
local_in_features,
local_out_features,
)
def num_fp8_scales(self, mode: str) -> int:
if mode in ("input", "param", "grad_output"):
return 1
return 0
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""
# Make sure parameter is initialized
weight = self.weight
if weight.device.type != "cuda" or is_float8_tensor(weight):
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
# Initialize values
init_context = contextlib.nullcontext
if self._rng_state_tracker_function is not None:
init_context = self._rng_state_tracker_function().fork
with init_context():
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
# Cast to FP8 if needed
if self._with_fp8_parameters:
dummy_amax = torch.empty(
(1, 1),
dtype=torch.float32,
device=self.device,
) # Dummy buffer to avoid overwriting amax history
weight = Float8Tensor.to_float8(
weight,
fp8_meta=self.get_fp8_meta("param"),
fp8_meta_forward=True,
fp8_meta_index=0,
amax=dummy_amax,
with_transpose_cache=torch.is_grad_enabled(),
)
# Save updated parameter
if not isinstance(weight, torch.nn.Parameter):
weight = torch.nn.Parameter(weight)
self.weight = weight
def pre_forward(self) -> None:
super().pre_forward()
if self.weight.device.type == "meta":
self.reset_parameters()
@staticmethod
def _functional_forward(
input: torch.Tensor, # pylint: disable=redefined-builtin
weight: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
out: Optional[torch.Tensor] = None,
accumulate_into_out: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
output_fp8_meta: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Functional API for forward pass
Parameters
----------
input: torch.Tensor
Input tensor
weight: torch.Tensor
Weight tensor
bias: torch.Tensor, optional
Bias tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
out: torch.Tensor, optional
Output tensor
accumulate_into_out: bool, default = `False`
Add result to output tensor instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
Returns
-------
torch.Tensor
Output tensor
torch.Tensor
Input tensor used in GEMM, possibly cast and reshaped from
provided input tensor
torch.Tensor
Weight tensor used in GEMM, possibly cast and reshaped from
provided weight tensor
"""
# Check device
if device is None:
device = weight.device if out is None else out.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
if out is not None and not devices_match(out.device, device):
raise ValueError(
f"Output tensor has invalid device (expected {device}, got {out.device})"
)
# Check datatype
if dtype is None:
dtype = weight.dtype if out is None else out.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
if out is not None and out.dtype != dtype:
raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})")
# Check input tensor dims
input_dims = tuple(input.size())
weight_dims = tuple(weight.size())
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check output tensor dims
output_dims: list[int]
if out is None:
output_dims = list(input_dims)
output_dims[0] = -1
output_dims[-1] = weight_dims[0]
else:
output_dims = list(out.size())
if len(output_dims) == 0 or weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check if accumulating into output tensor
if accumulate_into_out:
if out is None:
raise ValueError(
"Attempted to accumulate into output tensor without providing output tensor"
)
if tensor_parallel_mode == "row":
raise ValueError(
"Accumulating into output tensor is not supported with row tensor parallelism"
)
# Check if FP8 is enabled
if with_fp8_compute:
if input_fp8_meta is None and not is_float8_tensor(input):
raise ValueError("No FP8 metadata was provided for casting input to FP8")
if weight_fp8_meta is None and not is_float8_tensor(weight):
raise ValueError("No FP8 metadata was provided for casting weight to FP8")
else:
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row"
if out is None:
with_fp8_output = with_fp8_output and output_fp8_meta is not None
else:
if is_float8_tensor(out):
if not with_fp8_output:
raise ValueError(
"Output tensor is a Float8Tensor, but FP8 output is not supported"
)
out._reset_caches()
else:
with_fp8_output = False
# Check input tensor
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
with_transpose_cache = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
with_transpose_cache=with_transpose_cache,
)
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.dequantize()
x = x_local
x_async = None
if tensor_parallel_mode == "column" and sequence_parallel:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
async_op=True,
)
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize()
# Check bias tensor
b = None
if bias is not None:
b = convert_tensor(
bias,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Construct output tensor
y = None
if out is not None:
y = reshape(out, (-1, output_dims[-1]))
elif with_fp8_output:
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
data = torch.empty(
(x.size(0), weight_dims[0]),
dtype=torch.uint8,
device=device,
)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
y = torch.empty(
(x.size(0), weight_dims[0]),
dtype=dtype,
device=device,
)
# Perform GEMM
_wait_async(x_async)
x_async = None
if with_fp8_compute:
kwargs = {
"accumulate": accumulate_into_out,
"out": y,
"bias": b,
"use_bias": (b is not None),
}
if with_fp8_output:
if y._fp8_meta is None:
# Hackily create FP8TensorMeta if needed
fp8_meta = FP8TensorMeta()
fp8_meta.scale = y._scale_inv.reciprocal()
fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device)
fp8_meta.scale_inv = y._scale_inv
fp8_meta_index = 0
else:
# Get FP8TensorMeta from Float8Tensor
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=y._fp8_meta_forward,
)
fp8_meta = y._fp8_meta[fp8_meta_key]
fp8_meta_index = y._fp8_meta_index
kwargs.update(
{
"out": y._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": y._fp8_dtype,
}
)
fp8_gemm(
w._data,
w._scale_inv,
0,
w._fp8_dtype,
x._data,
x._scale_inv,
0,
x._fp8_dtype,
y.dtype,
get_workspace(),
**kwargs,
)
else:
gemm(
w,
x,
y.dtype,
get_workspace(),
accumulate=accumulate_into_out,
out=y,
bias=b,
use_bias=(b is not None),
)
# Reduce tensor-parallel output if needed
if tensor_parallel_mode == "row":
if sequence_parallel:
y, _ = reduce_scatter_along_first_dim(y, tensor_parallel_group)
else:
torch.distributed.all_reduce(y, group=tensor_parallel_group)
# Reshape output tensor if needed
if out is None:
out = reshape(y, output_dims)
return out, x_local, w
@staticmethod
def _functional_backward(
grad_output: torch.Tensor,
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor],
input_dims: Iterable[int],
weight_dims: Iterable[int],
*,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
grad_weight: Optional[torch.Tensor] = None,
accumulate_into_grad_weight: bool = False,
grad_input: Optional[torch.Tensor] = None,
accumulate_into_grad_input: bool = False,
tensor_parallel_mode: Optional[str] = None,
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
grad_output_fp8_meta: Optional[dict[str, Any]] = None,
grad_input_fp8_meta: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Functional API for backward pass
Parameters
----------
grad_output: torch.Tensor
Loss gradient w.r.t. output tensor
input: torch.Tensor, optional
Input tensor. Required to compute loss gradient w.r.t.
weight.
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
input_requires_grad: bool
Whether to compute loss gradient w.r.t. input tensor
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
device: torch.device, default = default CUDA device
Tensor device
dtype: torch.dtype, default = default dtype
Tensor datatype
grad_weight: torch.Tensor, optional
Loss gradient w.r.t. weight tensor
accumulate_into_grad_weight: bool, default = `False`
Add result to weight grad instead of overwriting
grad_input: torch.Tensor, optional
Loss gradient w.r.t. input tensor
accumulate_into_grad_input: bool, default = `False`
Add result to input grad instead of overwriting
tensor_parallel_mode: {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism
sequence_parallel: bool, default = `False`
Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. output
tensor to FP8. Required if output grad is not already in
FP8.
grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
Returns
-------
torch.Tensor
Loss gradient w.r.t. input tensor
torch.Tensor
Loss gradient w.r.t. weight tensor
"""
# Check device
if device is None:
device = weight.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Check tensor dims
output_dims = tuple(grad_output.size())
input_dims = tuple(input_dims)
weight_dims = tuple(weight_dims)
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
if weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Grad output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
if grad_input is not None and tuple(grad_input.size()) != input_dims:
raise ValueError(
f"Grad input tensor (shape={tuple(grad_input.size())}) "
f"does not match expected shape ({input_dims})"
)
# Check grad input tensor
if not input_requires_grad:
grad_input = None
if grad_input is not None and not devices_match(grad_input.device, device):
raise ValueError(
f"Grad input tensor has invalid device (expected {device}, got {grad_input.device})"
)
if grad_input is not None and grad_input.dtype != dtype:
raise ValueError(
f"Grad input tensor has invalid dtype (expected {dtype}, got {grad_input.dtype})"
)
if accumulate_into_grad_input:
if grad_input is None:
raise ValueError(
"Attempted to accumulate into grad input tensor "
"without providing grad input tensor"
)
if tensor_parallel_mode == "column":
raise ValueError(
"Accumulating into grad input tensor "
"is not supported with column tensor parallelism"
)
# Check if FP8 is enabled
if with_fp8_compute:
if grad_output_fp8_meta is None and not is_float8_tensor(grad_output):
raise ValueError("No FP8 metadata was provided for casting output gradient to FP8")
else:
input_fp8_meta = None
weight_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
with_fp8_grad_input = (
with_fp8_compute and input_requires_grad and tensor_parallel_mode != "column"
)
if grad_input is None:
with_fp8_grad_input = with_fp8_grad_input and grad_input_fp8_meta is not None
else:
if is_float8_tensor(grad_input):
if not with_fp8_grad_input:
raise ValueError(
"Grad input tensor is a Float8Tensor, but FP8 output is not supported"
)
grad_input._reset_caches()
else:
with_fp8_grad_input = False
# Check grad output tensor
dy_async = None
dy = reshape(
grad_output,
(-1, output_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(dy):
fp8_dtype = get_fp8_te_dtype(
grad_output_fp8_meta["recipe"],
fprop_tensor=False,
)
with_transpose_cache = weight_requires_grad
if tensor_parallel_mode == "row" and sequence_parallel:
with_transpose_cache = False
dy = Float8Tensor.to_float8(
dy,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
with_transpose_cache=with_transpose_cache,
)
elif not with_fp8_compute and is_float8_tensor(dy):
dy = dy.dequantize()
if tensor_parallel_mode == "row" and sequence_parallel:
dy, dy_async = gather_along_first_dim(
dy,
tensor_parallel_group,
async_op=True,
)
# Check input tensor
x = None
x_async = None
if weight_requires_grad:
if input is None:
raise ValueError("Input tensor is required to compute weight grad")
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
x_local = Float8Tensor.to_float8(
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
with_transpose_cache=(not x_is_sharded),
)
elif not with_fp8_compute and is_float8_tensor(x_local):
x_local = x_local.from_float8()
x = x_local
if x_is_sharded:
x, x_async = gather_along_first_dim(
x_local,
tensor_parallel_group,
async_op=True,
)
# Compute grad input
dx = None
dx_async = None
if input_requires_grad:
# Check weight tensor
if weight is None:
raise ValueError("Weight tensor is required to compute input grad")
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
with_transpose_cache=True,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize()
# Construct grad input tensor
if grad_input is not None:
dx = reshape(grad_input, (-1, input_dims[-1]))
elif with_fp8_grad_input:
fp8_dtype = get_fp8_te_dtype(
grad_input_fp8_meta["recipe"],
fprop_tensor=False,
)
data = torch.empty(
(dy.size(0), weight_dims[1]),
dtype=torch.uint8,
device=device,
)
dx = Float8Tensor(
data=data,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
dx = torch.empty(
(dy.size(0), weight_dims[1]),
dtype=dtype,
device=device,
)
# Perform dgrad GEMM
_wait_async(dy_async)
dy_async = None
if with_fp8_compute:
kwargs = {"accumulate": accumulate_into_grad_input, "out": dx}
if with_fp8_grad_input:
if dx._fp8_meta is None:
# Hackily create FP8TensorMeta if needed
fp8_meta = FP8TensorMeta()
fp8_meta.scale = dx._scale_inv.reciprocal()
fp8_meta.amax_history = torch.empty(
1, 1, dtype=torch.float32, device=device
)
fp8_meta.scale_inv = dx._scale_inv
fp8_meta_index = 0
else:
# Get FP8TensorMeta from Float8Tensor
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=dx._fp8_meta_forward,
)
fp8_meta = dx._fp8_meta[fp8_meta_key]
fp8_meta_index = dx._fp8_meta_index
kwargs.update(
{
"out": dx._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": dx._fp8_dtype,
}
)
fp8_gemm(
w.transpose_2d(),
w._scale_inv,
0,
w._fp8_dtype,
dy._data,
dy._scale_inv,
0,
dy._fp8_dtype,
dx.dtype,
get_workspace(),
**kwargs,
)
else:
gemm(
w,
dy,
dx.dtype,
get_workspace(),
accumulate=accumulate_into_grad_input,
layout="NN",
out=dx,
)
# Reduce tensor-parallel grad input if needed
if tensor_parallel_mode == "column":
if sequence_parallel:
dx, dx_async = reduce_scatter_along_first_dim(
dx,
tensor_parallel_group,
async_op=True,
)
else:
dx_async = torch.distributed.all_reduce(
dx,
group=tensor_parallel_group,