-
Notifications
You must be signed in to change notification settings - Fork 255
/
layers.py
2337 lines (1997 loc) · 88.1 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2023 The Mesh TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layers implemented in Mesh TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import gin
from mesh_tensorflow import ops_with_redefined_builtins as mtf
import tensorflow.compat.v1 as tf
def summary_for_clip_activation_gradient(t, name=None, batch_dims=None):
"""Summary for clip activation gradient."""
mtf.scalar_summary("mean/", mtf.reduce_mean(t))
mtf.scalar_summary("max/", mtf.reduce_max(t))
if batch_dims:
rms_batch = mtf.sqrt(
mtf.reduce_mean(mtf.square(t), output_shape=batch_dims))
mtf.scalar_summary("rms_mean/%s" % name, mtf.reduce_mean(rms_batch))
mtf.scalar_summary("rms_max/%s" % name, mtf.reduce_max(rms_batch))
mtf.scalar_summary(
"rms_var/%s" % name,
mtf.reduce_mean(mtf.square(rms_batch - mtf.reduce_mean(rms_batch))))
@gin.configurable
def unit_scaling_convention(value=False):
"""Turn this on with gin to enable the unit-scaling convention.
TODO(noam): turn this comment into a position paper and post to arxiv
Under the unit-scaling convention, all weights are initialized with unit
variance, and the outputs of most contractions (matmul/einsum operations) are
divided by the square-root of the sizes of the contracting dimensions.
This differs from the typical inverse-square-root weight-initalization
convention often attributed to
http://proceedings.mlr.press/v9/glorot10a.html
in which weights are typically initialized according to a distribution with
mean zero and standard-deviation equal to the inverse-square-root of the
contracting dimension(s).
Under both conventions, the purpose of the inverse-square-root scaling is so
that activations in a layer should be scaled similarly to the activations in
the previous layer. (Typically, models are initialized so that activations in
all layers should have RMS=O(1)).
The difference between the two conventions is whether this scaling happens in
the parameters (their way), or as an explicit multiplier on the activations
(our way).
In our opinion, parameter-scaling (their way) has three main disadvantages:
1. Optimizers need to be aware of differently-scaled parameters. This is
because the learning-rates of adaptive optimizers represent target step-sizes
for the parameters. The desired step size for a parameter logically depends
on the scale of the parameter itself, and so one typically needs to lower the
learning-rate when the layers get bigger and the parameters get consequently
smaller. Under the unit-scaling convention, this is unnecessary, since all
parameters are on the same unit scale.
2. It is often unwieldy from an engineering standpoint to communicate to both
the variable initializers and to the optimizer what the scale of the variable
should be. Typically, the variable initializer guesses this by inferring from
the dimension order which dimension of the variable might represent
contracting dimensions. This is highly error-prone.
3. Sometimes contractions happen without being associated with parameters, as
in neural attention. It may be important here too to divide by the square
root of the contracting dimensions, in order to maintain activation scale.
See the discussion in section 3.2.1 of https://arxiv.org/abs/1706.03762
Being in the habit of scaling the outputs of contractions in this way makes
it more likely to remember to do the same thing in these circumstances.
Note: When switching to the unit-scaling convention, it is probably necessary
to raise the learning rate, since larger parameters need larger updates. An
exception is when using Adafactor, which by default scales the updates
relative to the scale of the current parameter values.
Args:
value: a boolean
Returns:
a boolean
"""
return value
def us_einsum(xs, *args, **kwargs):
"""Einsum with optional unit-scaling convention.
If the unit-scaling convention is enabled, then divide the output by
the square-root of the product of the contracting dimensions.
Args:
xs: a list of mtf.Tensor
*args: arguments to mtf.einsum
**kwargs: keyword arguments to mtf.einsum
Returns:
a mtf.Tensor
"""
y = mtf.einsum(xs, *args, **kwargs)
if unit_scaling_convention():
all_input_dims = set(sum([x.shape.dims for x in xs], []))
reduced_dims = [d for d in all_input_dims if d not in y.shape.dims]
y *= mtf.Shape(reduced_dims).size ** -0.5
return y
def dense(x,
new_dims,
reduced_dims=None,
expert_dims=None,
use_bias=True,
activation=None,
master_dtype=tf.float32,
slice_dtype=tf.float32,
variable_dtype=None,
kernel_initializer=None,
kernel_weights=None,
name=None):
"""Dense layer doing (kernel*x + bias) computation.
Args:
x: a mtf.Tensor of shape [..., reduced_dims].
new_dims: a list of mtf.Dimension.
reduced_dims: a list of mtf.Dimensions of x to be reduced.
If omitted (deprecated interface), we reduce the last dimension.
expert_dims: an optional list of mtf.Dimension which represent different
experts. Different experts get different weights.
use_bias: a boolean, whether to add bias.
activation: an optional function from mtf.Tensor to mtf.Tensor
master_dtype: a tf.dtype (deprecated - use variable_dtype)
slice_dtype: a tf.dtype (deprecated - use variable_dtype)
variable_dtype: a mtf.VariableDType
kernel_initializer: an initializer for kernel variable.
kernel_weights: mtf.Tensor weights matrix to use for dense computation
name: a string used for tf.variable_scope.
Returns:
a mtf.Tensor of shape [..., new_dims].
"""
if not isinstance(new_dims, list):
new_dims = [new_dims]
if variable_dtype is None:
variable_dtype = mtf.VariableDType(master_dtype, slice_dtype, x.dtype)
if expert_dims is None:
expert_dims = []
if reduced_dims is None:
tf.logging.warning(
"Deprecation warning - it is recommended to pass reduced_dims "
"explicitly to mtf.layers.dense() so as not to depend on dimension "
"order. To silence this warning, explicitly pass "
"reduced_dims=x.shape.dims[-1:] (in scope %s)"
% tf.get_variable_scope().name)
reduced_dims = x.shape.dims[-1:]
# if any reduced dims have the same names as new dims, first change these
# dimension names in the input so as to avoid name conflict in the weight
# matrix.
reduced_dims = reduced_dims[:]
for i in range(len(reduced_dims)):
if reduced_dims[i] in new_dims:
original_name = reduced_dims[i].name
tmp_name = "_" + original_name
reduced_dims[i] = mtf.Dimension(tmp_name, reduced_dims[i].size)
x = mtf.rename_dimension(x, original_name, tmp_name)
output_shape = mtf.Shape([d for d in x.shape.dims if d not in reduced_dims] +
new_dims)
if not kernel_weights:
kernel_weights = get_dense_kernel_weights(x, new_dims, reduced_dims,
expert_dims, kernel_initializer,
name, variable_dtype,
master_dtype, slice_dtype)
with tf.variable_scope(name, default_name="dense"):
y = us_einsum([x, kernel_weights], output_shape)
if use_bias:
b = mtf.get_variable(
x.mesh,
"bias",
mtf.Shape(expert_dims + new_dims),
initializer=tf.zeros_initializer(),
dtype=variable_dtype)
y += b
if activation is not None:
y = activation(y)
return y
def get_dense_kernel_weights(x,
new_dims,
reduced_dims,
expert_dims,
kernel_initializer,
name=None,
variable_dtype=None,
master_dtype=tf.float32,
slice_dtype=tf.float32):
"""Create w matrix variable.
Args:
x: a mtf.Tensor.
new_dims: a list of mtf.Dimension.
reduced_dims: a list of mtf.Dimensions of x to be reduced.
expert_dims: an optional list of mtf.Dimension which represent different
experts. Different experts get different weights.
kernel_initializer: an initializer for kernel variable.
name: a string used for tf.variable_scope.
variable_dtype: a mtf.VariableDType
master_dtype: a tf.dtype (deprecated - use variable_dtype)
slice_dtype: a tf.dtype (deprecated - use variable_dtype)
Returns:
a mtf.Tensor.
"""
if variable_dtype is None:
variable_dtype = mtf.VariableDType(master_dtype, slice_dtype, x.dtype)
w_shape = mtf.Shape(expert_dims + reduced_dims + new_dims)
with tf.variable_scope(name, default_name="dense"):
if kernel_initializer is None:
kernel_initializer = VarianceScalingInitializer()
if isinstance(kernel_initializer, DenseInitializer):
kernel_initializer = kernel_initializer(reduced_dims, new_dims)
w = mtf.get_variable(
x.mesh,
"kernel",
w_shape,
initializer=kernel_initializer,
dtype=variable_dtype)
w = mtf.cast(w, x.dtype)
return w
def dense_product(x,
reduced_dims,
new_dims,
activation_functions=None,
name="dense_product",
**kwargs):
"""Component-wise product of multiple dense layers.
e.g. if activation_functions=["linear", "sigmoid"], then this implements
Gated Linear Units https://arxiv.org/pdf/1612.08083.pdf
Args:
x: a Tensor
reduced_dims: a list of Dimensions.
new_dims: a list of Dimensions.
activation_functions: a list of activation functions (or a singleton)
Each can be a either:
- a callable function from Tensor to Tensor
- a string function name from namespace mtf)
- None or "linear", meaning no activation function
name: an optional string
**kwargs: additional kwargs for mtf.layers.dense()
"""
if not isinstance(activation_functions, list):
activation_functions = [activation_functions]
num_factors = len(activation_functions)
factors = []
for i, activation in enumerate(activation_functions):
if activation == "linear":
activation = None
elif isinstance(activation, str):
activation = getattr(mtf, activation)
factors.append(
dense(x,
reduced_dims=reduced_dims,
new_dims=new_dims,
activation=activation,
name="%s_%d" % (name, i) if num_factors > 1 else name,
**kwargs))
return functools.reduce(mtf.multiply, factors)
class DenseInitializer(object):
"""Initializer that can be passed to dense().
The __call__ function takes reduced_dims and new_dims and returns a
tf initializer class.
"""
def __call__(self, reduced_dims, new_dims):
raise NotImplementedError("not implemented")
@gin.configurable
class VarianceScalingInitializer(DenseInitializer):
"""Initializer capable of adapting its scale to the shape of weights.
With `distribution="normal"`, samples are drawn from a truncated normal
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
1.0 if unit_scaling_convention() is turned on
otherwise:
number of input units in the weight tensor, if mode = "fan_in"
number of output units, if mode = "fan_out"
average of the numbers of input and output units, if mode = "fan_avg"
With `distribution="uniform"`,
samples are drawn from a uniform distribution
within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
# Arguments
scale: Scaling factor (positive float).
mode: One of "fan_in", "fan_out", "fan_avg".
distribution: Random distribution to use. One of "normal", "uniform".
seed: A Python integer. Used to seed the random generator.
"""
def __init__(self, scale=1.0,
mode="fan_in",
distribution="normal"):
self.scale = scale
self.mode = mode.lower()
self.distribution = distribution.lower()
def __call__(self, reduced_dims, new_dims):
fan_in = mtf.list_product(d.size for d in reduced_dims)
fan_out = mtf.list_product(d.size for d in new_dims)
scale = self.scale
if self.mode == "fan_in":
if not unit_scaling_convention():
scale /= max(1., fan_in)
elif self.mode == "fan_out":
if unit_scaling_convention():
raise ValueError("Unit scaling convention only works with \"fan_in\"")
scale /= max(1., fan_out)
elif self.mode == "fan_avg":
if unit_scaling_convention():
raise ValueError("Unit scaling convention only works with \"fan_in\"")
scale /= max(1., float(fan_in + fan_out) / 2)
else:
raise ValueError(
"Invalid `mode` argument: "
"expected on of {\"fan_in\", \"fan_out\", \"fan_avg\"} "
"but got %s" % (self.mode,))
stddev = scale ** 0.5
if self.distribution == "normal":
return tf.truncated_normal_initializer(stddev=stddev)
elif self.distribution == "uniform":
limit = stddev * 3. ** 0.5
return tf.random_uniform_initializer(minval=-limit, maxval=limit)
else:
raise ValueError("Invalid `distribution` argument: "
"expected one of {\"normal\", \"uniform\"} "
"but got %s" % (self.distribution,))
def conv1d(x, output_dim, filter_size=3, stride=1, **kw_args):
"""1D Convolution.
x can have multiple batch dims. The last dimension is considered the channel
dimension and the second-last dimension is the width dimension.
This function supports either "SAME" padding or "VALID" padding. The padding
type is specified by kwarg `padding` to conv2d, which transform the input
tensor x as follows:
padding="SAME"
[batch, fake_height, length, d_model]
-> [batch, fake_height, length, output_dim]
padding="VALID"
[batch, fake_height, length, d_model]
-> [batch, fake_height, output_length, output_dim]
Args:
x: a mtf.Tensor of format NWC where N can be multiple batch dimensions.
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a positive integer, the filter width.
stride: a positive integer, the stride.
**kw_args: optional keyword arguments to mtf.layers.conv2d.
Returns:
a mtf.Tensor of format NWO, where O is the output dimension.
"""
fake_height_dim = mtf.Dimension("fake_height", 1)
x = mtf.reshape(
x, mtf.Shape(x.shape.dims[:-2] + [fake_height_dim] + x.shape.dims[-2:]))
output = conv2d(
x,
output_dim,
filter_size=(1, filter_size),
strides=(1, stride),
**kw_args)
output_length_dim = output.shape.dims[-2]
output_shape = output.shape.dims[:-3] + [output_length_dim] + [output_dim]
output_shape = mtf.Shape(output_shape)
return mtf.reshape(output, output_shape)
def _depthwise_conv1d_hack(x,
depth_dim,
length_dim,
min_relative_pos=-1,
max_relative_pos=1,
name=None,
use_bias=True,
initializer_scale=1.0,
kernel_depth_weights=None):
"""Hacky version of a 1d depthwise convolution.
Args:
x: a mtf.Tensor
depth_dim: mtf.Dimension,
length_dim: mtf.Dimension,
min_relative_pos: int, min relative position,
max_relative_pos: int, max relative position,
name: str, variable_scope name,
use_bias: Bool, whether to use bias,
initializer_scale: int, initalizer scale,
kernel_depth_weights: an optional list of kernel weight tensors. The list
contains one element for each relative position in the kernel. Each element
has a width equal to the depth over which the separable conv operation is
being "separated"
Returns:
an mtf.Tensor
"""
ret = 0
kernel_size = max_relative_pos - min_relative_pos + 1
with tf.variable_scope(name, default_name="depthwise_conv_hack"):
for i in range(kernel_size):
relative_pos = min_relative_pos + i
shifted_input = mtf.shift(x, -relative_pos, length_dim, wrap=False)
ret += dense(
shifted_input,
new_dims=[],
reduced_dims=[],
expert_dims=[depth_dim],
kernel_weights=kernel_depth_weights[i]
if kernel_depth_weights else None,
name="depthwise_dense_%d" % i,
use_bias=use_bias and (i == 0),
kernel_initializer=VarianceScalingInitializer(
scale=initializer_scale / kernel_size))
return ret
def separable_conv1d(x,
output_dim,
min_relative_pos=-1,
max_relative_pos=1,
depthwise_filter_initializer_scale=1.0,
pointwise_filter_initializer_scale=1.0,
name=None,
use_bias=True,
kernel_depth_weights=None):
"""1-D convolution with separable filters.
The filter size will be `max_relative_pos - min_relative_pos + 1`.
Args:
x: a mtf.Tensor of format NWC.
output_dim: a mtf.Dimension, indicating the output channel dimension.
min_relative_pos: an integer, the inclusive minimum relative positive of the
depthwise filter, where a relative position of zero means the left end of
the filter aligns with the left end of the input.
max_relative_pos: an integer, the inclusive maximum relative position of the
depthwise filter, where a relative position of zero means the right end of
the filter aligns with the right end of the input.
depthwise_filter_initializer_scale: a positive float, the scale of the
initializer for the depthwise filter.
pointwise_filter_initializer_scale: a positive float, the scale of the
initializer for the pointwise filter.
name: a string used for tf.variable_scope.
use_bias: a bool, whether to use bias in the convolutions.
kernel_depth_weights: an optional list of kernel weight tensors. The list
contains one element for each relative position in the kernel. Each element
has a width equal to the dimension over which the separable conv operation
is being "separated"
Returns:
a mtf.Tensor of format NWO, where O is the output dimension.
"""
depth_dim = x.shape.dims[-1]
length_dim = x.shape.dims[-2]
with tf.variable_scope(name, default_name="separable_conv1d"):
depthwise = _depthwise_conv1d_hack(
x,
depth_dim=depth_dim,
length_dim=length_dim,
min_relative_pos=min_relative_pos,
max_relative_pos=max_relative_pos,
use_bias=use_bias,
initializer_scale=depthwise_filter_initializer_scale,
kernel_depth_weights=kernel_depth_weights)
return dense(
depthwise,
new_dims=[output_dim],
reduced_dims=[depth_dim],
name="pointwise_dense",
use_bias=use_bias,
kernel_initializer=VarianceScalingInitializer(
scale=pointwise_filter_initializer_scale))
def conv2d(x, output_dim, filter_size=(3, 3),
strides=(1, 1), padding="SAME", filter_initializer=None,
variable_dtype=None, name=None):
"""2D Convolution.
Args:
x: a mtf.Tensor of format NHWC.
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a list or tuple in format [filter_height, filter_width].
strides: a list or tuple in format [stride_height, stride_width].
padding: either "SAME" or "VALID".
filter_initializer: the initializer for tf.get_variable.
variable_dtype: a mtf.VariableDType
name: a string used for tf.variable_scope.
Returns:
a mtf.Tensor.
"""
fh_dim = mtf.Dimension("fh", filter_size[0])
fw_dim = mtf.Dimension("fw", filter_size[1])
input_dim = x.shape[-1]
with tf.variable_scope(name, default_name="conv2d"):
if variable_dtype is None:
variable_dtype = mtf.VariableDType(activation_dtype=x.dtype)
conv_filter = mtf.get_variable(
x.mesh, "kernel", [fh_dim, fw_dim, input_dim, output_dim],
initializer=filter_initializer, dtype=variable_dtype)
# Pad stride in batch and channel dimensions.
strides = [1] + list(strides) + [1]
return mtf.Conv2dOperation(x, conv_filter, strides, padding).outputs[0]
def conv2d_with_blocks(
x, output_dim, filter_size=(3, 3),
strides=(1, 1), padding="SAME",
h_blocks_dim=None, w_blocks_dim=None, filter_initializer=None,
variable_dtype=None, name=None):
"""2D Convolution with spatial partitioning.
Spatial partitioning is implemented by decomposing the image into blocks.
Block dimensions represented as h_blocks_dim and w_blocks_dim can be split
along the mesh axis. If split, then we do a halo exchange where each block
receives the part of the image from its left and right neighbors necessary to
do the convolution. Exchange can involve complete or partial blocks depending
on the filter height and width.
Currently, only "SAME" padding with dilation rate of 1 is supported.
Args:
x: a Tensor of shape
[batch, h_blocks_dim, w_blocks_dim, h_dim, w_dim, in_channels_dim]
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a list or tuple in format [filter_height, filter_width].
strides: a list or tuple in format [stride_height, stride_width].
padding: string, "SAME". The type of padding algorithm to use.
"Valid" is not currently supported.
h_blocks_dim: Dimension representing number of height blocks.
w_blocks_dim: Dimension representing number of witdh blocks.
filter_initializer: the initializer for tf.get_variable.
variable_dtype: a mtf.VariableDType
name: a name for the operation (optional).
Returns:
A Tensor of shape
[batch, h_blocks_dim, w_blocks_dim, h_dim, w_dim, out_channels_dim]
"""
# If h_blocks_dim and w_blocks_dim are not split, directly call conv2d.
if h_blocks_dim is None and w_blocks_dim is None:
return conv2d(x, output_dim,
filter_size, strides, padding, filter_initializer,
variable_dtype, name)
assert filter_size[0] % 2 == 1
assert filter_size[1] % 2 == 1
# Padding 'VALID' is not supported yet.
if padding != "SAME":
raise NotImplementedError("conv2d_with_blocks requires padding=SAME")
# Halo exchange for h_blocks and w_blocks.
h_dim, w_dim = x.shape.dims[-3:-1]
for blocks_dim, block_size_dim, halo_size in [
(h_blocks_dim, h_dim, filter_size[0] // 2),
(w_blocks_dim, w_dim, filter_size[1] // 2)]:
if halo_size > 0:
if blocks_dim is not None:
x = mtf.halo_exchange(x, blocks_dim, block_size_dim, halo_size)
else:
x = mtf.pad(x, [halo_size, halo_size], block_size_dim.name)
return conv2d(x, output_dim,
filter_size, strides, "VALID", filter_initializer,
variable_dtype, name)
def conv2d_transpose(x, output_dim,
filter_size=(2, 2), strides=(2, 2),
padding="SAME", filter_initializer=None,
variable_dtype=None, name=None):
"""2D Transposed Convolution.
Args:
x: a mtf.Tensor of format NHWC.
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a list or tuple in format
[filter_height, filter_width]. Only filter_size of (2, 2) is tested.
strides: a list or tuple in format
[stride_height, stride_width]. Only strides of (2, 2) is tested.
padding: either "SAME" or "VALID".
filter_initializer: the initializer for tf.get_variable.
variable_dtype: a mtf.VariableDType
name: a string used for tf.variable_scope.
Returns:
a mtf.Tensor.
"""
fh_dim = mtf.Dimension("fh", filter_size[0])
fw_dim = mtf.Dimension("fw", filter_size[1])
input_dim = x.shape[-1]
with tf.variable_scope(name, default_name="conv2d_transpose"):
if variable_dtype is None:
variable_dtype = mtf.VariableDType(activation_dtype=x.dtype)
conv_filter = mtf.get_variable(
x.mesh, "kernel", [fh_dim, fw_dim, output_dim, input_dim],
initializer=filter_initializer, dtype=variable_dtype)
# Pad stride in batch and channel dimensions.
strides = [1] + list(strides) + [1]
return mtf.Conv2dTransposeOperation(
x, conv_filter, strides, padding).outputs[0]
def conv2d_transpose_with_blocks(
x, output_dim, filter_size=(2, 2),
strides=(2, 2), padding="SAME",
h_blocks_dim=None, w_blocks_dim=None, filter_initializer=None,
variable_dtype=None, name=None):
"""2D Transposed Convolution with spatial partitioning.
Spatial partitioning is implemented by decomposing the image into blocks.
Block dimensions represented as h_blocks_dim and w_blocks_dim can be split
along the mesh axis. If split, then we do a halo exchange where each block
receives the part of the image from its left and right neighbors necessary to
do the convolution. Exchange can involve complete or partial blocks depending
on the filter depth and height.
Currently, only "SAME" padding with dilation rate of 1 is supported. Only
splitting along the depth and height dimensions are supported.
Args:
x: a Tensor of shape
[batch, h_blocks_dim, w_blocks_dim, h_dim, w_dim, in_channel_dim]
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a list or tuple in format
[filter_height, filter_width]. Only filter_size of (2, 2) is tested.
strides: a list or tuple in format
[stride_height, stride_width]. Only strides of (2, 2) is tested.
padding: string, "SAME". The type of padding algorithm to use.
"Valid" is not currently supported.
h_blocks_dim: Dimension representing number of height blocks.
w_blocks_dim: Dimension representing number of width blocks.
filter_initializer: the initializer for tf.get_variable.
variable_dtype: a mtf.VariableDType
name: a name for the operation (optional).
Returns:
A Tensor of shape
[batch, h_blocks_dim, w_blocks_dim, h_dim, w_dim, out_channels_dim]
"""
# If h_blocks_dim and w_blocks_dim are not split, directly call conv2d_trans.
if h_blocks_dim is None and w_blocks_dim is None:
return conv2d_transpose(
x, output_dim, filter_size, strides, padding, filter_initializer,
variable_dtype, name)
# Now only supports even-sized filters.
assert filter_size[0] % 2 == 0
assert filter_size[1] % 2 == 0
# Padding 'VALID' is not supported yet.
if padding != "SAME":
raise NotImplementedError(
"conv2d_transpose_with_blocks requires padding=SAME")
# Halo exchange for h_blocks and w_blocks.
# TODO(lehou): figure out the halo_size in general cases.
h_dim, w_dim = x.shape.dims[-3:-1]
for blocks_dim, block_size_dim, halo_size in [
(h_blocks_dim, h_dim, filter_size[0] // 2 - 1),
(w_blocks_dim, w_dim, filter_size[1] // 2 - 1)]:
if halo_size > 0:
if blocks_dim is not None:
x = mtf.halo_exchange(x, blocks_dim, block_size_dim, halo_size)
else:
x = mtf.pad(x, [halo_size, halo_size], block_size_dim.name)
return conv2d_transpose(
x, output_dim, filter_size, strides, "VALID", filter_initializer,
variable_dtype, name)
def conv3d(x, output_dim, filter_size=(3, 3, 3),
strides=(1, 1, 1), padding="SAME",
filter_initializer=None,
variable_dtype=None, name=None):
"""3D Convolution.
Args:
x: a mtf.Tensor of format NDHWC.
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a list or tuple in format
[filter_depth, filter_height, filter_width].
strides: a list or tuple in format
[stride_depth, stride_height, stride_width].
padding: either "SAME" or "VALID".
filter_initializer: the initializer for tf.get_variable.
variable_dtype: a mtf.VariableDType
name: a string used for tf.variable_scope.
Returns:
a mtf.Tensor.
"""
fd_dim = mtf.Dimension("fd", filter_size[0])
fh_dim = mtf.Dimension("fh", filter_size[1])
fw_dim = mtf.Dimension("fw", filter_size[2])
input_dim = x.shape[-1]
with tf.variable_scope(name, default_name="conv3d"):
if variable_dtype is None:
variable_dtype = mtf.VariableDType(activation_dtype=x.dtype)
conv_filter = mtf.get_variable(
x.mesh, "kernel", [fd_dim, fh_dim, fw_dim, input_dim, output_dim],
initializer=filter_initializer, dtype=variable_dtype)
# Pad stride in batch and channel dimensions.
strides = [1] + list(strides) + [1]
return mtf.Conv3dOperation(x, conv_filter, strides, padding).outputs[0]
def conv3d_with_blocks(
x, output_dim, filter_size=(3, 3, 3),
strides=(1, 1, 1), padding="SAME",
d_blocks_dim=None, h_blocks_dim=None, filter_initializer=None,
variable_dtype=None, name=None):
"""3D Convolution with spatial partitioning.
Spatial partitioning is implemented by decomposing the image into blocks.
Block dimensions represented as d_blocks_dim and h_blocks_dim can be split
along the mesh axis. If split, then we do a halo exchange where each block
receives the part of the image from its left and right neighbors necessary to
do the convolution. Exchange can involve complete or partial blocks depending
on the filter depth and height.
Currently, only "SAME" padding with dilation rate of 1 is supported. Only
splitting along the depth and height dimensions are supported.
Args:
x: a Tensor of shape
[batch, d_blocks_dim, h_blocks_dim, d_dim, h_dim, w_dim, in_channel_dim]
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a list or tuple in format
[filter_depth, filter_height, filter_width].
strides: a list or tuple in format
[stride_depth, stride_height, stride_width].
padding: string, "SAME". The type of padding algorithm to use.
"Valid" is not currently supported.
d_blocks_dim: Dimension representing number of depth blocks.
h_blocks_dim: Dimension representing number of height blocks.
filter_initializer: the initializer for tf.get_variable.
variable_dtype: a mtf.VariableDType
name: a name for the operation (optional).
Returns:
A Tensor of shape
[batch, d_blocks_dim, h_blocks_dim, w_blocks_dim,
d_dim, h_dim, w_dim, out_channels_dim]
"""
# If d_blocks_dim and h_blocks_dim are not split, directly call conv3d.
if d_blocks_dim is None and h_blocks_dim is None:
return conv3d(x, output_dim,
filter_size, strides, padding, filter_initializer,
variable_dtype, name)
assert filter_size[0] % 2 == 1
assert filter_size[1] % 2 == 1
assert filter_size[2] % 2 == 1
# Padding 'VALID' is not supported yet.
if padding != "SAME":
raise NotImplementedError("conv3d_with_blocks requires padding=SAME")
# Halo exchange for d_blocks and h_blocks.
d_dim, h_dim, w_dim = x.shape.dims[-4:-1]
for blocks_dim, block_size_dim, halo_size in [
(d_blocks_dim, d_dim, filter_size[0] // 2),
(h_blocks_dim, h_dim, filter_size[1] // 2)]:
if halo_size > 0:
if blocks_dim is not None:
x = mtf.halo_exchange(x, blocks_dim, block_size_dim, halo_size)
else:
x = mtf.pad(x, [halo_size, halo_size], block_size_dim.name)
# Pad w dimension with zeros.
x = mtf.pad(x, [filter_size[2] // 2, filter_size[2] // 2],
dim_name=w_dim.name, name="conv3d_pad_w_dim")
return conv3d(x, output_dim,
filter_size, strides, "VALID", filter_initializer,
variable_dtype, name)
def conv3d_transpose(x, output_dim,
filter_size=(2, 2, 2), strides=(2, 2, 2),
padding="SAME", filter_initializer=None,
variable_dtype=None, name=None):
"""3D Transposed Convolution.
Args:
x: a mtf.Tensor of format NDHWC.
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a list or tuple in format
[filter_depth, filter_height, filter_width].
Only filter_size of (2, 2, 2) is tested.
strides: a list or tuple in format
[stride_depth, stride_height, stride_width].
Only strides of (2, 2, 2) is tested.
padding: either "SAME" or "VALID".
filter_initializer: the initializer for tf.get_variable.
variable_dtype: a mtf.VariableDType
name: a string used for tf.variable_scope.
Returns:
a mtf.Tensor.
"""
fd_dim = mtf.Dimension("fd", filter_size[0])
fh_dim = mtf.Dimension("fh", filter_size[1])
fw_dim = mtf.Dimension("fw", filter_size[2])
input_dim = x.shape[-1]
with tf.variable_scope(name, default_name="conv3d_transpose"):
if variable_dtype is None:
variable_dtype = mtf.VariableDType(activation_dtype=x.dtype)
conv_filter = mtf.get_variable(
x.mesh, "kernel", [fd_dim, fh_dim, fw_dim, output_dim, input_dim],
initializer=filter_initializer, dtype=variable_dtype)
# Pad stride in batch and channel dimensions.
strides = [1] + list(strides) + [1]
return mtf.Conv3dTransposeOperation(
x, conv_filter, strides, padding).outputs[0]
def conv3d_transpose_with_blocks(
x, output_dim, filter_size=(2, 2, 2),
strides=(2, 2, 2), padding="SAME",
d_blocks_dim=None, h_blocks_dim=None, filter_initializer=None,
variable_dtype=None, name=None):
"""3D Transposed Convolution with spatial partitioning.
Spatial partitioning is implemented by decomposing the image into blocks.
Block dimensions represented as d_blocks_dim and h_blocks_dim can be split
along the mesh axis. If split, then we do a halo exchange where each block
receives the part of the image from its left and right neighbors necessary to
do the convolution. Exchange can involve complete or partial blocks depending
on the filter depth and height.
Currently, only "SAME" padding with dilation rate of 1 is supported. Only
splitting along the depth and height dimensions are supported.
Args:
x: a Tensor of shape
[batch, d_blocks_dim, h_blocks_dim, d_dim, h_dim, w_dim, in_channel_dim]
output_dim: a mtf.Dimension, indicating the output channel dimension.
filter_size: a list or tuple in format
[filter_depth, filter_height, filter_width].
Only filter_size of (2, 2, 2) is tested.
strides: a list or tuple in format
[stride_depth, stride_height, stride_width].
Only strides of (2, 2, 2) is tested.
padding: string, "SAME". The type of padding algorithm to use.
"Valid" is not currently supported.
d_blocks_dim: Dimension representing number of depth blocks.
h_blocks_dim: Dimension representing number of height blocks.
filter_initializer: the initializer for tf.get_variable.
variable_dtype: a mtf.VariableDType
name: a name for the operation (optional).
Returns:
A Tensor of shape
[batch, d_blocks_dim, h_blocks_dim, w_blocks_dim,
d_dim, h_dim, w_dim, out_channels_dim]
"""
# If d_blocks_dim and h_blocks_dim are not split, directly call conv3d_trans.
if d_blocks_dim is None and h_blocks_dim is None:
return conv3d_transpose(
x, output_dim, filter_size, strides, padding, filter_initializer,
variable_dtype, name)
# Now only supports even-sized filters.
assert filter_size[0] % 2 == 0
assert filter_size[1] % 2 == 0
assert filter_size[2] % 2 == 0
# Padding 'VALID' is not supported yet.
if padding != "SAME":
raise NotImplementedError(
"conv3d_transpose_with_blocks requires padding=SAME")
# Halo exchange for d_blocks and h_blocks.
# TODO(lehou): figure out the halo_size in general cases.
d_dim, h_dim, w_dim = x.shape.dims[-4:-1]
for blocks_dim, block_size_dim, halo_size in [
(d_blocks_dim, d_dim, filter_size[0] // 2 - 1),
(h_blocks_dim, h_dim, filter_size[1] // 2 - 1)]:
if halo_size > 0:
if blocks_dim is not None:
x = mtf.halo_exchange(x, blocks_dim, block_size_dim, halo_size)
else:
x = mtf.pad(x, [halo_size, halo_size], block_size_dim.name)
# Pad w dimension with zeros.
x = mtf.pad(x, [filter_size[2] // 2 - 1, filter_size[2] // 2 - 1],
dim_name=w_dim.name, name="conv3d_trans_pad_w_dim")
return conv3d_transpose(
x, output_dim, filter_size, strides, "VALID", filter_initializer,
variable_dtype, name)
def corr(x, dim, epsilon=1e-20, name="pearson_correlation"):
"""Compute correlation along dimension dim, equiv to tfp.stats.correlation.
It treats the dim Dimension as the random event axis, and all the other dims
as the sample axis. Pearson correlation is computed between random events in
dim Dimension, and marginalized over the other dims.
Example usage:
inputs = tf.random_normal([batch, channels])
mtf_inputs = mtf.import_tf_tensor(
mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim]))
correlation = corr(mtf_inputs, dim=channels_dim)
Args:
x: a mtf.Tensor whose shape contains dim.
dim: a mtf.Dimension.
epsilon: a small floating point number for numerical stability.
name: a string used for tf.variable_scope.
Returns:
a mtf.Tensor with the shape of [dim, dim].
"""
with tf.variable_scope(name):
mean = mtf.reduce_mean(x, output_shape=[dim])
dim_name = dim.name
x1 = mtf.rename_dimension(x - mean, dim_name, f"{dim_name}_1")
x2 = mtf.rename_dimension(x - mean, dim_name, f"{dim_name}_2")
variance = lambda z: mtf.sqrt( # pylint: disable=g-long-lambda
mtf.reduce_sum(mtf.square(z), output_shape=z.shape.dims[-1:])) + epsilon
v1, v2 = variance(x1), variance(x2)
return mtf.matmul(x1, x2) / mtf.matmul(v1, v2)
def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"):
"""Layer normalization over dimension dim.
Args:
x: a mtf.Tensor whose shape contains dim.
dim: a mtf.Dimension
epsilon: a floating point number
name: a string used for tf.variable_scope.
Returns:
a mtf.Tensor with same shape as x.
"""
with tf.variable_scope(name + "/layer_norm"):
scale = mtf.get_variable(
x.mesh,
"layer_norm_scale",
mtf.Shape([dim]),
initializer=tf.ones_initializer(),