Skip to content

Commit

Permalink
Make atten_tpl configurable
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572247771
  • Loading branch information
lingvo-bot authored and copybara-github committed Oct 10, 2023
1 parent 8f1bbb3 commit 8332fef
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lingvo/core/batch_major_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7260,6 +7260,9 @@ def Params(cls):
'expert_capacity_dim', 0,
'If not None, num_groups will be adjusted so that there will be at '
'least min_group_size tokens in each group.')
p.Define(
'atten_tpl', MultiHeadedAttention.Params(),
'Multi-Headed Dot-Product Attention default params.')
# SPMD partition related params.
#
# d - model_dim
Expand Down Expand Up @@ -7380,7 +7383,7 @@ def _MultiHeadedAtten(self, name, num_heads=None,
if num_heads is None:
num_heads = p.num_heads

atten_p = MultiHeadedAttention.Params().Set(
atten_p = p.atten_tpl.Copy().Set(
name=name,
input_dim=p.model_dim,
hidden_dim=p.attention_hidden_dim or p.model_dim,
Expand Down

0 comments on commit 8332fef

Please sign in to comment.