Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add lstm layer support #99

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions fvcore/nn/activation_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"aten::einsum": generic_activation_jit(),
"aten::matmul": generic_activation_jit(),
"aten::linear": generic_activation_jit(),
"aten::lstm": generic_activation_jit("lstm"),
"aten::mul": generic_activation_jit(),
"aten::mul_": generic_activation_jit(),
}


Expand Down
5 changes: 5 additions & 0 deletions fvcore/nn/flop_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
einsum_flop_jit,
elementwise_flop_counter,
linear_flop_jit,
lstm_flop_jit,
matmul_flop_jit,
norm_flop_counter,
)
Expand All @@ -29,8 +30,12 @@
"aten::_convolution": conv_flop_jit,
"aten::einsum": einsum_flop_jit,
"aten::matmul": matmul_flop_jit,
"aten::mul": elementwise_flop_counter(1),
"aten::mul_": elementwise_flop_counter(1),
"aten::multiply": elementwise_flop_counter(1),
"aten::mm": matmul_flop_jit,
"aten::linear": linear_flop_jit,
"aten::lstm": lstm_flop_jit,
# You might want to ignore BN flops due to inference-time fusion.
# Use `set_op_handle("aten::batch_norm", None)
"aten::batch_norm": batchnorm_flop_jit,
Expand Down
37 changes: 34 additions & 3 deletions fvcore/nn/jit_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def get_shape(val: Any) -> Optional[List[int]]:
return None


def get_values(vals: List[Any]) -> Optional[List[Any]]:
return [v.toIValue() for v in vals]


"""
Below are flop/activation counters for various ops. Every counter has the following signature:

Expand Down Expand Up @@ -63,7 +67,7 @@ def generic_activation_jit(op_name: Optional[str] = None) -> Handle:
"""

def _generic_activation_jit(
i: Any, outputs: List[Any]
inputs: Any, outputs: List[Any]
) -> Union[typing.Counter[str], Number]:
"""
This is a generic jit handle that counts the number of activations for any
Expand All @@ -73,8 +77,19 @@ def _generic_activation_jit(
ac_count = prod(out_shape)
if op_name is None:
return ac_count
else:
return Counter({op_name: ac_count})

if op_name == "lstm":
time_dim, batch_size, input_dim = get_shape(inputs[0])
*_, proj_size = get_shape(outputs[1])
*_, hidden_dim = get_shape(outputs[2])

*_, bias, lstm_layers, dropout, _, bidirectional, batch_first = get_values(inputs)

ac_count = 11 * proj_size + (0 if proj_size == hidden_dim else hidden_dim)

return ac_count * batch_size * (2 if bidirectional else 1) * lstm_layers * time_dim

return Counter({op_name: ac_count})

return _generic_activation_jit

Expand Down Expand Up @@ -110,6 +125,22 @@ def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
return flops


def lstm_flop_jit(inputs: List[Any], outputs: List[Any]):
"""
Count flops for the aten::lstm operator.
"""
time_dim, batch_size, input_dim = get_shape(inputs[0])
*_, proj_size = get_shape(outputs[1])
*_, hidden_dim = get_shape(outputs[2])

*_, _, lstm_layers, _, _, bidirectional, batch_first = get_values(inputs)

mm_flops = 4 * ((input_dim + hidden_dim) * proj_size) + (hidden_dim * proj_size if hidden_dim != proj_size else 0)
mul_flops = 3 * proj_size

return (mm_flops + mul_flops) * batch_size * (2 if bidirectional else 1) * lstm_layers * time_dim


def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the bmm operation.
Expand Down
152 changes: 152 additions & 0 deletions tests/test_activation_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,36 @@ def get_gt_activation(self, x: torch.Tensor) -> Tuple[int, int, int]:
return (count1, count2, count3)


class LSTMNet(nn.Module):
"""
A network with LSTM layers. This is used for testing flop
count for LSTM layers.
"""

def __init__(
self,
input_dim,
hidden_dim,
lstm_layers,
bias,
batch_first,
bidirectional,
proj_size
) -> None:
super(LSTMNet, self).__init__()
self.lstm = nn.LSTM(input_dim,
hidden_dim,
lstm_layers,
bias=bias,
batch_first=batch_first,
bidirectional=bidirectional,
proj_size=proj_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.lstm(x)
return x


class TestActivationCountAnalysis(unittest.TestCase):
"""
Unittest for activation_count.
Expand Down Expand Up @@ -99,6 +129,128 @@ def test_linear(self) -> None:
gt_dict, ac_dict, "FC layer failed to pass the activation count test."
)

