Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRU support #254

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open

GRU support #254

wants to merge 11 commits into from

Conversation

kylebgorman
Copy link
Contributor

@kylebgorman kylebgorman commented Oct 8, 2024

This adds GRU support; everywhere there is an LSTM model, there is now a GRU model too.

I initially tried to make RNN type a general flag but because LSTMs return the cell state in addition to the hidden state, and because various models need to reshape, average, or otherwise manipulate that cell state, this was really not feasible. Thefore I just create, for each model that was previously "LSTM-backed", an abstract class called FooRNN{Encoder,Decoder,Model}. FooLSTM subclasses this and returns a LSTM module (it may also have special logic in the forward method, or decode method, or whatever), as does FooGRU.

I experimented with traditional Elman RNNs (they have the same simpler interface as GRUs) but performance was absymal so I'm not going to bother.

All models have been tested on CPU and GPU.

Other changes:

  • Fixes init_hiddens calling convention #251 is also implied here, but I separated it out for review.
  • The names got confusing so I also went ahead and replaced EncoderDecoder in our naming convention with simply just Model.

Closes #180. (Note however there's still plenty to do to study the effects this has.)

@kylebgorman kylebgorman marked this pull request as ready for review October 8, 2024 16:53
@bonham79 bonham79 self-requested a review October 14, 2024 05:22
@Adamits
Copy link
Collaborator

Adamits commented Oct 15, 2024

Hey Kyle, I just started looking through the PR, but first, it occurs to me that I am unsure "because LSTMs return the cell state in addition to the hidden state, and because various models need to reshape, average, or otherwise manipulate that cell state" is a problem.

I didnt look everywhere, but iirc when we manipulate the state, we manipulate the h rather than the cell right? so if we just have a property for "the thing that should be manipulated" and GRU returns the hidden tensor that is returned from its forward call, and LSTM just returns the first tensor in the tuple, then I think we should be good.

I did not think through what that actually buys us, but I am always pro-anything that reduces the amount of abstraction :D.

EDIT: I suppose what I am suggesting here ADDS abstraction. But it may clean things up in the way that I thought you were suggesting in your comment.

@kylebgorman
Copy link
Contributor Author

I didnt look everywhere, but iirc when we manipulate the state, we manipulate the h rather than the cell right? so if we just have a property for "the thing that should be manipulated" and GRU returns the hidden tensor that is returned from its forward call, and LSTM just returns the first tensor in the tuple, then I think we should be good.

There is at least one case where both $h$ and $c$ are manipulated: see here. This seemed to necessitate the design I pursue here.

I suspect there is some abstraction that would reduce the amount of boilerplate here, but I'm not even sure if it's worth tracing down. If you can think of anything though please go ahead and suggest!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GRU support
2 participants