From 9b1168f59ba6a521ce35ae03042c1821b189cc9f Mon Sep 17 00:00:00 2001 From: HendrikKlug-synthara Date: Tue, 14 Dec 2021 16:36:01 +0100 Subject: [PATCH] add lstm layer support --- fvcore/nn/activation_count.py | 3 + fvcore/nn/flop_count.py | 5 ++ fvcore/nn/jit_handles.py | 37 +++++++- tests/test_activation_count.py | 152 ++++++++++++++++++++++++++++++++ tests/test_flop_count.py | 153 +++++++++++++++++++++++++++++++++ 5 files changed, 347 insertions(+), 3 deletions(-) diff --git a/fvcore/nn/activation_count.py b/fvcore/nn/activation_count.py index 55e6e8a..1a7d77a 100644 --- a/fvcore/nn/activation_count.py +++ b/fvcore/nn/activation_count.py @@ -19,6 +19,9 @@ "aten::einsum": generic_activation_jit(), "aten::matmul": generic_activation_jit(), "aten::linear": generic_activation_jit(), + "aten::lstm": generic_activation_jit("lstm"), + "aten::mul": generic_activation_jit(), + "aten::mul_": generic_activation_jit(), } diff --git a/fvcore/nn/flop_count.py b/fvcore/nn/flop_count.py index 6e5043e..cbb5ef2 100644 --- a/fvcore/nn/flop_count.py +++ b/fvcore/nn/flop_count.py @@ -17,6 +17,7 @@ einsum_flop_jit, elementwise_flop_counter, linear_flop_jit, + lstm_flop_jit, matmul_flop_jit, norm_flop_counter, ) @@ -29,8 +30,12 @@ "aten::_convolution": conv_flop_jit, "aten::einsum": einsum_flop_jit, "aten::matmul": matmul_flop_jit, + "aten::mul": elementwise_flop_counter(1), + "aten::mul_": elementwise_flop_counter(1), + "aten::multiply": elementwise_flop_counter(1), "aten::mm": matmul_flop_jit, "aten::linear": linear_flop_jit, + "aten::lstm": lstm_flop_jit, # You might want to ignore BN flops due to inference-time fusion. # Use `set_op_handle("aten::batch_norm", None) "aten::batch_norm": batchnorm_flop_jit, diff --git a/fvcore/nn/jit_handles.py b/fvcore/nn/jit_handles.py index 3d76f21..7dee98d 100644 --- a/fvcore/nn/jit_handles.py +++ b/fvcore/nn/jit_handles.py @@ -36,6 +36,10 @@ def get_shape(val: Any) -> Optional[List[int]]: return None +def get_values(vals: List[Any]) -> Optional[List[Any]]: + return [v.toIValue() for v in vals] + + """ Below are flop/activation counters for various ops. Every counter has the following signature: @@ -63,7 +67,7 @@ def generic_activation_jit(op_name: Optional[str] = None) -> Handle: """ def _generic_activation_jit( - i: Any, outputs: List[Any] + inputs: Any, outputs: List[Any] ) -> Union[typing.Counter[str], Number]: """ This is a generic jit handle that counts the number of activations for any @@ -73,8 +77,19 @@ def _generic_activation_jit( ac_count = prod(out_shape) if op_name is None: return ac_count - else: - return Counter({op_name: ac_count}) + + if op_name == "lstm": + time_dim, batch_size, input_dim = get_shape(inputs[0]) + *_, proj_size = get_shape(outputs[1]) + *_, hidden_dim = get_shape(outputs[2]) + + *_, bias, lstm_layers, dropout, _, bidirectional, batch_first = get_values(inputs) + + ac_count = 11 * proj_size + (0 if proj_size == hidden_dim else hidden_dim) + + return ac_count * batch_size * (2 if bidirectional else 1) * lstm_layers * time_dim + + return Counter({op_name: ac_count}) return _generic_activation_jit @@ -110,6 +125,22 @@ def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: return flops +def lstm_flop_jit(inputs: List[Any], outputs: List[Any]): + """ + Count flops for the aten::lstm operator. + """ + time_dim, batch_size, input_dim = get_shape(inputs[0]) + *_, proj_size = get_shape(outputs[1]) + *_, hidden_dim = get_shape(outputs[2]) + + *_, _, lstm_layers, _, _, bidirectional, batch_first = get_values(inputs) + + mm_flops = 4 * ((input_dim + hidden_dim) * proj_size) + (hidden_dim * proj_size if hidden_dim != proj_size else 0) + mul_flops = 3 * proj_size + + return (mm_flops + mul_flops) * batch_size * (2 if bidirectional else 1) * lstm_layers * time_dim + + def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: """ Count flops for the bmm operation. diff --git a/tests/test_activation_count.py b/tests/test_activation_count.py index 53cdd04..e32cdfa 100644 --- a/tests/test_activation_count.py +++ b/tests/test_activation_count.py @@ -44,6 +44,36 @@ def get_gt_activation(self, x: torch.Tensor) -> Tuple[int, int, int]: return (count1, count2, count3) +class LSTMNet(nn.Module): + """ + A network with LSTM layers. This is used for testing flop + count for LSTM layers. + """ + + def __init__( + self, + input_dim, + hidden_dim, + lstm_layers, + bias, + batch_first, + bidirectional, + proj_size + ) -> None: + super(LSTMNet, self).__init__() + self.lstm = nn.LSTM(input_dim, + hidden_dim, + lstm_layers, + bias=bias, + batch_first=batch_first, + bidirectional=bidirectional, + proj_size=proj_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.lstm(x) + return x + + class TestActivationCountAnalysis(unittest.TestCase): """ Unittest for activation_count. @@ -99,6 +129,128 @@ def test_linear(self) -> None: gt_dict, ac_dict, "FC layer failed to pass the activation count test." ) + def test_lstm(self) -> None: + """ + Test a network with a single fully connected layer. + """ + + class LSTMCellNet(nn.Module): + """ + A network with a single LSTM cell. This is used for testing if the flop + count of LSTM layers equals the flop count of an LSTM cell for one time-step. + """ + + def __init__( + self, + input_dim, + hidden_dim, + bias: bool + ) -> None: + super(LSTMCellNet, self).__init__() + self.lstm_cell = nn.LSTMCell(input_size=input_dim, + hidden_size=hidden_dim, + bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.lstm_cell(x[0]) + return x + + def _test_lstm( + batch_size, + time_dim, + input_dim, + hidden_dim, + lstm_layers, + proj_size, + bidirectional=False, + bias=True, + batch_first=True, + ): + lstmNet = LSTMNet(input_dim, hidden_dim, lstm_layers, bias, batch_first, bidirectional, proj_size) + x = torch.randn(time_dim, batch_size, input_dim) + ac_dict, _ = activation_count(lstmNet, (x,)) + + lstmcellNet = LSTMCellNet(input_dim, hidden_dim, bias) + lstmcell_ac_dict, _ = activation_count(lstmcellNet, (x,)) + + if time_dim == 1 and lstm_layers == 1: + gt_dict = defaultdict(float) + gt_dict["lstm"] = sum(e for _, e in lstmcell_ac_dict.items()) + elif time_dim == 5 and lstm_layers == 5 and bidirectional: + gt_dict = defaultdict(float) + gt_dict["lstm"] = sum(e for _, e in lstmcell_ac_dict.items()) * time_dim * lstm_layers * 2 + elif time_dim == 5 and lstm_layers == 5: + gt_dict = defaultdict(float) + gt_dict["lstm"] = sum(e for _, e in lstmcell_ac_dict.items()) * time_dim * lstm_layers + else: + raise ValueError( + f'No test implemented for parameters "time_dim": {time_dim}, "lstm_layers": {lstm_layers}' + f' and "bidirectional": {bidirectional}.' + ) + + self.assertAlmostEqual( + ac_dict['lstm'], + gt_dict['lstm'], + msg="LSTM layer failed to pass the flop count test.", + ) + + # Test LSTM for 1 layer and 1 time step. + batch_size1 = 5 + time_dim1 = 1 + input_dim1 = 3 + hidden_dim1 = 4 + lstm_layers1 = 1 + bidirectional1 = False + proj_size1 = 0 + + _test_lstm( + batch_size1, + time_dim1, + input_dim1, + hidden_dim1, + lstm_layers1, + proj_size1, + bidirectional1, + ) + + # Test LSTM for 5 layers and 5 time steps. + batch_size2 = 5 + time_dim2 = 5 + input_dim2 = 3 + hidden_dim2 = 4 + lstm_layers2 = 5 + bidirectional2 = False + proj_size2 = 0 + + _test_lstm( + batch_size2, + time_dim2, + input_dim2, + hidden_dim2, + lstm_layers2, + proj_size2, + bidirectional2, + ) + + # Test bidirectional LSTM for 5 layers and 5 time steps. + batch_size3 = 5 + time_dim3 = 5 + input_dim3 = 3 + hidden_dim3 = 4 + lstm_layers3 = 5 + bidirectional3 = True + proj_size3 = 0 + + _test_lstm( + batch_size3, + time_dim3, + input_dim3, + hidden_dim3, + lstm_layers3, + proj_size3, + bidirectional3, + ) + def test_supported_ops(self) -> None: """ Test the activation count for user provided handles. diff --git a/tests/test_flop_count.py b/tests/test_flop_count.py index 8158bda..535be6f 100644 --- a/tests/test_flop_count.py +++ b/tests/test_flop_count.py @@ -47,6 +47,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class LSTMNet(nn.Module): + """ + A network with LSTM layers. This is used for testing flop + count for LSTM layers. + """ + + def __init__( + self, + input_dim, + hidden_dim, + lstm_layers, + bias, + batch_first, + bidirectional, + proj_size + ) -> None: + super(LSTMNet, self).__init__() + self.lstm = nn.LSTM(input_dim, + hidden_dim, + lstm_layers, + bias=bias, + batch_first=batch_first, + bidirectional=bidirectional, + proj_size=proj_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.lstm(x) + return x + + class ConvNet(nn.Module): """ A network with a single convolution layer. This is used for testing flop @@ -310,6 +340,129 @@ def test_linear(self) -> None: "Fully connected layer failed to pass the flop count test.", ) + def test_lstm(self) -> None: + """ + Test a network with a single fully connected layer. + """ + + class LSTMCellNet(nn.Module): + """ + A network with a single LSTM cell. This is used for testing if the flop + count of LSTM layers equals the flop count of an LSTM cell for one time-step. + """ + + def __init__( + self, + input_dim, + hidden_dim, + bias: bool + ) -> None: + super(LSTMCellNet, self).__init__() + self.lstm_cell = nn.LSTMCell(input_size=input_dim, + hidden_size=hidden_dim, + bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.lstm_cell(x[0]) + return x + + def _test_lstm( + batch_size, + time_dim, + input_dim, + hidden_dim, + lstm_layers, + proj_size, + bidirectional=False, + bias=True, + batch_first=True, + ): + lstmNet = LSTMNet(input_dim, hidden_dim, lstm_layers, bias, batch_first, bidirectional, proj_size) + x = torch.randn(time_dim, batch_size, input_dim) + flop_dict, _ = flop_count(lstmNet, (x,)) + + lstmcellNet = LSTMCellNet(input_dim, hidden_dim, bias) + lstmcell_flop_dict, _ = flop_count(lstmcellNet, (x,)) + + if time_dim == 1 and lstm_layers == 1: + gt_dict = defaultdict(float) + gt_dict["lstm"] = sum(e for _, e in lstmcell_flop_dict.items()) + elif time_dim == 5 and lstm_layers == 5 and bidirectional: + gt_dict = defaultdict(float) + gt_dict["lstm"] = sum(e for _, e in lstmcell_flop_dict.items()) * time_dim * lstm_layers * 2 + elif time_dim == 5 and lstm_layers == 5: + gt_dict = defaultdict(float) + gt_dict["lstm"] = sum(e for _, e in lstmcell_flop_dict.items()) * time_dim * lstm_layers + else: + raise ValueError( + f'No test implemented for parameters "time_dim": {time_dim}, "lstm_layers": {lstm_layers}' + f' and "bidirectional": {bidirectional}.' + ) + + self.assertDictEqual( + flop_dict, + gt_dict, + "LSTM layer failed to pass the flop count test.", + ) + + # Test LSTM for 1 layer and 1 time step. + batch_size1 = 5 + time_dim1 = 1 + input_dim1 = 3 + hidden_dim1 = 4 + lstm_layers1 = 1 + bidirectional1 = False + proj_size1 = 0 + + _test_lstm( + batch_size1, + time_dim1, + input_dim1, + hidden_dim1, + lstm_layers1, + proj_size1, + bidirectional1, + ) + + # Test LSTM for 5 layers and 5 time steps. + batch_size2 = 5 + time_dim2 = 5 + input_dim2 = 3 + hidden_dim2 = 4 + lstm_layers2 = 5 + bidirectional2 = False + proj_size2 = 0 + + _test_lstm( + batch_size2, + time_dim2, + input_dim2, + hidden_dim2, + lstm_layers2, + proj_size2, + bidirectional2, + ) + + # Test bidirectional LSTM for 5 layers and 5 time steps. + batch_size3 = 5 + time_dim3 = 5 + input_dim3 = 3 + hidden_dim3 = 4 + lstm_layers3 = 5 + bidirectional3 = True + proj_size3 = 0 + + _test_lstm( + batch_size3, + time_dim3, + input_dim3, + hidden_dim3, + lstm_layers3, + proj_size3, + bidirectional3, + ) + + def test_conv(self) -> None: """ Test a network with a single convolution layer. The test cases are: 1)