LCOV - code coverage report
Current view: top level - src/include/panoc-alm/inner - guarded-aa-pga.hpp (source / functions) Hit Total Coverage
Test: ecee3ec3a495b05c61f077aa7a236b7e00601437 Lines: 0 157 0.0 %
Date: 2021-11-04 22:49:09 Functions: 0 16 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <panoc-alm/inner/decl/panoc-stop-crit.hpp>
       4             : #include <panoc-alm/inner/detail/anderson-helpers.hpp>
       5             : #include <panoc-alm/inner/detail/panoc-helpers.hpp>
       6             : #include <panoc-alm/util/atomic_stop_signal.hpp>
       7             : #include <panoc-alm/util/lipschitz.hpp>
       8             : #include <panoc-alm/util/solverstatus.hpp>
       9             : 
      10             : #include <cassert>
      11             : #include <chrono>
      12             : #include <cmath>
      13             : #include <iomanip>
      14             : #include <iostream>
      15             : #include <stdexcept>
      16             : 
      17             : namespace pa {
      18             : 
      19           0 : struct GAAPGAParams {
      20             :     /// Parameters related to the Lipschitz constant estimate and step size.
      21             :     LipschitzEstimateParams Lipschitz;
      22             :     /// Length of the history to keep in the limited-memory QR algorithm.
      23           0 :     unsigned limitedqr_mem = 10;
      24             :     /// Maximum number of inner iterations.
      25           0 :     unsigned max_iter = 100;
      26             :     /// Maximum duration.
      27           0 :     std::chrono::microseconds max_time = std::chrono::minutes(5);
      28             :     /// Minimum Lipschitz constant estimate.
      29           0 :     real_t L_min = 1e-5;
      30             :     /// Maximum Lipschitz constant estimate.
      31           0 :     real_t L_max = 1e20;
      32             :     /// What stopping criterion to use.
      33           0 :     PANOCStopCrit stop_crit = PANOCStopCrit::ApproxKKT;
      34             : 
      35             :     /// When to print progress. If set to zero, nothing will be printed.
      36             :     /// If set to N != 0, progress is printed every N iterations.
      37           0 :     unsigned print_interval = 0;
      38             : 
      39           0 :     real_t quadratic_upperbound_tolerance_factor =
      40           0 :         10 * std::numeric_limits<real_t>::epsilon();
      41             : 
      42             :     /// Maximum number of iterations without any progress before giving up.
      43           0 :     unsigned max_no_progress = 10;
      44             : 
      45           0 :     bool full_flush_on_γ_change = true;
      46             : };
      47             : 
      48           0 : struct GAAPGAProgressInfo {
      49             :     unsigned k;
      50             :     crvec x;
      51             :     crvec p;
      52             :     real_t norm_sq_p;
      53             :     crvec x_hat;
      54             :     real_t ψ;
      55             :     crvec grad_ψ;
      56             :     real_t ψ_hat;
      57             :     crvec grad_ψ_hat;
      58             :     real_t L;
      59             :     real_t γ;
      60             :     real_t ε;
      61             :     crvec Σ;
      62             :     crvec y;
      63             :     const Problem &problem;
      64             :     const GAAPGAParams &params;
      65             : };
      66             : 
      67             : /// Guarded Anderson Accelerated Proximal Gradient Algorithm.
      68             : /// Vien V. Mai and Mikael Johansson, Anderson Acceleration of Proximal Gradient Methods.
      69             : /// https://arxiv.org/abs/1910.08590v2
      70             : ///
      71             : /// @ingroup    grp_InnerSolvers
      72           0 : class GAAPGASolver {
      73             :   public:
      74             :     using Params = GAAPGAParams;
      75             : 
      76           0 :     GAAPGASolver(const Params &params) : params(params) {}
      77             : 
      78           0 :     struct Stats {
      79           0 :         SolverStatus status = SolverStatus::Unknown;
      80           0 :         real_t ε            = inf;
      81             :         std::chrono::microseconds elapsed_time;
      82           0 :         unsigned iterations                 = 0;
      83           0 :         unsigned accelerated_steps_accepted = 0;
      84             :     };
      85             : 
      86             :     using ProgressInfo = GAAPGAProgressInfo;
      87             : 
      88             :     Stats operator()(const Problem &problem,        // in
      89             :                      crvec Σ,                       // in
      90             :                      real_t ε,                      // in
      91             :                      bool always_overwrite_results, // in
      92             :                      rvec x,                        // inout
      93             :                      rvec λ,                        // inout
      94             :                      rvec err_z);                   // out
      95             : 
      96             :     GAAPGASolver &
      97             :     set_progress_callback(std::function<void(const ProgressInfo &)> cb) {
      98             :         this->progress_cb = cb;
      99             :         return *this;
     100             :     }
     101             : 
     102             :     std::string get_name() const { return "GAAPGA"; }
     103             : 
     104             :     void stop() { stop_signal.stop(); }
     105             : 
     106             :     const Params &get_params() const { return params; }
     107             : 
     108             :   private:
     109             :     Params params;
     110             :     AtomicStopSignal stop_signal;
     111             :     std::function<void(const ProgressInfo &)> progress_cb;
     112             : };
     113             : 
     114             : using std::chrono::duration_cast;
     115             : using std::chrono::microseconds;
     116             : 
     117             : inline GAAPGASolver::Stats
     118           0 : GAAPGASolver::operator()(const Problem &problem,        // in
     119             :                          crvec Σ,                       // in
     120             :                          real_t ε,                      // in
     121             :                          bool always_overwrite_results, // in
     122             :                          rvec x,                        // inout
     123             :                          rvec y,                        // inout
     124             :                          rvec err_z                     // out
     125             : ) {
     126           0 :     auto start_time = std::chrono::steady_clock::now();
     127           0 :     Stats s;
     128             : 
     129           0 :     const auto n = problem.n;
     130           0 :     const auto m = problem.m;
     131             : 
     132           0 :     vec xₖ = x,      // Value of x at the beginning of the iteration
     133           0 :         x̂ₖ(n),       // Value of x after a projected gradient step
     134           0 :         gₖ(n),       // <?>
     135           0 :         rₖ₋₁(n),     // <?>
     136           0 :         rₖ(n),       // <?>
     137           0 :         pₖ(n),       // Projected gradient step
     138           0 :         yₖ(n),       // Value of x after a gradient or AA step
     139           0 :         ŷₖ(m),       // <?>
     140           0 :         grad_ψₖ(n),  // ∇ψ(xₖ)
     141           0 :         grad_ψx̂ₖ(n); // ∇ψ(x̂ₖ)
     142             : 
     143           0 :     vec work_n(n), work_n2(n), work_m(m);
     144             : 
     145           0 :     unsigned m_AA = std::min(params.limitedqr_mem, n);
     146           0 :     LimitedMemoryQR qr(n, m_AA);
     147           0 :     mat G(n, m_AA);
     148           0 :     vec γ_LS(m_AA);
     149             : 
     150             :     // Helper functions --------------------------------------------------------
     151             : 
     152             :     // Wrappers for helper functions that automatically pass along any arguments
     153             :     // that are constant within AAPGA (for readability in the main algorithm)
     154           0 :     auto calc_ψ_ŷ = [&problem, &y, &Σ](crvec x, rvec ŷ) {
     155           0 :         return detail::calc_ψ_ŷ(problem, x, y, Σ, ŷ);
     156           0 :     };
     157           0 :     auto calc_ψ_grad_ψ = [&problem, &y, &Σ, &work_n, &work_m](crvec x,
     158             :                                                               rvec grad_ψ) {
     159           0 :         return detail::calc_ψ_grad_ψ(problem, x, y, Σ, grad_ψ, work_n, work_m);
     160           0 :     };
     161           0 :     auto calc_grad_ψ_from_ŷ = [&problem, &work_n](crvec x, crvec ŷ,
     162             :                                                   rvec grad_ψ) {
     163           0 :         detail::calc_grad_ψ_from_ŷ(problem, x, ŷ, grad_ψ, work_n);
     164           0 :     };
     165           0 :     auto calc_x̂ = [&problem](real_t γ, crvec x, crvec grad_ψ, rvec x̂, rvec p) {
     166           0 :         detail::calc_x̂(problem, γ, x, grad_ψ, x̂, p);
     167           0 :     };
     168           0 :     auto calc_err_z = [&problem, &y, &Σ](crvec x̂, rvec err_z) {
     169           0 :         detail::calc_err_z(problem, x̂, y, Σ, err_z);
     170           0 :     };
     171           0 :     auto descent_lemma = [this, &problem, &y,
     172             :                           &Σ](crvec xₖ, real_t ψₖ, crvec grad_ψₖ, rvec x̂ₖ,
     173             :                               rvec pₖ, rvec ŷx̂ₖ, real_t &ψx̂ₖ, real_t &pₖᵀpₖ,
     174             :                               real_t &grad_ψₖᵀpₖ, real_t &Lₖ, real_t &γₖ) {
     175           0 :         return detail::descent_lemma(
     176           0 :             problem, params.quadratic_upperbound_tolerance_factor, params.L_max,
     177           0 :             xₖ, ψₖ, grad_ψₖ, y, Σ, x̂ₖ, pₖ, ŷx̂ₖ, ψx̂ₖ, pₖᵀpₖ, grad_ψₖᵀpₖ, Lₖ, γₖ);
     178           0 :     };
     179           0 :     auto print_progress = [&](unsigned k, real_t ψₖ, crvec grad_ψₖ, crvec pₖ,
     180             :                               real_t γₖ, real_t εₖ) {
     181           0 :         std::cout << "[AAPGA] " << std::setw(6) << k
     182           0 :                   << ": ψ = " << std::setw(13) << ψₖ
     183           0 :                   << ", ‖∇ψ‖ = " << std::setw(13) << grad_ψₖ.norm()
     184           0 :                   << ", ‖p‖ = " << std::setw(13) << pₖ.norm()
     185           0 :                   << ", γ = " << std::setw(13) << γₖ
     186           0 :                   << ", εₖ = " << std::setw(13) << εₖ << "\r\n";
     187           0 :     };
     188             : 
     189             :     // Estimate Lipschitz constant ---------------------------------------------
     190             : 
     191           0 :     real_t ψₖ, Lₖ;
     192             :     // Finite difference approximation of ∇²ψ in starting point
     193           0 :     if (params.Lipschitz.L₀ <= 0) {
     194           0 :         Lₖ = detail::initial_lipschitz_estimate(
     195           0 :             problem, xₖ, y, Σ, params.Lipschitz.ε, params.Lipschitz.δ,
     196           0 :             params.L_min, params.L_max,
     197           0 :             /* in ⟹ out */ grad_ψₖ, /* work */ x̂ₖ, work_n2, work_n, work_m);
     198           0 :     }
     199             :     // Initial Lipschitz constant provided by the user
     200             :     else {
     201           0 :         Lₖ = params.Lipschitz.L₀;
     202             :         // Calculate ψ(xₖ), ∇ψ(x₀)
     203           0 :         ψₖ = calc_ψ_grad_ψ(xₖ, /* in ⟹ out */ grad_ψₖ);
     204             :     }
     205           0 :     if (not std::isfinite(Lₖ)) {
     206           0 :         s.status = SolverStatus::NotFinite;
     207           0 :         return s;
     208             :     }
     209             : 
     210           0 :     real_t γₖ = params.Lipschitz.Lγ_factor / Lₖ;
     211             : 
     212             :     // First projected gradient step -------------------------------------------
     213             : 
     214           0 :     rₖ₋₁     = -γₖ * grad_ψₖ;
     215           0 :     yₖ       = xₖ + rₖ₋₁;
     216           0 :     xₖ       = project(yₖ, problem.C);
     217           0 :     G.col(0) = yₖ;
     218             : 
     219           0 :     unsigned no_progress = 0;
     220             : 
     221             :     // Calculate gradient in second iterate ------------------------------------
     222             : 
     223             :     // Calculate ψ(x₁) and ∇ψ(x₁)
     224           0 :     ψₖ = calc_ψ_grad_ψ(xₖ, /* in ⟹ out */ grad_ψₖ);
     225             : 
     226             :     // Main loop
     227             :     // =========================================================================
     228           0 :     for (unsigned k = 0; k <= params.max_iter; ++k) {
     229             :         // From previous iteration:
     230             :         //  - xₖ
     231             :         //  - grad_ψₖ
     232             :         //  - ψₖ
     233             :         //  - rₖ₋₁
     234             :         //  - history in qr and G
     235             : 
     236             :         // Quadratic upper bound -----------------------------------------------
     237             : 
     238             :         // Projected gradient step: x̂ₖ and pₖ
     239           0 :         calc_x̂(γₖ, xₖ, grad_ψₖ, /* in ⟹ out */ x̂ₖ, pₖ);
     240             :         // Calculate ψ(x̂ₖ) and ŷ(x̂ₖ)
     241           0 :         real_t ψx̂ₖ = calc_ψ_ŷ(x̂ₖ, /* in ⟹ out */ ŷₖ);
     242             :         // Calculate ∇ψ(xₖ)ᵀpₖ and ‖pₖ‖²
     243           0 :         real_t grad_ψₖᵀpₖ = grad_ψₖ.dot(pₖ);
     244           0 :         real_t pₖᵀpₖ      = pₖ.squaredNorm();
     245             : 
     246           0 :         real_t old_γₖ = descent_lemma(xₖ, ψₖ, grad_ψₖ, x̂ₖ, pₖ, ŷₖ, ψx̂ₖ, pₖᵀpₖ,
     247             :                                       grad_ψₖᵀpₖ, Lₖ, γₖ);
     248             : 
     249             :         // Flush or update Anderson buffers if step size changed
     250           0 :         if (γₖ != old_γₖ) {
     251           0 :             if (params.full_flush_on_γ_change) {
     252             :                 // Save the latest function evaluation gₖ at the first index
     253           0 :                 size_t newest_g_idx = qr.ring_tail();
     254           0 :                 if (newest_g_idx != 0)
     255           0 :                     G.col(0) = G.col(newest_g_idx);
     256             :                 // Flush everything else and reset indices
     257           0 :                 qr.reset();
     258           0 :             } else {
     259             :                 // When not near the boundaries of the feasible set,
     260             :                 // r(x) = g(x) - x = Π(x - γ∇ψ(x)) - x = -γ∇ψ(x),
     261             :                 // in other words, r(x) is proportional to γ, and so is Δr,
     262             :                 // so when γ changes, these values have to be updated as well
     263           0 :                 qr.scale_R(γₖ / old_γₖ);
     264             :             }
     265           0 :             rₖ₋₁ *= γₖ / old_γₖ;
     266           0 :         }
     267             : 
     268             :         // Calculate ∇ψ(x̂ₖ)
     269           0 :         calc_grad_ψ_from_ŷ(x̂ₖ, ŷₖ, /* in ⟹ out */ grad_ψx̂ₖ);
     270             : 
     271             :         // Check stop condition ------------------------------------------------
     272             : 
     273           0 :         real_t εₖ = detail::calc_error_stop_crit(
     274           0 :             problem.C, params.stop_crit, pₖ, γₖ, xₖ, x̂ₖ, ŷₖ, grad_ψₖ, grad_ψx̂ₖ);
     275             : 
     276             :         // Print progress
     277           0 :         if (params.print_interval != 0 && k % params.print_interval == 0)
     278           0 :             print_progress(k, ψₖ, grad_ψₖ, pₖ, γₖ, εₖ);
     279           0 :         if (progress_cb)
     280           0 :             progress_cb({k, xₖ, pₖ, pₖᵀpₖ, x̂ₖ, ψₖ, grad_ψₖ, ψx̂ₖ, grad_ψx̂ₖ, Lₖ,
     281           0 :                          γₖ, εₖ, Σ, y, problem, params});
     282             : 
     283           0 :         auto time_elapsed = std::chrono::steady_clock::now() - start_time;
     284           0 :         auto stop_status  = detail::check_all_stop_conditions(
     285           0 :             params, time_elapsed, k, stop_signal, ε, εₖ, no_progress);
     286           0 :         if (stop_status != SolverStatus::Unknown) {
     287             :             // TODO: We could cache g(x) and ẑ, but would that faster?
     288             :             //       It saves 1 evaluation of g per ALM iteration, but requires
     289             :             //       many extra stores in the inner loops of PANOC.
     290             :             // TODO: move the computation of ẑ and g(x) to ALM?
     291           0 :             if (stop_status == SolverStatus::Converged ||
     292           0 :                 stop_status == SolverStatus::Interrupted ||
     293           0 :                 always_overwrite_results) {
     294           0 :                 calc_err_z(x̂ₖ, /* in ⟹ out */ err_z);
     295           0 :                 x = std::move(x̂ₖ);
     296           0 :                 y = std::move(ŷₖ);
     297           0 :             }
     298           0 :             s.iterations   = k;
     299           0 :             s.ε            = εₖ;
     300           0 :             s.elapsed_time = duration_cast<microseconds>(time_elapsed);
     301           0 :             s.status       = stop_status;
     302           0 :             return s;
     303             :         }
     304             : 
     305             :         // Standard gradient descent step
     306           0 :         gₖ = xₖ - γₖ * grad_ψₖ;
     307           0 :         rₖ = gₖ - yₖ;
     308             : 
     309             :         // Solve Anderson acceleration least squares problem and update history
     310           0 :         minimize_update_anderson(qr, G, rₖ, rₖ₋₁, gₖ, γ_LS, yₖ);
     311             : 
     312             :         // Project accelerated step onto feasible set
     313           0 :         xₖ = project(yₖ, problem.C);
     314             : 
     315             :         // Calculate the objective at the projected accelerated point
     316           0 :         real_t ψₖ₊₁  = calc_ψ_ŷ(xₖ, /* in ⟹ out */ ŷₖ);
     317           0 :         real_t old_ψ = ψₖ;
     318             : 
     319             :         // Check sufficient decrease condition for Anderson iterate
     320           0 :         bool sufficient_decrease;
     321             :         if (0) // From paper
     322             :             sufficient_decrease = ψₖ₊₁ <= ψₖ - 0.5 * γₖ * grad_ψₖ.squaredNorm();
     323             :         else // Since we compute ψ(x̂ₖ) we might as well pick the best one
     324           0 :             sufficient_decrease = ψₖ₊₁ <= ψx̂ₖ;
     325             : 
     326           0 :         if (sufficient_decrease) {
     327             :             // Accept Anderson step
     328             :             // yₖ and xₖ are already overwritten by yₑₓₜ and Π(yₑₓₜ)
     329           0 :             ψₖ = ψₖ₊₁;
     330           0 :             calc_grad_ψ_from_ŷ(xₖ, ŷₖ, /* in ⟹ out */ grad_ψₖ);
     331           0 :         }
     332             :         // If not satisfied, take normal projected gradient step
     333             :         else {
     334           0 :             yₖ.swap(gₖ);
     335           0 :             xₖ.swap(x̂ₖ);
     336           0 :             ψₖ = ψx̂ₖ;
     337           0 :             grad_ψₖ.swap(grad_ψx̂ₖ);
     338             :         }
     339           0 :         rₖ.swap(rₖ₋₁);
     340           0 :         s.accelerated_steps_accepted += sufficient_decrease;
     341             : 
     342             :         // Check if we made any progress, prevents us from exceeding the maximum
     343             :         // number of iterations doing nothing if the step size gets too small
     344             :         // TODO: is this a valid test?
     345           0 :         no_progress = (ψₖ == old_ψ) ? no_progress + 1 : 0;
     346           0 :     }
     347           0 :     throw std::logic_error("[AAPGA] loop error");
     348           0 : }
     349             : 
     350             : template <class InnerSolverStats>
     351             : struct InnerStatsAccumulator;
     352             : 
     353             : template <>
     354           0 : struct InnerStatsAccumulator<GAAPGASolver::Stats> {
     355             :     std::chrono::microseconds elapsed_time;
     356           0 :     unsigned iterations                 = 0;
     357           0 :     unsigned accelerated_steps_accepted = 0;
     358             : };
     359             : 
     360             : inline InnerStatsAccumulator<GAAPGASolver::Stats> &
     361           0 : operator+=(InnerStatsAccumulator<GAAPGASolver::Stats> &acc,
     362             :            const GAAPGASolver::Stats &s) {
     363           0 :     acc.elapsed_time += s.elapsed_time;
     364           0 :     acc.iterations += s.iterations;
     365           0 :     acc.accelerated_steps_accepted += s.accelerated_steps_accepted;
     366           0 :     return acc;
     367             : }
     368             : 
     369             : } // namespace pa

Generated by: LCOV version 1.15