Skip to content

Commit

Permalink
Add one flag to allow to disable use_padding for ffn.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673374355
  • Loading branch information
lingvo-bot authored and copybara-github committed Sep 11, 2024
1 parent f910d4f commit 1f53897
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions lingvo/core/batch_major_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8611,6 +8611,7 @@ def Params(cls):
'y = x + ff_residual_weight * F(x) in Feedforward layer, else '
'y = ff_residual_weight * F(x). This is experimental and could be '
'removed in the future. See b/174568214.')
p.Define('ff_use_paddings', True, 'If true, zero out padding tokens in FFN.')
p.Define('atten_apply_residual', True, 'If true, '
'y = x + F(x) in attention layer, else y = F(x). This is '
'experimental and could be removed in the future. See '
Expand Down Expand Up @@ -8883,9 +8884,11 @@ def MoE(self, name, is_causal=False, ff_hidden_dim=None):
*sub_list)

def GatedFeedforward(self, name, is_causal=False, ff_hidden_dim=None,
activation_fn=tf.nn.relu, use_paddings=True):
activation_fn=tf.nn.relu, use_paddings=None):
del is_causal
p = self.params
use_paddings = p.ff_use_paddings if use_paddings is None else use_paddings

if ff_hidden_dim is None:
ff_hidden_dim = p.ff_hidden_dim

Expand Down Expand Up @@ -8926,9 +8929,10 @@ def GatedFn(x, y):
*sub_list)

def Feedforward(self, name, is_causal=False, ff_hidden_dim=None,
qdomain=None, use_paddings=True):
qdomain=None, use_paddings=None):
del is_causal
p = self.params
use_paddings = p.ff_use_paddings if use_paddings is None else use_paddings
if ff_hidden_dim is None:
ff_hidden_dim = p.ff_hidden_dim
if p.device_mesh is not None:
Expand Down Expand Up @@ -8972,7 +8976,7 @@ def Feedforward(self, name, is_causal=False, ff_hidden_dim=None,
if use_paddings:
sub_list += [
('added,i.paddings->o.vec', self._Pad('pad')),
('i.paddings->o.paddings', self._Id('id'))
('i.paddings->o.paddings', self._Id('id')),
]
else:
sub_list.append(
Expand Down

0 comments on commit 1f53897

Please sign in to comment.