def test_lstm(self) -> None:
"""
Test a network with a single fully connected layer.
"""

class LSTMCellNet(nn.Module):
"""
A network with a single LSTM cell. This is used for testing if the flop
count of LSTM layers equals the flop count of an LSTM cell for one time-step.
"""

def __init__(
self,
input_dim,
hidden_dim,
bias: bool
) -> None:
super(LSTMCellNet, self).__init__()
self.lstm_cell = nn.LSTMCell(input_size=input_dim,
hidden_size=hidden_dim,
bias=bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.lstm_cell(x[0])
return x

def _test_lstm(
batch_size,
time_dim,
input_dim,
hidden_dim,
lstm_layers,
proj_size,
bidirectional=False,
bias=True,
batch_first=True,
):
lstmNet = LSTMNet(input_dim, hidden_dim, lstm_layers, bias, batch_first, bidirectional, proj_size)
x = torch.randn(time_dim, batch_size, input_dim)
ac_dict, _ = activation_count(lstmNet, (x,))

lstmcellNet = LSTMCellNet(input_dim, hidden_dim, bias)
lstmcell_ac_dict, _ = activation_count(lstmcellNet, (x,))

if time_dim == 1 and lstm_layers == 1:
gt_dict = defaultdict(float)
gt_dict["lstm"] = sum(e for _, e in lstmcell_ac_dict.items())
elif time_dim == 5 and lstm_layers == 5 and bidirectional:
gt_dict = defaultdict(float)
gt_dict["lstm"] = sum(e for _, e in lstmcell_ac_dict.items()) * time_dim * lstm_layers * 2
elif time_dim == 5 and lstm_layers == 5:
gt_dict = defaultdict(float)
gt_dict["lstm"] = sum(e for _, e in lstmcell_ac_dict.items()) * time_dim * lstm_layers
else:
raise ValueError(
f'No test implemented for parameters "time_dim": {time_dim}, "lstm_layers": {lstm_layers}'
f' and "bidirectional": {bidirectional}.'
)

self.assertAlmostEqual(
ac_dict['lstm'],
gt_dict['lstm'],
msg="LSTM layer failed to pass the flop count test.",
)

# Test LSTM for 1 layer and 1 time step.
batch_size1 = 5
time_dim1 = 1
input_dim1 = 3
hidden_dim1 = 4
lstm_layers1 = 1
bidirectional1 = False
proj_size1 = 0

_test_lstm(
batch_size1,
time_dim1,
input_dim1,
hidden_dim1,
lstm_layers1,
proj_size1,
bidirectional1,
)

# Test LSTM for 5 layers and 5 time steps.
batch_size2 = 5
time_dim2 = 5
input_dim2 = 3
hidden_dim2 = 4
lstm_layers2 = 5
bidirectional2 = False
proj_size2 = 0

_test_lstm(
batch_size2,
time_dim2,
input_dim2,
hidden_dim2,
lstm_layers2,
proj_size2,
bidirectional2,
)

# Test bidirectional LSTM for 5 layers and 5 time steps.
batch_size3 = 5
time_dim3 = 5
input_dim3 = 3
hidden_dim3 = 4
lstm_layers3 = 5
bidirectional3 = True
proj_size3 = 0

_test_lstm(
batch_size3,
time_dim3,
input_dim3,
hidden_dim3,
lstm_layers3,
proj_size3,
bidirectional3,
)

def test_supported_ops(self) -> None:
"""
Test the activation count for user provided handles.
Expand Down
Loading