From 77db199392ff967e49156ccb24c916b270b47eca Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Mon, 23 Nov 2020 19:16:12 -0800 Subject: [PATCH] Bidirectional FAVOR. PiperOrigin-RevId: 343972992 --- trax/layers/research/sparsity.py | 230 ++++++++++++++++---------- trax/layers/research/sparsity_test.py | 23 ++- 2 files changed, 166 insertions(+), 87 deletions(-) diff --git a/trax/layers/research/sparsity.py b/trax/layers/research/sparsity.py index fceb947fc..09b244c49 100644 --- a/trax/layers/research/sparsity.py +++ b/trax/layers/research/sparsity.py @@ -23,7 +23,7 @@ from trax import fastmath from trax import layers as tl -from trax.fastmath import numpy as np +from trax.fastmath import numpy as jnp from trax.layers import base from trax.layers import core from trax.layers import initializers as init @@ -52,14 +52,14 @@ class ReversibleReshapePermute(reversible.ReversibleLayer): def forward(self, x): shape = x.shape x = x.reshape(shape[:-1]+(-1, self._get_multiplier(x))) - t_x = np.einsum('...ab->...ba', x) # transpose + t_x = jnp.einsum('...ab->...ba', x) # transpose return t_x.reshape(shape) def reverse(self, x, weights=(), state=(), new_state=(), rng=None): del state, new_state, rng shape = x.shape x = x.reshape(shape[:-1]+(self._get_multiplier(x), -1)) - t_x = np.einsum('...ab->...ba', x) # transpose + t_x = jnp.einsum('...ab->...ba', x) # transpose return t_x.reshape(shape) def _get_multiplier(self, x): @@ -133,13 +133,13 @@ def _get_permutation_and_reverse_permutation(self, x): @assert_shape('...a->...bc') def SplitLastAxis(num_splits): return tl.Fn(f'SplitLastAxis_{num_splits}', - lambda x: np.reshape(x, x.shape[:-1] + (num_splits, -1))) + lambda x: jnp.reshape(x, x.shape[:-1] + (num_splits, -1))) @assert_shape('...ab->...c') def MergeLastTwoAxes(): return tl.Fn('SplitLastAxis', - lambda x: np.reshape(x, x.shape[:-2] + (-1,))) + lambda x: jnp.reshape(x, x.shape[:-2] + (-1,))) @assert_shape('...a->...b') @@ -243,7 +243,7 @@ def LocallyConvDense(n_modules, n_units, kernel_size=1, length_kernel_size=1): pad_widths = [[0, 0], [length_kernel_size - 1, 0], [half, half], [0, 0]] return tl.Serial( tl.SplitLastAxis(n_modules), - tl.Fn('Pad', lambda x: np.pad(x, pad_width=pad_widths)), + tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths)), tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)), tl.MergeLastTwoAxes() ) @@ -349,7 +349,7 @@ def MultiplicativeSparseDense(sparsity, d_input, d_output=None): # kernel is done in the same einsum. tl.Fn('AttentionEinsum', (lambda kernel, multiplier, embeds: # pylint: disable=g-long-lambda - np.einsum('dx,hd,bld->blhx', kernel, multiplier, embeds))), + jnp.einsum('dx,hd,bld->blhx', kernel, multiplier, embeds))), MergeLastTwoAxes(), # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]), @@ -387,7 +387,7 @@ def MultiplicativeModularSparseDense(sparsity, d_feature): # kernels is done in a single einsum. tl.Fn('SparseDenseEinsum', (lambda kmod, kmult, multiplier, embeds: # pylint: disable=g-long-lambda - np.einsum('hxo,dx,hd,bld->blho', kmod, kmult, multiplier, embeds))), + jnp.einsum('hxo,dx,hd,bld->blho', kmod, kmult, multiplier, embeds))), MergeLastTwoAxes(), # Weight below is bias after dense, per-head. tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]), @@ -506,6 +506,65 @@ def MultiplicativeConvCausalAttention( ) +def Favor(d_feature, n_heads=1, dropout=0.0, + numerical_stabilizer=0.001, mode='train'): + """Returns a layer that maps (activations, mask) to (new_activations, mask). + + See the FAVOR paper for details: https://arxiv.org/abs/2006.03555 + + Args: + d_feature: Depth/dimensionality of feature embedding. + n_heads: Number of attention heads. + dropout: Probababilistic rate for internal dropout applied to attention + activations (based on query-key pairs) before dotting them with values. + numerical_stabilizer: float, small number used for numerical stability. + mode: One of `'train'`, `'eval'`, or `'predict'`. + """ + del dropout, mode # not implemented yet but needed in the API + + def bidirectional_numerator(query_prime, key_prime, value): + kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value) + return jnp.einsum('lbm,bmd->lbd', query_prime, kvs) + + def bidirectional_denominator(query_prime, key_prime): + all_ones = jnp.ones([query_prime.shape[0]]) + ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones) + return jnp.einsum('lbm,bm->lb', query_prime, ks_sum) + + def relu(x): + return jnp.where(x <= 0, jnp.zeros_like(x), x) + + def favor(query, key, value, mask): + query_prime = relu(query) + numerical_stabilizer + key_prime = relu(key) + numerical_stabilizer + mask_batch_1_length = jnp.reshape( + mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32) + mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1)) + key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) + + w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0), + jnp.moveaxis(value, 1, 0)) + r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0)) + w = jnp.moveaxis(w, 0, 1) + r = jnp.moveaxis(r, 0, 1) + r = jnp.reciprocal(r) + r = jnp.expand_dims(r, len(r.shape)) + renormalized_attention = w * r + return renormalized_attention, mask + + return tl.Serial( + tl.Branch( + [tl.Dense(d_feature), tl.SplitIntoHeads(n_heads)], + [tl.Dense(d_feature), tl.SplitIntoHeads(n_heads)], + [tl.Dense(d_feature), tl.SplitIntoHeads(n_heads)], + ), + tl.Fn('FAVOR', favor, n_out=2), + tl.Dense(d_feature), + ) + + def CausalFavor(d_feature, n_heads=1, dropout=0.0, numerical_stabilizer=0.001, precision=None, mode='train'): """Returns a layer that maps activations to activations, with causal masking. @@ -520,7 +579,7 @@ def CausalFavor(d_feature, n_heads=1, dropout=0.0, dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. numerical_stabilizer: float, small number used for numerical stability. - precision: passed to np.einsum to define arithmetic precision. + precision: passed to jnp.einsum to define arithmetic precision. mode: One of `'train'`, `'eval'`, or `'predict'`. """ del dropout, mode # not implemented yet but needed in the API @@ -529,8 +588,8 @@ def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value): def body(p, qkv): (q, k, v) = qkv - p += np.einsum('...m,...d->...md', k, v, precision=precision) - x_slice = np.einsum('...m,...md->...d', q, p, precision=precision) + p += jnp.einsum('...m,...d->...md', k, v, precision=precision) + x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision) return p, x_slice p, w = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime, value)) @@ -542,15 +601,15 @@ def favor_numerator_bwd(pqkv, w_ct): def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct - q_ct = np.einsum('...d,...md->...m', x_ct, p, precision=precision) - p_ct += np.einsum('...d,...m->...md', x_ct, q, precision=precision) - k_ct = np.einsum('...md,...d->...m', p_ct, v, precision=precision) - v_ct = np.einsum('...md,...m->...d', p_ct, k, precision=precision) - p -= np.einsum('...m,...d->...md', k, v, precision=precision) + q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision) + p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision) + k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision) + v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision) + p -= jnp.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) _, (qs_ct, ks_ct, vs_ct) = fastmath.scan( - body, (p, np.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) + body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return (None, None, qs_ct, ks_ct, vs_ct) def favor_numerator(init_prefix_sum_value, precision, query_prime, @@ -567,7 +626,7 @@ def favor_denominator_fwd(init_prefix_sum_value, precision, def body(p, qk): q, k = qk p += k - x = np.einsum('...m,...m->...', q, p, precision=precision) + x = jnp.einsum('...m,...m->...', q, p, precision=precision) return p, x p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime)) @@ -579,14 +638,14 @@ def favor_denominator_bwd(qkp, r_ct): def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx - q_ct = np.einsum('...,...m->...m', x_ct, p, precision=precision) - p_ct += np.einsum('...,...m->...m', x_ct, q, precision=precision) + q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision) + p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) _, (qs_ct, ks_ct) = fastmath.scan( - body, (p, np.zeros_like(p)), (qs, ks, r_ct), reverse=True) + body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (None, None, qs_ct, ks_ct) def favor_denominator(init_prefix_sum_value, precision, query_prime, @@ -601,36 +660,35 @@ def favor_denominator(init_prefix_sum_value, precision, query_prime, favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) def relu(x): - return np.where(x <= 0, np.zeros_like(x), x) + return jnp.where(x <= 0, jnp.zeros_like(x), x) def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) - init_prefix_sum_value_numerator = np.zeros(prefix_sum_tensor_shape) - init_prefix_sum_value_denominator = np.zeros(t_slice_shape) + init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) + init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, - np.moveaxis(query_prime, 1, 0), - np.moveaxis(key_prime, 1, 0), - np.moveaxis(value, 1, 0)) + jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0), + jnp.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, - np.moveaxis(query_prime, 1, 0), - np.moveaxis(key_prime, 1, 0)) - w = np.moveaxis(w, 0, 1) - r = np.moveaxis(r, 0, 1) - # r = r + 2 * numerical_stabilizer * (np.abs(r) <= numerical_stabilizer) - r = np.reciprocal(r) - r = np.expand_dims(r, len(r.shape)) + jnp.moveaxis(query_prime, 1, 0), + jnp.moveaxis(key_prime, 1, 0)) + w = jnp.moveaxis(w, 0, 1) + r = jnp.moveaxis(r, 0, 1) + r = jnp.reciprocal(r) + r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention return tl.ConfigurableAttention( core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), n_heads=n_heads, - qkv_attention_layer=base.Fn('FAVOR', favor)) + qkv_attention_layer=base.Fn('CausalFAVOR', favor)) class SparseFF(base.Layer): @@ -679,23 +737,23 @@ def forward(self, x): """ m1, m2, mb, w1, w2, b2 = self.weights if self._mode != 'predict': - w1 = np.reshape(w1.T, (-1, self._d_ff)) - w2 = np.reshape(w2, (self._d_ff, -1)) + w1 = jnp.reshape(w1.T, (-1, self._d_ff)) + w2 = jnp.reshape(w2, (self._d_ff, -1)) x_shape = x.shape - x = np.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: should we add bias and/or put relu after the low-rank m1 dot? - mask_logits = np.dot(np.dot(x, m1), m2) + mb - mask_logits = np.reshape(mask_logits, [-1, self._d1, self._d2]) + mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb + mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2]) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp - mask = np.exp(log_mask) + mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. rng1, rng2 = fastmath.random.split(self.rng, 2) - u = fastmath.random.uniform(rng1, mask.shape, np.float32, 1e-6, 1.0 - 1e-6) - g = -np.log(-np.log(u)) - quant_mask = np.argmax(log_mask + g * self._temperature, axis=-1) + u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) + g = -jnp.log(-jnp.log(u)) + quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = metrics.one_hot(quant_mask, self._n_elements_in_block) @@ -703,18 +761,18 @@ def forward(self, x): quant_mask += mask - fastmath.stop_gradient(mask) # straight-through # We will sometimes (quant_prob of the batches) use the soft-mask instead # of the quantized mask to improve training stability (see paper above). - select = fastmath.random.uniform(rng2, (), np.float32, 0.0, 1.0) - quant_mask = np.where(select < self._quant_prob, quant_mask, mask) - quant_mask = np.reshape(quant_mask, [-1, self._d_ff]) + select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0) + quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) + quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) if self._mode == 'train': # In training, run full matmul to get benefits from the above tricks. - mid = np.dot(x, w1) * quant_mask # [joint_batch, d_ff] - relu = np.where(mid <= 0, np.zeros_like(mid), mid) - res = np.dot(relu, w2) + b2 + mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + res = jnp.dot(relu, w2) + b2 elif self._mode == 'predict': - # w1 = np.reshape(w1.T, (self._d1, self._d2, -1)) - # w2 = np.reshape(w2, (self._d1, self._d2, -1)) + # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1)) + # w2 = jnp.reshape(w2, (self._d1, self._d2, -1)) # This implementation mimicks inference. It's not efficient for large # size of joint_batch, but at inference that will be 1 most of the time. # Shapes: @@ -723,26 +781,26 @@ def forward(self, x): # we'll index w1 with advanced numpy indexing, first range over # self._d1 times the batch size, second range being quant_mask batch_size = quant_mask.shape[0] - idx1 = np.array([np.arange(self._d1)] * batch_size) + idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) # flatten indices and select from w1 - idx1 = np.reshape(idx1, [-1]) - idx2 = np.reshape(quant_mask, [-1]) + idx1 = jnp.reshape(idx1, [-1]) + idx2 = jnp.reshape(quant_mask, [-1]) w = w1[idx1, idx2, :] # now we have per-element weights with batch dim - w = np.reshape(w, [batch_size, self._d1, -1]) - mid = np.einsum('ai,aji->aj', x, w) - relu = np.where(mid <= 0, np.zeros_like(mid), mid) + w = jnp.reshape(w, [batch_size, self._d1, -1]) + mid = jnp.einsum('ai,aji->aj', x, w) + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [self._d1, self._d2, d_model] v = w2[idx1, idx2, :] - v = np.reshape(v, [batch_size, self._d1, -1]) - res = np.einsum('ai,aij->aj', relu, v) + b2 + v = jnp.reshape(v, [batch_size, self._d1, -1]) + res = jnp.einsum('ai,aij->aj', relu, v) + b2 else: quant_mask = metrics.one_hot(quant_mask, self._n_elements_in_block) - quant_mask = np.reshape(quant_mask, [-1, self._d_ff]) - mid = np.dot(x, w1) * quant_mask # [joint_batch, d_ff] - relu = np.where(mid <= 0, np.zeros_like(mid), mid) - res = np.dot(relu, w2) + b2 + quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) + mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + res = jnp.dot(relu, w2) + b2 - return np.reshape(res, x_shape) # un-flatten if needed + return jnp.reshape(res, x_shape) # un-flatten if needed def init_weights_and_state(self, input_signature): """Randomly initializes this layer's weights.""" @@ -763,8 +821,8 @@ def init_weights_and_state(self, input_signature): w2 = self._kernel_initializer(shape_w2, rng_w2) b2 = self._bias_initializer(shape_b2, rng_b2) - w1 = np.reshape(w1.T, (self._d1, self._d2, -1)) - w2 = np.reshape(w2, (self._d1, self._d2, -1)) + w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1)) + w2 = jnp.reshape(w2, (self._d1, self._d2, -1)) self.weights = (m1, m2, mb, w1, w2, b2) @@ -812,20 +870,20 @@ def forward(self, x): """ m1, w1, w2, b2 = self.weights x_shape = x.shape - x = np.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. + x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: check if we need bias and/or put relu after the m1 dot? - mask_logits = np.dot(x, m1) + mask_logits = jnp.dot(x, m1) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp - mask = np.exp(log_mask) + mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. # TODO(lukaszkaiser, chowdhery): Extract this block and share rng1, rng2 = fastmath.random.split(self.rng, 2) - u = fastmath.random.uniform(rng1, mask.shape, np.float32, 1e-6, 1.0 - 1e-6) - g = -np.log(-np.log(u)) - selected_experts = np.argmax(log_mask + g * self._temperature, axis=-1) + u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) + g = -jnp.log(-jnp.log(u)) + selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = metrics.one_hot(selected_experts, self._num_experts) @@ -834,11 +892,11 @@ def forward(self, x): # We will sometimes (50% of the batches) use the soft-mask instead of # the quantized mask to improve training stability (see the paper above). # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? - select = fastmath.random.uniform(rng2, (), np.float32, -1.0, 1.0) - quant_mask = np.where(select > 0.0, quant_mask, mask) + select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) + quant_mask = jnp.where(select > 0.0, quant_mask, mask) else: quant_mask = metrics.one_hot(selected_experts, self._num_experts) - quant_mask = np.reshape(quant_mask, [-1, self._num_experts, 1]) + quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1]) quant_mask_shape = quant_mask.shape batch_size = quant_mask.shape[0] @@ -848,23 +906,23 @@ def forward(self, x): # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] w = jax.lax.dynamic_slice(w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block]) - mid = np.dot(x, w) - relu = np.where(mid <= 0, np.zeros_like(mid), mid) + mid = jnp.dot(x, w) + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] v = jax.lax.dynamic_slice(w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]]) - v = np.reshape(v, [self._n_elements_in_block, -1]) - res = np.dot(relu, v) + b2 + v = jnp.reshape(v, [self._n_elements_in_block, -1]) + res = jnp.dot(relu, v) + b2 else: - expanded_mask = np.broadcast_to( + expanded_mask = jnp.broadcast_to( quant_mask, (quant_mask_shape[0], quant_mask.shape[1], self._n_elements_in_block)) - expanded_mask = np.reshape(expanded_mask, (-1, self._d_ff)) - mid = np.dot(x, w1) * expanded_mask # [joint_batch, d_ff] - relu = np.where(mid <= 0, np.zeros_like(mid), mid) - res = np.dot(relu, w2) + b2 + expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) + mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] + relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) + res = jnp.dot(relu, w2) + b2 - return np.reshape(res, x_shape) # un-flatten if needed + return jnp.reshape(res, x_shape) # un-flatten if needed def init_weights_and_state(self, input_signature): """Randomly initializes this layer's weights.""" diff --git a/trax/layers/research/sparsity_test.py b/trax/layers/research/sparsity_test.py index a87ee5055..a85011849 100644 --- a/trax/layers/research/sparsity_test.py +++ b/trax/layers/research/sparsity_test.py @@ -172,9 +172,30 @@ def test_simple_call(self): self.assertEqual(y.shape, (1, 3, 4)) -class CausalFavorTest(test.TestCase): +class FavorTest(test.TestCase): def test_call_and_grad(self): + layer = tl.Serial( + tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()), + sparsity.Favor(d_feature=4, n_heads=2), + tl.Select([0], n_in=2), + tl.LogSoftmax(), + tl.CrossEntropyLoss() + ) + x = np.ones((1, 2), dtype=np.int32) + w = np.ones_like(x).astype(np.float32) + x_sig = shapes.signature(x) + w_sig = shapes.signature(w) + layer.init((x_sig, x_sig, w_sig)) + y = layer((x, x, w)) + self.assertEqual(y.shape, ()) + state = layer.state + rng = fastmath.random.get_prng(0) + fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[0] + g = fastmath.grad(fwd)(layer.weights, (x, x, w)) + self.assertEqual(g[0][1][0].shape, (3, 4)) + + def test_causal_call_and_grad(self): layer = tl.Serial( tl.Dense(4), sparsity.CausalFavor(d_feature=4, n_heads=2),