10000 Deduplicate solver regularization, logging, and local rates and decays by shelhamer · Pull Request #2518 · BVLC/caffe · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Deduplicate solver regularization, logging, and local rates and decays #2518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 27, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace caffe {
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ComputeUpdateValue to compute a parameter update
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
Expand Down Expand Up @@ -39,8 +39,8 @@ class Solver {
int iter() { return iter_; }

protected:
// Get the update value for the current iteration.
virtual void ComputeUpdateValue() = 0;
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
Expand Down Expand Up @@ -80,7 +80,9 @@ class SGDSolver : public Solver<Dtype> {
protected:
void PreSolve();
Dtype GetLearningRate();
virtual void ComputeUpdateValue();
virtual void ApplyUpdate();
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void ClipGradients();
virtual void SnapshotSolverState(SolverState * state);
virtual void RestoreSolverState(const SolverState& state);
Expand All @@ -102,7 +104,7 @@ class NesterovSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param_file) {}

protected:
virtual void ComputeUpdateValue();
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};
Expand All @@ -116,7 +118,7 @@ class AdaGradSolver : public SGDSolver<Dtype> {
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }

protected:
virtual void ComputeUpdateValue();
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with AdaGrad.";
Expand Down
Loading
0