Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding fixes so transducer can work again #247

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion yoyodyne/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
self.decoder = self.get_decoder()
# Saves hyperparameters for PL checkpointing.
self.save_hyperparameters(
ignore=["source_encoder", "decoder", "expert", "features_encoder"]
ignore=["source_encoder", "decoder", "features_encoder"]
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
)
# Logs the module names.
util.log_info(f"Model: {self.name}")
Expand Down
121 changes: 63 additions & 58 deletions yoyodyne/models/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
from torch.utils import data

from .. import defaults
from ..data import indexes
bonham79 marked this conversation as resolved.
Show resolved Hide resolved


class ActionError(Exception):
pass


class AlignerError(Exception):
pass


class ActionVocabulary:
"""Manages encoding of action vocabulary for transducer training."""

Expand All @@ -32,19 +37,24 @@ class ActionVocabulary:
start_vocab_idx: int
target_characters: Set[Any]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Any and not str?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically the edit actions can accept any hashable symbol. So strings, ints, even tuples. So any is a better representation of its coverage given that maxwell is also symbol agnostic.


def __init__(self, unk_idx: int, i2w=None):
def __init__(self, index: indexes.Index):
self.target_characters = set()
self.i2w = [
actions.Start(),
actions.End(),
actions.ConditionalDel(),
actions.ConditionalCopy(),
]
self.start_vocab_idx = len(self.i2w)
if i2w:
self.i2w.extend(i2w)
self.w2i = {w: i for i, w in enumerate(self.i2w)}
self.target_characters = set()
self.encode_actions([unk_idx]) # Sets unknown character decoding.
# Use index from dataset to create action vocabulary.
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
self.encode_actions([index(t) for t in index.target_vocabulary])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we start an issue to document exactly what's going on here? encode_actions converts vocab into Actions and stored them in a separate vocabulary, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it nees to track all potential edit actions for a given symbol.

self.encode_actions(
[index.unk_idx]
) # Sets unknown character decoding.
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
# Add source characters if index has tied embeddings.
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
if index.tie_embeddings:
self.encode_actions([index(s) for s in index.source_vocabulary])

def encode(self, symb: actions.Edit) -> int:
"""Returns index referencing symbol in encoding table.
Expand Down Expand Up @@ -240,7 +250,9 @@ class Expert(abc.ABC):
oracle_factor: int
roll_in: int

def __init__(self, actions, aligner, oracle_factor=defaults.ORACLE_FACTOR):
def __init__(
self, actions, aligner=None, oracle_factor=defaults.ORACLE_FACTOR
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
):
self.actions = actions
self.oracle_factor = oracle_factor
self.roll_in = 1
Expand Down Expand Up @@ -296,6 +308,12 @@ def find_valid_actions(

def roll_in_schedule(self, epoch: int) -> float:
"""Gets probability of sampling from oracle given current epoch."""
if self.aligner is None:
raise AlignerError(
"""Expert called `roll_in_schedule` but there
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
is no aligner present to allow oracle
predictions!"""
)
self.roll_in = 1 - self.oracle_factor / (
self.oracle_factor + numpy.exp(epoch / self.oracle_factor)
)
Expand Down Expand Up @@ -325,6 +343,11 @@ def roll_out(
Returns:
Dict[Edit, float]: edit actions and their respective scores.
"""
if self.aligner is None:
raise AlignerError(
"""Expert called 'roll_out' but no aligner was instantiated.
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
Check the parameters passed to the expert module."""
)
costs_to_go = {}
for action_prefix in action_prefixes:
suffix_begin = action_prefix.prefix.alignment
Expand Down Expand Up @@ -408,82 +431,64 @@ def get_expert(
train_data: data.Dataset,
epochs: int = defaults.ORACLE_EM_EPOCHS,
oracle_factor: int = defaults.ORACLE_FACTOR,
sed_params_path: str = None,
sed_params_path: str = "",
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
) -> Expert:
"""Generates expert object for training transducer.

Args:
data (data.Dataset): dataset for generating expert vocabulary.
epochs (int): number of EM epochs.
sched_factor (float): scaling factor to determine rate of
oracle_factor (float): scaling factor to determine rate of
expert rollout sampling.
sed_params_path (str): path to read/write location of sed parameters.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, but can you remove the hard line breaks here? That would cause problems with tools like Sphinx that can parse args lists in docstrings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the best way to preserve formatting then? I want to have something along the lines of a case statement for readibility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't within the arglist, so move it out of the arglist. Something like this:

def get_expert(
    train_data: data.Dataset,
    epochs: int = defaults.ORACLE_EM_EPOCHS,
    oracle_factor: int = defaults.ORACLE_FACTOR,
    sed_params_path: str = "",
) -> Expert:
    """Generates expert object for training transducer.

    * If epochs > 0, sed_params_path is a write path.
    * If epochs == 0, sed_params_path is a read path.
    * If it is an empty string, then it creates a 'dummy' expert.

    Args:
        data (data.Dataset): dataset for generating expert vocabulary.
        epochs (int): number of EM epochs.
        oracle_factor (float): scaling factor to determine rate of
            expert rollout sampling.
        sed_params_path (str): path to read/write location of sed parameters.

    Returns:
        expert.Expert.
    """

If epochs > 0, this is a write path.
If epochs == 0, this is a read path.
If empty string then creates 'dummy' expert.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here or in the README about what behavior a 'dummy' expert entails.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed dummy expert. Not important for this go around.


Returns:
expert.Expert.
"""

