Skip to content

Commit

Permalink
Change reshapes in MultiHeadedAttention to avoid querying the time …
Browse files Browse the repository at this point in the history
…batch size.

PiperOrigin-RevId: 679755949
  • Loading branch information
lingvo-bot authored and copybara-github committed Sep 27, 2024
1 parent 3fe6275 commit b0e0e80
Showing 1 changed file with 42 additions and 28 deletions.
70 changes: 42 additions & 28 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,16 +1436,19 @@ def PackSource(self,
"""

p = self.params
num_heads = p.num_attention_heads
fns = self.fns
if not p.enable_source_proj:
assert p.source_dim == p.hidden_dim
if not p.enable_query_proj:
assert p.query_dim == p.hidden_dim
# Check input tensor shapes
# [time_steps, batch_size, source_dim]
source_vecs = py_utils.HasRank(source_vecs, 3)
[time_steps, batch_size] = py_utils.GetShape(source_vecs, 2)
if p.use_source_vec_as_attention_value:
assert source_contexts is not None
# [time_steps, batch_size, context_dim]
source_contexts = py_utils.HasShape(source_contexts,
[time_steps, batch_size, -1])
source_padding = py_utils.HasShape(source_padding,
Expand All @@ -1454,9 +1457,7 @@ def PackSource(self,
source_segment_id = py_utils.HasShape(source_segment_id,
[time_steps, batch_size])

with tf.name_scope('init__0'):
# source_vecs shape after (optional) projection is
# [time_steps, batch_size, hidden]
with tf.name_scope('vecs'):
if p.enable_source_proj:
source_vecs, w_source_proj = self.ToAqtInputs(
'source_proj',
Expand All @@ -1472,21 +1473,26 @@ def PackSource(self,
source_vecs,
self.QWeight(theta.source_proj_b),
qout_name='source_proj_add')
source_vecs = gshard_utils.MeshSplit(
source_vecs,
p.device_mesh,
p.activation_split_dims_mapping)
with tf.name_scope('init__1'):
num_heads = p.num_attention_heads
# => [time_steps, batch_size * num_heads, hidden / num_heads]
[time_steps_vecs] = py_utils.GetShape(source_vecs, 1)
source_vecs = tf.reshape(
source_vecs,
[time_steps_vecs, -1, symbolic.ToStatic(p.hidden_dim // num_heads)])
source_vecs = gshard_utils.MeshSplit(
source_vecs,
p.device_mesh,
p.activation_split_dims_mapping)
# => [time_steps, batch_size, hidden_dim]
source_vecs = py_utils.HasShape(source_vecs,
[time_steps,
batch_size,
symbolic.ToStatic(p.hidden_dim)])
# => [time_steps, batch_size * num_heads, hidden_dim / num_heads]
source_vecs = tf.reshape(source_vecs,
[-1,
batch_size * num_heads,
symbolic.ToStatic(p.hidden_dim // num_heads)])
source_vecs = gshard_utils.MeshSplit(source_vecs,
p.device_mesh,
p.activation_split_dims_mapping)
source_vecs = self.ProcessProjectionVec(theta, source_vecs, 'source')

with tf.name_scope('contexts'):
if p.use_source_vec_as_attention_value:
source_contexts = source_vecs
else:
Expand All @@ -1509,12 +1515,14 @@ def PackSource(self,
source_contexts,
p.device_mesh,
p.activation_split_dims_mapping)

time_steps_contexts = py_utils.GetShape(source_contexts)[0]
source_context_depth = py_utils.GetShape(source_contexts)[-1]
# => [time_steps, batch_size, context_dim]
source_contexts = py_utils.HasShape(source_contexts,
[time_steps, batch_size, -1])
context_dim = py_utils.GetShape(source_contexts)[-1]
# => [time_steps, batch_size * num_heads, context_dim / num_heads]
source_contexts = tf.reshape(
source_contexts,
[time_steps_contexts, -1, source_context_depth // num_heads])
[-1, batch_size * num_heads, context_dim // num_heads])
source_contexts = gshard_utils.MeshSplit(
source_contexts,
p.device_mesh,
Expand All @@ -1523,25 +1531,31 @@ def PackSource(self,
source_contexts,
'ctx')

with tf.name_scope('init__2'):
[time_steps_paddings] = py_utils.GetShape(source_padding, 1)
with tf.name_scope('padding'):
# => [time_steps, batch_size, 1]
source_padding = tf.expand_dims(source_padding, 2)
# => [time_steps, batch_size, num_heads]
source_padding = tf.tile(source_padding, [1, 1, num_heads])
source_padding = tf.reshape(source_padding, [time_steps_paddings, -1])
# => [time_steps, batch_size * num_heads]
source_padding = tf.reshape(source_padding, [-1, batch_size * num_heads])

with tf.name_scope('segment_id'):
if source_segment_id is None:
source_segment_id = tf.zeros_like(source_padding)
else:
[time_steps_segment_id] = py_utils.GetShape(source_segment_id, 1)
# => [time_steps, batch_size, 1]
source_segment_id = tf.expand_dims(source_segment_id, 2)
# => [time_steps, batch_size, num_heads]
source_segment_id = tf.tile(source_segment_id, [1, 1, num_heads])
# => [time_steps, batch_size * num_heads]
source_segment_id = tf.reshape(source_segment_id,
[time_steps_segment_id, -1])
[-1, batch_size * num_heads])

return self.atten.PackSource(theta.atten,
source_vecs,
source_contexts,
source_padding,
source_segment_id)
return self.atten.PackSource(theta.atten,
source_vecs,
source_contexts,
source_padding,
source_segment_id)

@py_utils.NameScopeDecorator('MultiHeadedAttention/ExtendSourcePacked')
def ExtendSourcePacked(self,
Expand Down

0 comments on commit b0e0e80

Please sign in to comment.