diff --git a/lingvo/core/attention.py b/lingvo/core/attention.py index 7d16d7b1e..817c394bc 100644 --- a/lingvo/core/attention.py +++ b/lingvo/core/attention.py @@ -510,7 +510,6 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: """Generates probs.""" source_batch = py_utils.GetShape(inputs.source_vecs)[2] multiplier = py_utils.GetShape(inputs.query_vec)[1] - target_batch = multiplier * source_batch # Shape of summed is [sl, tb/sb, sb, hidden_dim]. summed = tf.tanh(inputs.source_vecs + inputs.query_vec) @@ -568,13 +567,10 @@ def AttenProbs(inputs: py_utils.NestedMap) -> tf.Tensor: tf.logging.warning( 'packed_input is False but query_segment_id is passed.' ) - source_padding = tf.reshape(source_padding, [-1, target_batch]) - source_padding = tf.transpose(source_padding) + source_padding = tf.transpose(source_padding, [1, 2, 0]) - # [source_length, target_batch] - logits = tf.reshape(logits, [-1, target_batch]) - # [target_batch, source_length] - logits = tf.transpose(logits) + # [multiplier, source_batch, source_length] + logits = tf.transpose(logits, [1, 2, 0]) # take the softmax to compute the probabilities. probs = self._PaddedSoftmax(logits, source_padding) return probs @@ -638,31 +634,32 @@ def Atten( args.source_segment_id = source_segment_id if query_segment_id is not None: args.query_segment_id = query_segment_id - # probs is of shape [target_batch, source_length] + # probs is of shape [multiplier, source_batch, source_length] probs = py_utils.CallDefun(AttenProbs, args) - probs = py_utils.HasShape(probs, [target_batch, source_length]) + probs = py_utils.HasShape( + probs, [multiplier, source_batch, source_length] + ) # Apply dropout to weights if applicable. if not self.do_eval: probs = _ApplyAttentionDropout(p, probs) - # Reshape probs to be of shape - # [target_batch/source_batch, source_batch, source_length] - probs_reshaped = tf.reshape(probs, [multiplier, source_batch, -1]) + # Shape of returned probs is [target_batch, source_length] + probs_returned = tf.reshape(probs, [target_batch, source_length]) # Transpose probs to be of shape # [source_batch, target_batch/source_batch, source_length] - probs_reshaped = tf.transpose(probs_reshaped, [1, 0, 2]) + probs = tf.transpose(probs, [1, 0, 2]) # Batched matmul # [source_batch, target_batch/source_batch, source_length] * # [source_batch, source_length, context_dim] = # [source_batch, target_batch/source_batch, context_dim] - summed = tf.matmul(probs_reshaped, source_contexts) + summed = tf.matmul(probs, source_contexts) # summed is of shape # [target_batch/source_batch, source_batch, context_dim] summed = tf.transpose(summed, [1, 0, 2]) - return tf.reshape(summed, [target_batch, -1]), probs + return tf.reshape(summed, [target_batch, -1]), probs_returned # The source batch size equals to the target batch size. def AttenSameBatchSize( @@ -1036,7 +1033,8 @@ def AttenProbs( * query_segment_id: [target_batch]. Returns: - logits [target_batch, source_time]. + logits: [n, source_batch, time] + where n = target_batch // source_batch """ source_vecs = inputs.source_vecs @@ -1121,14 +1119,9 @@ def AttenProbs( 'packed_input is False but query_segment_id is passed.' ) source_padding = tf.transpose(source_padding, [2, 0, 1]) - source_padding = tf.reshape(source_padding, [target_batch, -1]) # => [n, source_batch, time] logits = tf.transpose(logits, [2, 0, 1]) - - # => [n * source_batch, time]. - # This makes logits store content in the same order as query_vec. - logits = tf.reshape(logits, [target_batch, -1]) if p.atten_logit_cap is not None and p.atten_logit_cap > 0: logits = py_utils.MaybeSoftCapLogits(logits, p.atten_logit_cap) probs = self._PaddedSoftmax(logits, source_padding) @@ -1193,11 +1186,12 @@ def Atten( args.source_segment_id = source_segment_id if query_segment_id is not None: args.query_segment_id = query_segment_id - returned_probs = py_utils.CallDefun(AttenProbs, args) - returned_probs = py_utils.HasShape(returned_probs, [target_batch, time]) + # [n, source_batch, time] where n = target_batch // source_batch + probs = py_utils.CallDefun(AttenProbs, args) + probs = py_utils.HasShape(probs, [-1, source_batch, time]) + # [target_batch, time] + returned_probs = tf.reshape(probs, [-1, time]) - # => [n, source_batch, time] where n = target_batch // source_batch - probs = tf.reshape(returned_probs, [-1, source_batch, time]) # => [source_batch, n, time]. probs = tf.transpose(probs, [1, 0, 2]) @@ -2455,9 +2449,9 @@ def Atten( ), ) # Take out the padding states. - # _source_padding is of shape [sl, sb]. - # reshaped to [sl, 1, sb]. + # [sl, 1, sb] source_padding = tf.expand_dims(source_padding, 1) + # [sl, tb/sb, sb] per_step_source_padding = tf.reshape( tf.transpose(per_step_source_padding), [-1, multiplier, sb] ) @@ -2465,26 +2459,24 @@ def Atten( source_padding = source_padding > 0 if per_step_source_padding.dtype != tf.bool: per_step_source_padding = per_step_source_padding > 0 + # [sl, tb/sb, sb] source_padding = tf.logical_or(source_padding, per_step_source_padding) - # Reshape logits to a matrix of shape [tb, sl] and takes the - # softmax to compute the probabilities. - logits = tf.transpose(tf.reshape(logits, [-1, tb])) - source_padding = tf.transpose(tf.reshape(source_padding, [-1, tb])) + # [tb/sb, sb, sl] + logits = tf.transpose(logits, [1, 2, 0]) + source_padding = tf.transpose(source_padding, [1, 2, 0]) + # Take the softmax to compute the probabilities. probs = self._PaddedSoftmax(logits, source_padding) - # Reshape probs to be of shape [tb/sb, sb, sl]. - probs_reshaped = tf.reshape(probs, [multiplier, sb, -1]) + # [tb, sl] + probs_returned = tf.reshape(probs, [tb, -1]) # Transpose probs to be of shape [sb, tb/sb, sl] - probs_reshaped = tf.transpose(probs_reshaped, [1, 0, 2]) + probs = tf.transpose(probs, [1, 0, 2]) # [sb, tb/sb, sl] * [sb, sl, context_dim] = [sb, tb/sb, context_dim] - summed = tf.matmul( - tf.cast(probs_reshaped, source_contexts.dtype), - source_contexts, - ) + summed = tf.matmul(tf.cast(probs, source_contexts.dtype), source_contexts) summed = self.QAct('atten_context', summed) # summed is of shape [tb/sb, sb, context_dim] summed = tf.transpose(summed, [1, 0, 2]) - return tf.reshape(summed, [tb, -1]), probs + return tf.reshape(summed, [tb, -1]), probs_returned def AttenSameBatchSize( hidden_var,