From 2fd6269f0f1a1cb51e1e11930225f7098aecdcf5 Mon Sep 17 00:00:00 2001 From: Lingvo Maintenance Date: Fri, 8 Dec 2023 15:07:26 -0800 Subject: [PATCH] Change `layers_with_attention` to `pytype_strict_library`. Fix warnings / errors found by the linter. PiperOrigin-RevId: 589252477 --- lingvo/core/BUILD | 3 +- lingvo/core/layers_with_attention.py | 67 +++++++++++++++++----------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/lingvo/core/BUILD b/lingvo/core/BUILD index 14e05f37b..fe020e725 100644 --- a/lingvo/core/BUILD +++ b/lingvo/core/BUILD @@ -2036,7 +2036,7 @@ lingvo_cuda_py_test( ], ) -py_library( +pytype_strict_library( name = "layers_with_attention", srcs = ["layers_with_attention.py"], deps = [ @@ -2051,6 +2051,7 @@ py_library( ":py_utils", ":symbolic", "//lingvo:compat", + # Implicit numpy dependency. ], ) diff --git a/lingvo/core/layers_with_attention.py b/lingvo/core/layers_with_attention.py index 3530b87f0..3788bf07c 100644 --- a/lingvo/core/layers_with_attention.py +++ b/lingvo/core/layers_with_attention.py @@ -1198,15 +1198,15 @@ def _split(t_in, sharding): # Residual dropout. after_residual = self.residual_dropout.FProp(theta.residual_dropout, combined_output) - if p.add_skip_connection: - if p.residual_droppath_prob: - out = self.residual_droppath.FProp( - theta.residual_droppath, - inputs, - after_residual, - ) - else: - out = inputs + after_residual * self.params.residual_weight + assert p.add_skip_connection + if p.residual_droppath_prob: + out = self.residual_droppath.FProp( + theta.residual_droppath, + inputs, + after_residual, + ) + else: + out = inputs + after_residual * self.params.residual_weight if not p.pre_layer_norm: out = self.layer_norm.FProp(theta.layer_norm, out) @@ -1233,7 +1233,7 @@ def Params(cls): p.ln_tpl = layers.ReshapedLayerNorm.Params() return p - def FProp(self, theta, inputs, paddings): + def FProp(self, theta, inputs, paddings, tasks=None): """Feed-forward, residual and layer-norm. Args: @@ -1244,10 +1244,15 @@ def FProp(self, theta, inputs, paddings): first augmented (resp. reduced) by splitting the last dimension according to the device_mesh (resp. merging the last two dimensions). paddings: [time, batch]. + tasks: Not supported, must be None. Returns: tensor of the same shape with inputs. """ + if tasks is not None: + raise ValueError( + 'multi-task is not supported in ReshapedTransformerFeedForwardLayer.' + ) p = self.params with tf.name_scope(p.name): inputs_shape = py_utils.GetShape(inputs) @@ -2220,14 +2225,16 @@ def FProp(self, hidden = self.adapters.FProp(theta.adapters, hidden, source_task_id) return hidden, atten_prob - def ExtendStep(self, - theta, - source_vecs, - prefix_states, - aux_vecs=None, - aux_paddings=None, - timestep=None, - source_task_id=None): + def ExtendStep( + self, + theta, + source_vecs, + prefix_states, + aux_vecs=None, + aux_paddings=None, + t=None, + source_task_id=None, + ): """Transformer Layer with adapters, extend one step in decoding. Applies TransformerLayer.ExtendStep, then applies adapters. @@ -2240,7 +2247,7 @@ def ExtendStep(self, attentions, used for fast decoding. aux_vecs: [aux_time, aux_batch, dim] aux_paddings: [aux_time, aux_batch] - timestep: a scalar, the current time step, 0-based. + t: a scalar, the current time step, 0-based. source_task_id: [source_batch] Returns: @@ -2260,7 +2267,8 @@ def ExtendStep(self, # First the self-attention layer. atten_vec, atten_prob, new_states = self.self_atten.ExtendStep( - theta.self_atten, source_vecs, prefix_states, timestep) + theta.self_atten, source_vecs, prefix_states, t + ) atten_vec = tf.expand_dims(atten_vec, axis=0) # Next the source attention layer. @@ -2454,6 +2462,8 @@ def FProp(self, -1, 0) elif p.mask_type == 'eye': padding = tf.eye(target_time, target_time, dtype=py_utils.FPropDtype(p)) + else: + raise ValueError('Unsupported mask type') # [time, batch, time] causal_padding = tf.tile(tf.expand_dims(padding, 1), [1, target_bs, 1]) @@ -2716,6 +2726,7 @@ def FProp(self, theta, inputs, paddings): """ p = self.params ff_outputs = [] + inputs_normalized: tf.Tensor = None for i in range(p.num_blocks): inputs_normalized = self.layer_norm[i].FProp(theta.layer_norm[i], inputs) ff_output = self.fflayers[i].FProp( @@ -2857,14 +2868,20 @@ def FProp(self, assert tertiary_segment_id is not None, ('Need to specify segment id for ' 'packed input.') - atten_vec, atten_prob = self.self_atten.FProp( + atten_vec, _ = self.self_atten.FProp( theta.self_atten, source_vecs, source_paddings, - query_segment_id=source_segment_id) - atten_vec, atten_prob = self.tertiary_atten.FProp( - theta.tertiary_atten, atten_vec, tertiary_paddings, tertiary_vecs, - source_segment_id, tertiary_segment_id) + query_segment_id=source_segment_id, + ) + atten_vec, _ = self.tertiary_atten.FProp( + theta.tertiary_atten, + atten_vec, + tertiary_paddings, + tertiary_vecs, + source_segment_id, + tertiary_segment_id, + ) atten_vec, atten_prob = self.atten.FProp(theta.atten, atten_vec, aux_paddings, aux_vecs, source_segment_id, aux_segment_id,