# TODO: Figure out a way to avoid these functions.

def _generate_data_and_encode_vocabulary(
data: data.Dataset, actions: ActionVocabulary
def _generate_data(
data: data.Dataset,
) -> Iterator[Tuple[List[int], List[int]]]:
"""Function to manage data encoding while aligning SED."
"""Helper function to manage data encoding for SED."

SED training over the default data sampling is expensive.
Training is quicker if tensors are converted to lists.
For efficiency, we encode action vocabulary simultaneously.
We want just the encodings without BOS or EOS tokens. This
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a more general comment here before this for the unfamiliar user: this basically converts the dataset into a format that is usable by the Expert -- is that right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just want the raw symbols string. We don't want to worry about BOS and EOS being encoded by maxwell.

What comment would make sense to you? (I'm too familiar for a good user friendly doc string.)

encodes just raw source-target text for the Maxwell library.
bonham79 marked this conversation as resolved.
Show resolved Hide resolved

Args:
data (data.Dataset): Dataset to iterate over.
actions (ActionVocabulary): Vocabulary object
to encode actions for expert.

Returns:
Iterator[Tuple[List[int], List[int]]]: Iterator that
yields list version of source and target entries
in dataset.
"""
for item in data:
# Dataset encodes BOW and EOW symbols for source. EOW
# for target. Removes these for SED training.
source = item.source.tolist()[1:-1]
target = item.target.tolist()[:-1]
actions.encode_actions(target)
if data.index.tie_embeddings:
actions.encode_actions(source)
yield source, target

def _encode_action_vocabulary(
data: data.Dataset, actions: ActionVocabulary
) -> None:
"""Encodes action vocabulary for expert oracle.

For instantiating SED objects from file.

Args:
data (data.Dataset): Dataset to iterate over.
actions (ActionVocabulary): Vocabulary object
to encode actions for expert.
"""
for item in data:
# Ignores last symbol since EOW.
target = item.target.tolist()[:-1]
actions.encode_actions(target)
if data.index.tie_embeddings:
source = item.source.tolist()[1:-1]
actions.encode_actions(source)

actions = ActionVocabulary(unk_idx=train_data.index.unk_idx)
assert data.has_target, """Passed dataset with no target to expert
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are exceptions elsewhere, why is this an assertion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more of a sanity check thing to me. Makes more sense to write a single assert line than defining an error class, doing a coinditional check, then raising. Just an economy thing I guess.

module, cannot perform SED!"""
for sampl in data.samples:
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
source, target = sampl[0], sampl[-1]
yield [data.index(s) for s in source], [
data.index(t) for t in target
]

actions = ActionVocabulary(train_data.index)
if sed_params_path:
sed_params = sed.ParamDict.read_params(sed_params_path)
sed_aligner = sed.StochasticEditDistance(sed_params)
# Loads vocabulary into action vocabulary.
_encode_action_vocabulary(train_data, actions)
else:
sed_aligner = sed.StochasticEditDistance.fit_from_data(
_generate_data_and_encode_vocabulary(train_data, actions),
epochs=epochs,
if epochs:
sed_aligner = sed.StochasticEditDistance.fit_from_data(
_generate_data(train_data),
epochs=epochs,
)
sed_aligner.params.write_params(sed_params_path)
else:
sed_params = sed.ParamDict.read_params(sed_params_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding correctly that in order to use an existing sed, the user would specify 0 or None for oracle_em_epochs? Avoiding rerunning em every time is a great feature, but I wonder if we can think through a more intuitive user interface for it, this feels a bit buried. Maybe you've already put some thought into this though so please let me know if you think this is already a good interface.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I see the comment in the method header. I missed it before.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just added a bool to the train script that's triggered if a sed file already exists. I think this is a bit cleaner. Thoughts?

sed_aligner = sed.StochasticEditDistance(sed_params)
return Expert(
actions, aligner=sed_aligner, oracle_factor=oracle_factor
)
return Expert(actions, sed_aligner, oracle_factor=oracle_factor)
else:
return Expert(actions)


def add_argparse_args(parser: argparse.ArgumentParser) -> None:
Expand Down
29 changes: 20 additions & 9 deletions yoyodyne/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ def __init__(
*args,
**kwargs,
):
# Gets number of non-target symbols.
source_vocab_size = kwargs["vocab_size"] - kwargs["target_vocab_size"]
# This is the size of the shared embedding matrix.
# It must contain every possible source AND target symbol.
kwargs["vocab_size"] = source_vocab_size + len(expert.actions)
# Alternate outputs than dataset targets.
kwargs["target_vocab_size"] = len(expert.actions)
"""Initializes transducer model.

Args:
expert (expert.Expert): oracle that guides training for transducer.
*args: passed to superclass.
**kwargs: passed to superclass.
"""
super().__init__(*args, **kwargs)
# Model specific variables.
self.vocab_offset = self.vocab_size - self.target_vocab_size
self.expert = expert # Oracle to train model.
self.actions = self.expert.actions
self.substitutions = self.actions.substitutions
Expand Down Expand Up @@ -195,8 +196,10 @@ def decode(
not_complete.to(self.device), action_count + 1, action_count
)
# Decoding.
# We offset the action idx by the symbol vocab size so that we
# can index into the shared embeddings matrix.
decoder_output = self.decoder(
last_action.unsqueeze(dim=1),
last_action.unsqueeze(dim=1) + self.vocab_offset,
last_hiddens,
encoder_out,
# To accomodate LSTMDecoder. See encoder_mask behavior.
Expand Down Expand Up @@ -557,11 +560,19 @@ def predict_step(self, batch: Tuple[torch.tensor], batch_idx: int) -> Dict:

def convert_prediction(self, prediction: List[List[int]]) -> torch.Tensor:
"""Converts prediction values to tensor for evaluator compatibility."""
# FIXME: the two steps below may be partially redundant.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the "two steps" referred to here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's from Adam's PR, no comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, let's remove it then since we don't know what it means.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is copied from #233

iirc I meant that looping and stacking predictions, and calling util.pad_tensor_after_eos may do some redundant things that could be cleaned up at some point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should we leave the comment to clean up later or just remove now?

# TODO: Clean this up and make it more efficient.
max_len = len(max(prediction, key=len))
for i, pred in enumerate(prediction):
pad = [self.actions.end_idx] * (max_len - len(pred))
pred.extend(pad)
prediction[i] = torch.tensor(pred, dtype=torch.int)
prediction = torch.stack(prediction)
# Uses the same util that all other models use.
# This turns all symbols after the first EOS into PADs
# so prediction tensors match gold tensors.
return util.pad_tensor_after_eos(
torch.tensor(prediction),
prediction,
self.end_idx,
self.pad_idx,
)
Expand Down
18 changes: 15 additions & 3 deletions yoyodyne/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ def get_model_from_argparse_args(
datamodule.train_dataloader().dataset,
epochs=args.oracle_em_epochs,
oracle_factor=args.oracle_factor,
sed_params_path=args.sed_params,
sed_params_path=(
bonham79 marked this conversation as resolved.
Show resolved Hide resolved
args.sed_params
if args.sed_params
else f"{args.model_dir}/{args.experiment}/sed.pkl"
),
)
if args.arch in ["transducer"]
else None
Expand Down Expand Up @@ -273,8 +277,16 @@ def get_model_from_argparse_args(
source_attention_heads=args.source_attention_heads,
source_encoder_cls=source_encoder_cls,
start_idx=datamodule.index.start_idx,
target_vocab_size=datamodule.index.target_vocab_size,
vocab_size=datamodule.index.vocab_size,
target_vocab_size=(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now we orchestrate the vocab sizes once in the trainer, which already has a handle on the initialized expert? This is much nicer than w/e I was trying to do before.

It still somehow feels clunky I think, but that is an effect of the single embeddings matrix updates not easily aligning to the expert.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's still a bit clunky but at least the clunkiness is occuring in the train script (which is localized clumsiness). Probably a good next issue is to break up the train script some more so it's less of a monolith.

len(expert.actions)
if expert is not None
else datamodule.index.target_vocab_size
),
vocab_size=(
datamodule.index.vocab_size + len(expert.actions)
if expert is not None
else datamodule.index.vocab_size
),
)


Expand Down