Skip to content

feat: [WIP] add RNN training#1

Open
ValerianRey wants to merge 3 commits intomainfrom
add-rnn-training
Open

feat: [WIP] add RNN training#1
ValerianRey wants to merge 3 commits intomainfrom
add-rnn-training

Conversation

@ValerianRey
Copy link

@mattbuot

@PierreQuinton this is the code of what I explained on signal. super messy so far.

Comment on lines 60 to 70
Copy link

@PierreQuinton PierreQuinton Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be possible to clone the parameters of the memory model at each call, it should not require more memory. But then if we do backward we obtain a grad for each of the copies, we can stack them. Of course this also works and later on we can also make this quite efficient with hooks.

Copy link

@PierreQuinton PierreQuinton Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in this code, there is no training at all? (no .grad=...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it could be possible to clone the parameters of the memory model at each call, it should not require more memory. But then if we do backward we obtain a grad for each of the copies, we can stack them. Of course this also works and later on we can also make this quite efficient with hooks.

I think the current method is almost maximally efficient. But maybe it's not expressive enough (can't really select paths of length 1, 2, 4, 8, etc, without computing also 3, 5, 6, 7, ..., for now).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe do what you say with a detached view of the parameters (I think cloning duplicates memory + is differentiable so the gradients would flow back to the original params)

Copy link

@PierreQuinton PierreQuinton Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Selecting only paths is doable only with residual RNN. But note that if you select only path to level 2 memory, then you don't train interaction between level 1 and level 2, which is not typically what we want to do.

* Reset memories, memories_wrt and param_to_gradients
* Use transform to aggregate and accumulate into .grad
* Train head too
* Change some values
* At this point, it seems hard to train with Mean() and doable with
  UPGrad()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments