From 4e4c89b026cb0b7f296aae4dfbb45e2eb1654f43 Mon Sep 17 00:00:00 2001 From: PatWie Date: Mon, 3 Aug 2015 17:31:14 +0200 Subject: [PATCH 1/2] Adam solver This commit implements the Adam solver by Kingma et. al for CPU and GPU. All solver parameters are defined in the caffe.proto. This also adds an example for the MNIST dataset. --- examples/mnist/lenet_solver_adam.prototxt | 26 +++ examples/mnist/train_lenet_adam.sh | 3 + include/caffe/solver.hpp | 17 ++ src/caffe/proto/caffe.proto | 7 +- src/caffe/solver.cpp | 104 +++++++++++ src/caffe/test/test_gradient_based_solver.cpp | 170 +++++++++++++++--- 6 files changed, 299 insertions(+), 28 deletions(-) create mode 100644 examples/mnist/lenet_solver_adam.prototxt create mode 100755 examples/mnist/train_lenet_adam.sh diff --git a/examples/mnist/lenet_solver_adam.prototxt b/examples/mnist/lenet_solver_adam.prototxt new file mode 100644 index 00000000000..d22c5718f3f --- /dev/null +++ b/examples/mnist/lenet_solver_adam.prototxt @@ -0,0 +1,26 @@ +# The train/test net protocol buffer definition +# this follows "ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION" +net: "examples/mnist/lenet_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# All parameters are from the cited paper above +base_lr: 0.001 +momentum: 0.9 +momentum2: 0.999 +# since Adam dynamically changes the learning rate, we set the base learning +# rate to a fixed value +lr_policy: "fixed" +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 5000 +snapshot_prefix: "examples/mnist/lenet" +# solver mode: CPU or GPU +solver_type: ADAM +solver_mode: GPU diff --git a/examples/mnist/train_lenet_adam.sh b/examples/mnist/train_lenet_adam.sh new file mode 100755 index 00000000000..a32ecf2d9c2 --- /dev/null +++ b/examples/mnist/train_lenet_adam.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh + +./build/tools/caffe train --solver=examples/mnist/lenet_solver_adam.prototxt diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index d2b99923f23..582aa1427d3 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -217,6 +217,21 @@ class AdaDeltaSolver : public SGDSolver { DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; +template +class AdamSolver : public SGDSolver { + public: + explicit AdamSolver(const SolverParameter& param) + : SGDSolver(param) { AdamPreSolve();} + explicit AdamSolver(const string& param_file) + : SGDSolver(param_file) { AdamPreSolve(); } + + protected: + void AdamPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(AdamSolver); +}; + template Solver* GetSolver(const SolverParameter& param) { SolverParameter_SolverType type = param.solver_type(); @@ -232,6 +247,8 @@ Solver* GetSolver(const SolverParameter& param) { return new RMSPropSolver(param); case SolverParameter_SolverType_ADADELTA: return new AdaDeltaSolver(param); + case SolverParameter_SolverType_ADAM: + return new AdamSolver(param); default: LOG(FATAL) << "Unknown SolverType: " << type; } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index fc0d961abda..d4c97d2bd06 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 39 (last added: rms_decay) +// SolverParameter next available ID: 40 (last added: momentum2) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -216,10 +216,13 @@ message SolverParameter { ADAGRAD = 2; RMSPROP = 3; ADADELTA = 4; + ADAM = 5; } optional SolverType solver_type = 30 [default = SGD]; - // numerical stability for AdaGrad + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam optional float delta = 31 [default = 1e-8]; + // parameters for the Adam solver + optional float momentum2 = 39 [default = 0.999]; // RMSProp decay value // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 248f238eb76..9348e11c249 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -1114,11 +1114,115 @@ void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { } } +template +void AdamSolver::AdamPreSolve() { + // Add the extra history entries for Adam after those from + // SGDSolver::PreSolve + const vector*>& net_params = this->net_->learnable_params(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } +} + +template +void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype local_rate = rate * net_params_lr[param_id]; + const Dtype beta1 = this->param_.momentum(); + const Dtype beta2 = this->param_.momentum2(); + + // we create aliases for convenience + size_t update_history_offset = net_params.size(); + Blob* val_m = this->history_[param_id].get(); + Blob* val_v = this->history_[param_id + update_history_offset].get(); + Blob* val_t = this->temp_[param_id].get(); + + const int t = this->iter_ + 1; + const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / + (Dtype(1.) - pow(beta1, t)); + const int N = net_params[param_id]->count(); + const Dtype eps_hat = this->param_.delta(); + + switch (Caffe::mode()) { + case Caffe::CPU: { + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_cpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->cpu_diff(), beta1, + val_m->mutable_cpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_mul(N, + net_params[param_id]->cpu_diff(), + net_params[param_id]->cpu_diff(), + val_t->mutable_cpu_data()); + caffe_cpu_axpby(N, Dtype(1)-beta2, + val_t->cpu_data(), beta2, + val_v->mutable_cpu_data()); + + // set update + caffe_powx(N, + val_v->cpu_data(), Dtype(0.5), + val_t->mutable_cpu_data()); + caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); + caffe_div(N, + val_m->cpu_data(), + val_t->cpu_data(), + val_t->mutable_cpu_data()); + + caffe_cpu_scale(N, local_rate*correction, + val_t->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_gpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->gpu_diff(), beta1, + val_m->mutable_gpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_gpu_mul(N, + net_params[param_id]->gpu_diff(), + net_params[param_id]->gpu_diff(), + val_t->mutable_gpu_data()); + caffe_gpu_axpby(N, Dtype(1)-beta2, + val_t->gpu_data(), beta2, + val_v->mutable_gpu_data()); + + // set update + caffe_gpu_powx(N, + val_v->gpu_data(), Dtype(0.5), + val_t->mutable_gpu_data()); + caffe_gpu_add_scalar(N, eps_hat, + val_t->mutable_gpu_data()); + caffe_gpu_div(N, + val_m->gpu_data(), + val_t->gpu_data(), + val_t->mutable_gpu_data()); + + caffe_gpu_scale(N, local_rate*correction, + val_t->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + INSTANTIATE_CLASS(Solver); INSTANTIATE_CLASS(SGDSolver); INSTANTIATE_CLASS(NesterovSolver); INSTANTIATE_CLASS(AdaGradSolver); INSTANTIATE_CLASS(RMSPropSolver); INSTANTIATE_CLASS(AdaDeltaSolver); +INSTANTIATE_CLASS(AdamSolver); } // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 1d255a86621..dcbfff1cad2 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -42,7 +42,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { // TODO this is brittle and the hdf5 file should be checked instead. int num_, channels_, height_, width_; bool share_; - Dtype delta_; // Stability constant for AdaGrad. + Dtype delta_; // Stability constant for RMSProp, AdaGrad, AdaDelta and Adam // Test data: check out generate_sample_data.py in the same directory. string* input_file_; @@ -65,10 +65,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); } InitSolver(param); - delta_ = (solver_type() == SolverParameter_SolverType_ADAGRAD || - solver_type() == SolverParameter_SolverType_RMSPROP || - solver_type() == SolverParameter_SolverType_ADADELTA) ? - param.delta() : 0; + delta_ = param.delta(); } string RunLeastSquaresSolver(const Dtype learning_rate, @@ -216,7 +213,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { // updated_params will store the updated weight and bias results, // using the blobs' diffs to hold the update values themselves. void ComputeLeastSquaresUpdate(const Dtype learning_rate, - const Dtype weight_decay, const Dtype momentum, + const Dtype weight_decay, const Dtype momentum, const int num_iters, vector > >* updated_params) { const int N = num_; const int D = channels_ * height_ * width_; @@ -282,7 +279,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); // Finally, compute update. const vector > >& history = solver_->history(); - if (solver_type() != SolverParameter_SolverType_ADADELTA) { + if (solver_type() != SolverParameter_SolverType_ADADELTA + && solver_type() != SolverParameter_SolverType_ADAM) { ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias } else { ASSERT_EQ(4, history.size()); // additional blobs for update history @@ -312,16 +310,31 @@ class GradientBasedSolverTest : public MultiDeviceTest { case SolverParameter_SolverType_ADADELTA: { const Dtype update_history_value = (i == D) ? - history[3]->cpu_data()[0] : history[2]->cpu_data()[i]; + history[1 + num_param_blobs]->cpu_data()[0] : + history[0 + num_param_blobs]->cpu_data()[i]; const Dtype weighted_gradient_average = momentum * history_value + (1 - momentum) * (grad * grad); update_value = grad * std::sqrt((update_history_value + delta_) / - (weighted_gradient_average + delta_)); + (weighted_gradient_average + delta_)) * learning_rate; // not actually needed, just here for illustrative purposes // const Dtype weighted_update_average = // momentum * update_history_value + (1 - momentum) * (update_value); break; } + case SolverParameter_SolverType_ADAM: { + const Dtype momentum2 = 0.999; + const Dtype m = history_value; + const Dtype v = (i == D) ? + history[1 + num_param_blobs]->cpu_data()[0] : + history[0 + num_param_blobs]->cpu_data()[i]; + const Dtype val_m = (1 - momentum) * grad + momentum * m; + const Dtype val_v = (1 - momentum2) * grad * grad + momentum2 * v; + Dtype alpha_t = learning_rate * + std::sqrt(Dtype(1) - pow(momentum2, num_iters)) / + (Dtype(1.) - pow(momentum, num_iters)); + update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_); + break; + } default: LOG(FATAL) << "Unknown solver type: " << solver_type(); } @@ -465,7 +478,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Compute the (K+1)th update using the analytic least squares gradient. vector > > updated_params; ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, - &updated_params); + iter_to_check + 1, &updated_params); // Reinitialize the solver and run K+1 solver iterations. num_ = kNum; @@ -946,13 +959,13 @@ TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdate) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; this->TestLeastSquaresUpdate(kLearningRate); } TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.5; const Dtype kMomentum = 0.95; this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); @@ -960,64 +973,64 @@ TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithWeightDecay) { TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithHalfMomentum) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.5; const int kNumIters = 1; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); } } TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithMomentum) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.95; const int kNumIters = 1; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); } } TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.0; const Dtype kMomentum = 0.95; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverything) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; const int kNumIters = 4; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } TYPED_TEST(AdaDeltaSolverTest, TestAdaDeltaLeastSquaresUpdateWithEverythingShare) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; const int kNumIters = 4; this->share_ = true; for (int i = 0; i <= kNumIters; ++i) { - this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); } } TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; const int kNumIters = 4; @@ -1028,7 +1041,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; const int kNumIters = 4; @@ -1040,7 +1053,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; const int kNumIters = 4; @@ -1051,7 +1064,7 @@ TYPED_TEST(AdaDeltaSolverTest, TestSnapshot) { TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) { typedef typename TypeParam::Dtype Dtype; - const Dtype kLearningRate = 1.0; + const Dtype kLearningRate = 0.1; const Dtype kWeightDecay = 0.1; const Dtype kMomentum = 0.95; const int kNumIters = 4; @@ -1061,6 +1074,111 @@ TYPED_TEST(AdaDeltaSolverTest, TestSnapshotShare) { } } +template +class AdamSolverTest : public GradientBasedSolverTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + virtual void InitSolver(const SolverParameter& param) { + SolverParameter new_param = param; + const Dtype momentum = 0.9; + new_param.set_momentum(momentum); + const Dtype momentum2 = 0.999; + new_param.set_momentum2(momentum2); + this->solver_.reset(new AdamSolver(new_param)); + } + virtual SolverParameter_SolverType solver_type() { + return SolverParameter_SolverType_ADAM; + } +}; + +TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices); + +TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0; + const Dtype kMomentum = 0.9; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum); +} + +TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestAdamLeastSquaresUpdateWithEverythingShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + this->share_ = true; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + const int kIterSize = 2; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdamSolverTest, TestLeastSquaresUpdateWithEverythingAccumShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + const int kIterSize = 2; + this->share_ = true; + this->CheckAccumulation(kLearningRate, kWeightDecay, kMomentum, kNumIters, + kIterSize); +} + +TYPED_TEST(AdamSolverTest, TestSnapshot) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(AdamSolverTest, TestSnapshotShare) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.5; + const Dtype kMomentum = 0.9; + const int kNumIters = 4; + this->share_ = true; + for (int i = 1; i <= kNumIters; ++i) { + this->TestSnapshot(kLearningRate, kWeightDecay, kMomentum, i); + } +} + template class RMSPropSolverTest : public GradientBasedSolverTest { typedef typename TypeParam::Dtype Dtype; From bf42e6ebf7c56ff2f0d13bdcc7294d357d7592c6 Mon Sep 17 00:00:00 2001 From: Ronghang Hu Date: Thu, 13 Aug 2015 22:41:21 -0700 Subject: [PATCH 2/2] Cite Adam paper in solver.hpp --- include/caffe/solver.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 582aa1427d3..ab12ef1b1bd 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -217,6 +217,14 @@ class AdaDeltaSolver : public SGDSolver { DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); }; +/** + * @brief AdamSolver, an algorithm for first-order gradient-based optimization + * of stochastic objective functions, based on adaptive estimates of + * lower-order moments. Described in [1]. + * + * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." + * arXiv preprint arXiv:1412.6980v8 (2014). + */ template class AdamSolver : public SGDSolver { public: