Skip to content

Commit

Permalink
Move reshape after softmax calculations, to avoid reshaping back to t…
Browse files Browse the repository at this point in the history
…he original shape where needed.

PiperOrigin-RevId: 686542189
  • Loading branch information
lingvo-bot authored and copybara-github committed Oct 16, 2024
1 parent f1f5671 commit b26149e
Showing 1 changed file with 31 additions and 39 deletions.
70 changes: 31 additions & 39 deletions lingvo/core/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,6 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor:
"""Generates probs."""
source_batch = py_utils.GetShape(inputs.source_vecs)[2]
multiplier = py_utils.GetShape(inputs.query_vec)[1]
target_batch = multiplier * source_batch

# Shape of summed is [sl, tb/sb, sb, hidden_dim].
summed = tf.tanh(inputs.source_vecs + inputs.query_vec)
Expand Down Expand Up @@ -568,13 +567,10 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor:
tf.logging.warning(
'packed_input is False but query_segment_id is passed.'
)
source_padding = tf.reshape(source_padding, [-1, target_batch])
source_padding = tf.transpose(source_padding)
source_padding = tf.transpose(source_padding, [1, 2, 0])

# [source_length, target_batch]
logits = tf.reshape(logits, [-1, target_batch])
# [target_batch, source_length]
logits = tf.transpose(logits)
# [multiplier, source_batch, source_length]
logits = tf.transpose(logits, [1, 2, 0])
# take the softmax to compute the probabilities.
probs = self._PaddedSoftmax(logits, source_padding)
return probs
Expand Down Expand Up @@ -638,31 +634,32 @@ def Atten(
args.source_segment_id = source_segment_id
if query_segment_id is not None:
args.query_segment_id = query_segment_id
# probs is of shape [target_batch, source_length]
# probs is of shape [multiplier, source_batch, source_length]
probs = py_utils.CallDefun(AttenProbs, args)
probs = py_utils.HasShape(probs, [target_batch, source_length])
probs = py_utils.HasShape(
probs, [multiplier, source_batch, source_length]
)

# Apply dropout to weights if applicable.
if not self.do_eval:
probs = _ApplyAttentionDropout(p, probs)

# Reshape probs to be of shape
# [target_batch/source_batch, source_batch, source_length]
probs_reshaped = tf.reshape(probs, [multiplier, source_batch, -1])
# Shape of returned probs is [target_batch, source_length]
probs_returned = tf.reshape(probs, [target_batch, source_length])
# Transpose probs to be of shape
# [source_batch, target_batch/source_batch, source_length]
probs_reshaped = tf.transpose(probs_reshaped, [1, 0, 2])
probs = tf.transpose(probs, [1, 0, 2])
# Batched matmul
# [source_batch, target_batch/source_batch, source_length] *
# [source_batch, source_length, context_dim] =
# [source_batch, target_batch/source_batch, context_dim]
summed = tf.matmul(probs_reshaped, source_contexts)
summed = tf.matmul(probs, source_contexts)

# summed is of shape
# [target_batch/source_batch, source_batch, context_dim]
summed = tf.transpose(summed, [1, 0, 2])

return tf.reshape(summed, [target_batch, -1]), probs
return tf.reshape(summed, [target_batch, -1]), probs_returned

# The source batch size equals to the target batch size.
def AttenSameBatchSize(
Expand Down Expand Up @@ -1036,7 +1033,8 @@ def AttenProbs(
* query_segment_id: [target_batch].
Returns:
logits [target_batch, source_time].
logits: [n, source_batch, time]
where n = target_batch // source_batch
"""
source_vecs = inputs.source_vecs

Expand Down Expand Up @@ -1121,14 +1119,9 @@ def AttenProbs(
'packed_input is False but query_segment_id is passed.'
)
source_padding = tf.transpose(source_padding, [2, 0, 1])
source_padding = tf.reshape(source_padding, [target_batch, -1])

# => [n, source_batch, time]
logits = tf.transpose(logits, [2, 0, 1])

# => [n * source_batch, time].
# This makes logits store content in the same order as query_vec.
logits = tf.reshape(logits, [target_batch, -1])
if p.atten_logit_cap is not None and p.atten_logit_cap > 0:
logits = py_utils.MaybeSoftCapLogits(logits, p.atten_logit_cap)
probs = self._PaddedSoftmax(logits, source_padding)
Expand Down Expand Up @@ -1193,11 +1186,12 @@ def Atten(
args.source_segment_id = source_segment_id
if query_segment_id is not None:
args.query_segment_id = query_segment_id
returned_probs = py_utils.CallDefun(AttenProbs, args)
returned_probs = py_utils.HasShape(returned_probs, [target_batch, time])
# [n, source_batch, time] where n = target_batch // source_batch
probs = py_utils.CallDefun(AttenProbs, args)
probs = py_utils.HasShape(probs, [-1, source_batch, time])
# [target_batch, time]
returned_probs = tf.reshape(probs, [-1, time])

# => [n, source_batch, time] where n = target_batch // source_batch
probs = tf.reshape(returned_probs, [-1, source_batch, time])
# => [source_batch, n, time].
probs = tf.transpose(probs, [1, 0, 2])

Expand Down Expand Up @@ -2455,36 +2449,34 @@ def Atten(
),
)
# Take out the padding states.
# _source_padding is of shape [sl, sb].
# reshaped to [sl, 1, sb].
# [sl, 1, sb]
source_padding = tf.expand_dims(source_padding, 1)
# [sl, tb/sb, sb]
per_step_source_padding = tf.reshape(
tf.transpose(per_step_source_padding), [-1, multiplier, sb]
)
if source_padding.dtype != tf.bool:
source_padding = source_padding > 0
if per_step_source_padding.dtype != tf.bool:
per_step_source_padding = per_step_source_padding > 0
# [sl, tb/sb, sb]
source_padding = tf.logical_or(source_padding, per_step_source_padding)

# Reshape logits to a matrix of shape [tb, sl] and takes the
# softmax to compute the probabilities.
logits = tf.transpose(tf.reshape(logits, [-1, tb]))
source_padding = tf.transpose(tf.reshape(source_padding, [-1, tb]))
# [tb/sb, sb, sl]
logits = tf.transpose(logits, [1, 2, 0])
source_padding = tf.transpose(source_padding, [1, 2, 0])
# Take the softmax to compute the probabilities.
probs = self._PaddedSoftmax(logits, source_padding)
# Reshape probs to be of shape [tb/sb, sb, sl].
probs_reshaped = tf.reshape(probs, [multiplier, sb, -1])
# [tb, sl]
probs_returned = tf.reshape(probs, [tb, -1])
# Transpose probs to be of shape [sb, tb/sb, sl]
probs_reshaped = tf.transpose(probs_reshaped, [1, 0, 2])
probs = tf.transpose(probs, [1, 0, 2])
# [sb, tb/sb, sl] * [sb, sl, context_dim] = [sb, tb/sb, context_dim]
summed = tf.matmul(
tf.cast(probs_reshaped, source_contexts.dtype),
source_contexts,
)
summed = tf.matmul(tf.cast(probs, source_contexts.dtype), source_contexts)
summed = self.QAct('atten_context', summed)
# summed is of shape [tb/sb, sb, context_dim]
summed = tf.transpose(summed, [1, 0, 2])
return tf.reshape(summed, [tb, -1]), probs
return tf.reshape(summed, [tb, -1]), probs_returned

def AttenSameBatchSize(
hidden_var,
Expand Down

0 comments on commit b26149e

Please sign in to comment.