Skip to content

Commit

Permalink
Change layers_with_attention to pytype_strict_library.
Browse files Browse the repository at this point in the history
Fix warnings / errors found by the linter.

PiperOrigin-RevId: 589252477
  • Loading branch information
lingvo-bot authored and copybara-github committed Dec 8, 2023
1 parent 64ef85f commit 2fd6269
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 26 deletions.
3 changes: 2 additions & 1 deletion lingvo/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2036,7 +2036,7 @@ lingvo_cuda_py_test(
],
)

py_library(
pytype_strict_library(
name = "layers_with_attention",
srcs = ["layers_with_attention.py"],
deps = [
Expand All @@ -2051,6 +2051,7 @@ py_library(
":py_utils",
":symbolic",
"//lingvo:compat",
# Implicit numpy dependency.
],
)

Expand Down
67 changes: 42 additions & 25 deletions lingvo/core/layers_with_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,15 +1198,15 @@ def _split(t_in, sharding):
# Residual dropout.
after_residual = self.residual_dropout.FProp(theta.residual_dropout,
combined_output)
if p.add_skip_connection:
if p.residual_droppath_prob:
out = self.residual_droppath.FProp(
theta.residual_droppath,
inputs,
after_residual,
)
else:
out = inputs + after_residual * self.params.residual_weight
assert p.add_skip_connection
if p.residual_droppath_prob:
out = self.residual_droppath.FProp(
theta.residual_droppath,
inputs,
after_residual,
)
else:
out = inputs + after_residual * self.params.residual_weight

if not p.pre_layer_norm:
out = self.layer_norm.FProp(theta.layer_norm, out)
Expand All @@ -1233,7 +1233,7 @@ def Params(cls):
p.ln_tpl = layers.ReshapedLayerNorm.Params()
return p

def FProp(self, theta, inputs, paddings):
def FProp(self, theta, inputs, paddings, tasks=None):
"""Feed-forward, residual and layer-norm.
Args:
Expand All @@ -1244,10 +1244,15 @@ def FProp(self, theta, inputs, paddings):
first augmented (resp. reduced) by splitting the last dimension
according to the device_mesh (resp. merging the last two dimensions).
paddings: [time, batch].
tasks: Not supported, must be None.
Returns:
tensor of the same shape with inputs.
"""
if tasks is not None:
raise ValueError(
'multi-task is not supported in ReshapedTransformerFeedForwardLayer.'
)
p = self.params
with tf.name_scope(p.name):
inputs_shape = py_utils.GetShape(inputs)
Expand Down Expand Up @@ -2220,14 +2225,16 @@ def FProp(self,
hidden = self.adapters.FProp(theta.adapters, hidden, source_task_id)
return hidden, atten_prob

def ExtendStep(self,
theta,
source_vecs,
prefix_states,
aux_vecs=None,
aux_paddings=None,
timestep=None,
source_task_id=None):
def ExtendStep(
self,
theta,
source_vecs,
prefix_states,
aux_vecs=None,
aux_paddings=None,
t=None,
source_task_id=None,
):
"""Transformer Layer with adapters, extend one step in decoding.
Applies TransformerLayer.ExtendStep, then applies adapters.
Expand All @@ -2240,7 +2247,7 @@ def ExtendStep(self,
attentions, used for fast decoding.
aux_vecs: [aux_time, aux_batch, dim]
aux_paddings: [aux_time, aux_batch]
timestep: a scalar, the current time step, 0-based.
t: a scalar, the current time step, 0-based.
source_task_id: [source_batch]
Returns:
Expand All @@ -2260,7 +2267,8 @@ def ExtendStep(self,

# First the self-attention layer.
atten_vec, atten_prob, new_states = self.self_atten.ExtendStep(
theta.self_atten, source_vecs, prefix_states, timestep)
theta.self_atten, source_vecs, prefix_states, t
)

atten_vec = tf.expand_dims(atten_vec, axis=0)
# Next the source attention layer.
Expand Down Expand Up @@ -2454,6 +2462,8 @@ def FProp(self,
-1, 0)
elif p.mask_type == 'eye':
padding = tf.eye(target_time, target_time, dtype=py_utils.FPropDtype(p))
else:
raise ValueError('Unsupported mask type')

# [time, batch, time]
causal_padding = tf.tile(tf.expand_dims(padding, 1), [1, target_bs, 1])
Expand Down Expand Up @@ -2716,6 +2726,7 @@ def FProp(self, theta, inputs, paddings):
"""
p = self.params
ff_outputs = []
inputs_normalized: tf.Tensor = None
for i in range(p.num_blocks):
inputs_normalized = self.layer_norm[i].FProp(theta.layer_norm[i], inputs)
ff_output = self.fflayers[i].FProp(
Expand Down Expand Up @@ -2857,14 +2868,20 @@ def FProp(self,
assert tertiary_segment_id is not None, ('Need to specify segment id for '
'packed input.')

atten_vec, atten_prob = self.self_atten.FProp(
atten_vec, _ = self.self_atten.FProp(
theta.self_atten,
source_vecs,
source_paddings,
query_segment_id=source_segment_id)
atten_vec, atten_prob = self.tertiary_atten.FProp(
theta.tertiary_atten, atten_vec, tertiary_paddings, tertiary_vecs,
source_segment_id, tertiary_segment_id)
query_segment_id=source_segment_id,
)
atten_vec, _ = self.tertiary_atten.FProp(
theta.tertiary_atten,
atten_vec,
tertiary_paddings,
tertiary_vecs,
source_segment_id,
tertiary_segment_id,
)
atten_vec, atten_prob = self.atten.FProp(theta.atten, atten_vec,
aux_paddings, aux_vecs,
source_segment_id, aux_segment_id,
Expand Down

0 comments on commit 2fd6269

Please sign in to comment.