Skip to content

Commit

Permalink
Merge pull request #16 from CUNY-CL/bar
Browse files Browse the repository at this point in the history
Improves progress bars.
  • Loading branch information
kylebgorman authored May 5, 2024
2 parents b9eb681 + f86fe7c commit 76ec8df
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 21 deletions.
96 changes: 77 additions & 19 deletions maxwell/sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def read_params(cls, filepath: str) -> ParamDict:

class StochasticEditDistance(abc.ABC):
params: ParamDict
_train_progress_bar: Optional[tqdm.tqdm]
_val_progress_bar: Optional[tqdm.tqdm]

def __init__(self, params):
"""SED model.
Expand All @@ -115,6 +117,8 @@ def __init__(self, params):
psum = self.params.sum()
if not numpy.isclose(0.0, psum):
raise SEDParameterError(f"Parameters do not sum to 1: {psum:.4f}")
self._train_progress_bar = None
self._val_progress_bar = None

@classmethod
def build_sed(
Expand Down Expand Up @@ -298,14 +302,14 @@ def log_likelihood(
targets: Iterable[Sequence[Any]],
) -> float:
"""Computes log likelihood."""
with tqdm.tqdm(
zip(sources, targets), total=len(sources), leave=False
) as pbar:
ll = []
pbar.set_description("Calculating log-likelihood")
for source, target in pbar:
ll.append(self.forward_evaluate(source, target)[-1, -1])
return float(numpy.mean(ll))
loglikes = []
self.val_progress_bar.total = len(sources)
self.on_validation_start()
for source, target in zip(sources, targets):
loglikes.append(self.forward_evaluate(source, target)[-1, -1])
self.on_validation_step_end()
self.on_validation_epoch_end()
return numpy.mean(loglikes)

def em(
self,
Expand All @@ -320,18 +324,19 @@ def em(
targets (Sequence[Any]): target strings.
epochs (int): number of EM epochs.
"""
loglike = numpy.NINF
gammas = ParamDict.from_params(self.params)
with tqdm.tqdm(zip(sources, targets), total=len(sources)) as pbar:
for epoch in range(epochs):
pbar.set_description(f"Epoch {epoch}")
pbar.set_postfix(loglike=loglike)
for source, target in pbar:
self.e_step(source, target, gammas) # Updates gammas.
self.m_step(gammas) # Updates gammas.
self.params.update_params(gammas) # Updates model parameters.
loglike = self.log_likelihood(sources, targets)
util.log_info(f"Final log-likelihood: {loglike:.4f}")
self.train_progress_bar.total = len(sources)
for epoch in range(epochs):
self.on_train_epoch_start(epoch)
for source, target in zip(sources, targets):
self.e_step(source, target, gammas) # Updates gammas.
self.on_train_step_end()
self.m_step(gammas) # Updates gammas.
self.params.update_params(gammas) # Updates model parameters.
loglike = self.log_likelihood(sources, targets)
self.on_train_epoch_end(loss=-loglike)
self.on_train_end()
self.on_validation_end()

def e_step(
self, source: Sequence[Any], target: Sequence[Any], gammas: ParamDict
Expand Down Expand Up @@ -523,3 +528,56 @@ def action_cost(self, action: actions.Edit) -> float:
(action.old, action.old), self.default
)
raise SEDActionError(f"Unknown action: {action}")

# The progress bar implementation is based on:
#
# https://github.com/Lightning-AI/pytorch-lightning/
# blob/master/src/lightning/pytorch/callbacks/progress/tqdm_progress.py

BAR_FORMAT = (
"{l_bar}{bar}| {n_fmt}/{total_fmt} "
"[{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
)

@property
def train_progress_bar(self) -> tqdm.tqdm:
if self._train_progress_bar is None:
self._train_progress_bar = tqdm.tqdm(
position=0, leave=True, bar_format=self.BAR_FORMAT
)
return self._train_progress_bar

def on_train_epoch_start(self, epoch: int) -> None:
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {epoch}")

def on_train_step_end(self) -> None:
self.train_progress_bar.update()

def on_train_epoch_end(self, loss: float) -> None:
self.train_progress_bar.set_postfix(loss=loss)
self.train_progress_bar.reset()

def on_train_end(self) -> None:
self.train_progress_bar.close()

@property
def val_progress_bar(self) -> tqdm.tqdm:
if self._val_progress_bar is None:
self._val_progress_bar = tqdm.tqdm(
position=1, leave=False, bar_format=self.BAR_FORMAT
)
return self._val_progress_bar

def on_validation_start(self) -> None:
self.val_progress_bar.initial = 0
self.val_progress_bar.set_description("Validating")

def on_validation_step_end(self) -> None:
self.val_progress_bar.update()

def on_validation_epoch_end(self) -> None:
self.val_progress_bar.reset()

def on_validation_end(self) -> None:
self.val_progress_bar.close()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ include = ["maxwell*"]

[project]
name = "maxwell"
version = "0.2.3.post1"
description = "Stochastic Edit Distance aligenr for string transduction"
version = "0.2.4"
description = "Stochastic Edit Distance aligner for string transduction"
readme = "README.md"
requires-python = "> 3.9"
license = { text = "Apache 2.0" }
Expand Down

0 comments on commit 76ec8df

Please sign in to comment.