8000 Add natural gradient descent optimizer method by SammySuliman · Pull Request #1971 · geomstats/geomstats · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add natural gradient descent optimizer method #1971

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from

Conversation

SammySuliman
Copy link

Checklist

  • My pull request has a clear and explanatory title.
  • If necessary, my code is vectorized.
  • I added appropriate unit tests.
  • I made sure the code passes all unit tests. (refer to comment below)
  • My PR follows PEP8 guidelines. (refer to comment below)
  • My PR follows geomstats coding style and API.
  • My code is properly documented and I made sure the documentation renders properly. (Link)
  • I linked to issues and PRs that are relevant to this PR.

Description

Issue

Additional context

@luisfpereira luisfpereira self-requested a review March 8, 2024 10:51
@luisfpereira luisfpereira changed the title Added natural gradient descent optimizer method Add natural gradient descent optimizer method Mar 18, 2024
Copy link
Collaborator
@luisfpereira luisfpereira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @SammySuliman! Please take a look to my comments below.

On the general side, I think we should transform your test function in a notebook (some kind of a how-to showing how to use NaturalGradientDescent), and add a minimal test that shows we can do one step properly (in the most basic model we can think of).

For the step method, please see my comments below (there's probably small bugs to fix).

@ninamiolane, on a more backend philosophical side: this code only works with pytorch, as it assumes the NN structure of pytorch and inherits from its Optimizer class. It is not trivial (neither relevant?) to make it work with other backends. Is this direction something we're happy pursuing?

Copy link
Collaborator
@ninamiolane ninamiolane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Code style needs to be cleaned to follow python's international coding guidelines. See results of Lint gihub action.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0