Skip to content

Commit

Permalink
Merge pull request #14 from Adamits/faster-logsumexp
Browse files Browse the repository at this point in the history
replaces scipy.logsumexp with numpy.logaddexp
  • Loading branch information
bonham79 authored Apr 22, 2024
2 parents abd17f5 + 22cd4f0 commit 41d97d5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
16 changes: 7 additions & 9 deletions maxwell/sed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import numpy
import tqdm

from scipy import special

from . import actions, util

LARGE_NEG_CONST = -1e6
Expand Down Expand Up @@ -70,7 +68,7 @@ def sum(self) -> float:
for vs in (self.delta_sub, self.delta_ins, self.delta_del):
values.extend(vs.values())
# Uses sum of exponentiation to maintain logarithmic values.
return special.logsumexp(values)
return numpy.logaddexp.reduce(values)

@classmethod
def from_params(cls, other: ParamDict) -> ParamDict:
Expand Down Expand Up @@ -248,7 +246,7 @@ def forward_evaluate(
)
+ alpha[t - 1, v - 1]
)
alpha[t, v] = special.logsumexp(summands)
alpha[t, v] = numpy.logaddexp.reduce(summands)
alpha[T, V] += self.params.delta_eos
return alpha

Expand Down Expand Up @@ -292,7 +290,7 @@ def backward_evaluate(
)
+ beta[t + 1, v + 1]
)
beta[t, v] = special.logsumexp(summands)
beta[t, v] = numpy.logaddexp.reduce(summands)
return beta

def log_likelihood(
Expand Down Expand Up @@ -348,7 +346,7 @@ def e_step(
"""
alpha = self.forward_evaluate(source, target)
beta = self.backward_evaluate(source, target)
gammas.delta_eos = special.logsumexp([gammas.delta_eos, 0.0])
gammas.delta_eos = numpy.logaddexp(gammas.delta_eos, 0.0)
T = len(source)
V = len(target)
for t in range(T + 1):
Expand All @@ -358,7 +356,7 @@ def e_step(
tchar = target[v - 1]
stpair = schar, tchar
if t > 0 and schar in gammas.delta_del:
gammas.delta_del[schar] = special.logsumexp(
gammas.delta_del[schar] = numpy.logaddexp.reduce(
[
gammas.delta_del[schar],
alpha[t - 1, v]
Expand All @@ -367,7 +365,7 @@ def e_step(
]
)
if v > 0 and tchar in gammas.delta_ins:
gammas.delta_ins[tchar] = special.logsumexp(
gammas.delta_ins[tchar] = numpy.logaddexp.reduce(
[
gammas.delta_ins[tchar],
alpha[t, v - 1]
Expand All @@ -376,7 +374,7 @@ def e_step(
]
)
if t > 0 and v > 0 and stpair in gammas.delta_sub:
gammas.delta_sub[stpair] = special.logsumexp(
gammas.delta_sub[stpair] = numpy.logaddexp.reduce(
[
gammas.delta_sub[stpair],
alpha[t - 1, v - 1]
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include = ["maxwell*"]

[project]
name = "maxwell"
version = "0.2.2.post2"
version = "0.2.3"
description = "Stochastic Edit Distance aligenr for string transduction"
readme = "README.md"
requires-python = "> 3.9"
Expand All @@ -28,7 +28,6 @@ keywords = [
]
dependencies = [
"numpy >= 1.20.1",
"scipy >= 1.6",
"tqdm >= 4.64.1",
]
classifiers = [
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ build>=0.10.0
flake8>=3.9.2
numpy>=1.20.1
pytest>=7.4.0
scipy>=1.6
twine>=4.0.2
tqdm>=4.64.1

0 comments on commit 41d97d5

Please sign in to comment.