-
Notifications
You must be signed in to change notification settings - Fork 267
Add natural gradient descent optimizer method #1971
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @SammySuliman! Please take a look to my comments below.
On the general side, I think we should transform your test function in a notebook (some kind of a how-to showing how to use NaturalGradientDescent
), and add a minimal test that shows we can do one step
properly (in the most basic model we can think of).
For the step
method, please see my comments below (there's probably small bugs to fix).
@ninamiolane, on a more backend philosophical side: this code only works with pytorch
, as it assumes the NN structure of pytorch
and inherits from its Optimizer
class. It is not trivial (neither relevant?) to make it work with other backends. Is this direction something we're happy pursuing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Code style needs to be cleaned to follow python's international coding guidelines. See results of Lint gihub action.
Checklist
Description
Issue
Additional context