LCOV - code coverage report
Current view: top level - src/include/panoc-alm/inner - panoc.hpp (source / functions) Hit Total Coverage
Test: ecee3ec3a495b05c61f077aa7a236b7e00601437 Lines: 155 185 83.8 %
Date: 2021-11-04 22:49:09 Functions: 7 9 77.8 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <panoc-alm/inner/decl/panoc.hpp>
       4             : #include <panoc-alm/inner/detail/panoc-helpers.hpp>
       5             : #include <panoc-alm/inner/directions/decl/panoc-direction-update.hpp>
       6             : 
       7             : #include <cassert>
       8             : #include <cmath>
       9             : #include <iomanip>
      10             : #include <iostream>
      11             : #include <stdexcept>
      12             : 
      13             : namespace pa {
      14             : 
      15             : using std::chrono::duration_cast;
      16             : using std::chrono::microseconds;
      17             : 
      18             : template <class DirectionProviderT>
      19           0 : std::string PANOCSolver<DirectionProviderT>::get_name() const {
      20           0 :     return "PANOCSolver<" + direction_provider.get_name() + ">";
      21           0 : }
      22             : 
      23             : template <class DirectionProviderT>
      24             : typename PANOCSolver<DirectionProviderT>::Stats
      25          31 : PANOCSolver<DirectionProviderT>::operator()(
      26             :     /// [in]    Problem description
      27             :     const Problem &problem,
      28             :     /// [in]    Constraint weights @f$ \Sigma @f$
      29             :     crvec Σ,
      30             :     /// [in]    Tolerance @f$ \varepsilon @f$
      31             :     real_t ε,
      32             :     /// [in]    Overwrite @p x, @p y and @p err_z even if not converged
      33             :     bool always_overwrite_results,
      34             :     /// [inout] Decision variable @f$ x @f$
      35             :     rvec x,
      36             :     /// [inout] Lagrange multipliers @f$ y @f$
      37             :     rvec y,
      38             :     /// [out]   Slack variable error @f$ g(x) - z @f$
      39             :     rvec err_z) {
      40             : 
      41          31 :     auto start_time = std::chrono::steady_clock::now();
      42          31 :     Stats s;
      43             : 
      44          31 :     const auto n = problem.n;
      45          31 :     const auto m = problem.m;
      46             : 
      47             :     // Allocate vectors, init L-BFGS -------------------------------------------
      48             : 
      49             :     // TODO: the L-BFGS objects and vectors allocate on each iteration of ALM,
      50             :     //       and there are more vectors than strictly necessary.
      51             : 
      52          31 :     bool need_grad_̂ψₖ = detail::stop_crit_requires_grad_̂ψₖ(params.stop_crit);
      53             : 
      54          62 :     vec xₖ = x,   // Value of x at the beginning of the iteration
      55          31 :         x̂ₖ(n),    // Value of x after a projected gradient step
      56          31 :         xₖ₊₁(n),  // xₖ for next iteration
      57          31 :         x̂ₖ₊₁(n),  // x̂ₖ for next iteration
      58          31 :         ŷx̂ₖ(m),   // ŷ(x̂ₖ) = Σ (g(x̂ₖ) - ẑₖ)
      59          31 :         ŷx̂ₖ₊₁(m), // ŷ(x̂ₖ) for next iteration
      60          31 :         pₖ(n),    // Projected gradient step pₖ = x̂ₖ - xₖ
      61          31 :         pₖ₊₁(n), // Projected gradient step pₖ₊₁ = x̂ₖ₊₁ - xₖ₊₁
      62          31 :         qₖ(n),   // Newton step Hₖ pₖ
      63          31 :         grad_ψₖ(n),                    // ∇ψ(xₖ)
      64          31 :         grad_̂ψₖ(need_grad_̂ψₖ ? n : 0), // ∇ψ(x̂ₖ)
      65          31 :         grad_ψₖ₊₁(n);                  // ∇ψ(xₖ₊₁)
      66             : 
      67          31 :     vec work_n(n), work_m(m);
      68             : 
      69             :     // Keep track of how many successive iterations didn't update the iterate
      70          31 :     unsigned no_progress = 0;
      71             : 
      72             :     // Helper functions --------------------------------------------------------
      73             : 
      74             :     // Wrappers for helper functions that automatically pass along any arguments
      75             :     // that are constant within PANOC (for readability in the main algorithm)
      76         890 :     auto calc_ψ_ŷ = [&problem, &y, &Σ](crvec x, rvec ŷ) {
      77         859 :         return detail::calc_ψ_ŷ(problem, x, y, Σ, ŷ);
      78           0 :     };
      79         811 :     auto calc_ψ_grad_ψ = [&problem, &y, &Σ, &work_n, &work_m](crvec x,
      80             :                                                               rvec grad_ψ) {
      81         780 :         return detail::calc_ψ_grad_ψ(problem, x, y, Σ, grad_ψ, work_n, work_m);
      82           0 :     };
      83         610 :     auto calc_grad_ψ_from_ŷ = [&problem, &work_n](crvec x, crvec ŷ,
      84             :                                                   rvec grad_ψ) {
      85         579 :         detail::calc_grad_ψ_from_ŷ(problem, x, ŷ, grad_ψ, work_n);
      86         579 :     };
      87         890 :     auto calc_x̂ = [&problem](real_t γ, crvec x, crvec grad_ψ, rvec x̂, rvec p) {
      88         859 :         detail::calc_x̂(problem, γ, x, grad_ψ, x̂, p);
      89         859 :     };
      90          62 :     auto calc_err_z = [&problem, &y, &Σ](crvec x̂, rvec err_z) {
      91          31 :         detail::calc_err_z(problem, x̂, y, Σ, err_z);
      92          31 :     };
      93         890 :     auto descent_lemma = [this, &problem, &y,
      94             :                           &Σ](crvec xₖ, real_t ψₖ, crvec grad_ψₖ, rvec x̂ₖ,
      95             :                               rvec pₖ, rvec ŷx̂ₖ, real_t &ψx̂ₖ, real_t &pₖᵀpₖ,
      96             :                               real_t &grad_ψₖᵀpₖ, real_t &Lₖ, real_t &γₖ) {
      97        1718 :         return detail::descent_lemma(
      98         859 :             problem, params.quadratic_upperbound_tolerance_factor, params.L_max,
      99         859 :             xₖ, ψₖ, grad_ψₖ, y, Σ, x̂ₖ, pₖ, ŷx̂ₖ, ψx̂ₖ, pₖᵀpₖ, grad_ψₖᵀpₖ, Lₖ, γₖ);
     100           0 :     };
     101          31 :     auto print_progress = [&](unsigned k, real_t ψₖ, crvec grad_ψₖ,
     102             :                               real_t pₖᵀpₖ, real_t γₖ, real_t εₖ) {
     103           0 :         std::cout << "[PANOC] " << std::setw(6) << k
     104           0 :                   << ": ψ = " << std::setw(13) << ψₖ
     105           0 :                   << ", ‖∇ψ‖ = " << std::setw(13) << grad_ψₖ.norm()
     106           0 :                   << ", ‖p‖ = " << std::setw(13) << std::sqrt(pₖᵀpₖ)
     107           0 :                   << ", γ = " << std::setw(13) << γₖ
     108           0 :                   << ", εₖ = " << std::setw(13) << εₖ << "\r\n";
     109           0 :     };
     110             : 
     111             :     // Estimate Lipschitz constant ---------------------------------------------
     112             : 
     113          31 :     real_t ψₖ, Lₖ;
     114             :     // Finite difference approximation of ∇²ψ in starting point
     115          31 :     if (params.Lipschitz.L₀ <= 0) {
     116          62 :         Lₖ = detail::initial_lipschitz_estimate(
     117          31 :             problem, xₖ, y, Σ, params.Lipschitz.ε, params.Lipschitz.δ,
     118          31 :             params.L_min, params.L_max,
     119          31 :             /* in ⟹ out */ ψₖ, grad_ψₖ, x̂ₖ, grad_ψₖ₊₁, work_n, work_m);
     120          31 :     }
     121             :     // Initial Lipschitz constant provided by the user
     122             :     else {
     123           0 :         Lₖ = params.Lipschitz.L₀;
     124             :         // Calculate ψ(xₖ), ∇ψ(x₀)
     125           0 :         ψₖ = calc_ψ_grad_ψ(xₖ, /* in ⟹ out */ grad_ψₖ);
     126             :     }
     127          31 :     if (not std::isfinite(Lₖ)) {
     128           0 :         s.status = SolverStatus::NotFinite;
     129           0 :         return s;
     130             :     }
     131          31 :     real_t γₖ = params.Lipschitz.Lγ_factor / Lₖ;
     132          31 :     real_t τ  = NaN;
     133             : 
     134             :     // First projected gradient step -------------------------------------------
     135             : 
     136             :     // Calculate x̂₀, p₀ (projected gradient step)
     137          31 :     calc_x̂(γₖ, xₖ, grad_ψₖ, /* in ⟹ out */ x̂ₖ, pₖ);
     138             :     // Calculate ψ(x̂ₖ) and ŷ(x̂ₖ)
     139          31 :     real_t ψx̂ₖ        = calc_ψ_ŷ(x̂ₖ, /* in ⟹ out */ ŷx̂ₖ);
     140          31 :     real_t grad_ψₖᵀpₖ = grad_ψₖ.dot(pₖ);
     141          31 :     real_t pₖᵀpₖ      = pₖ.squaredNorm();
     142             :     // Compute forward-backward envelope
     143          31 :     real_t φₖ = ψₖ + 1 / (2 * γₖ) * pₖᵀpₖ + grad_ψₖᵀpₖ;
     144             : 
     145             :     // Main PANOC loop
     146             :     // =========================================================================
     147         610 :     for (unsigned k = 0; k <= params.max_iter; ++k) {
     148             : 
     149             :         // Quadratic upper bound -----------------------------------------------
     150         579 :         if (k == 0 || params.update_lipschitz_in_linesearch == false) {
     151             :             // Decrease step size until quadratic upper bound is satisfied
     152          62 :             real_t old_γₖ =
     153          62 :                 descent_lemma(xₖ, ψₖ, grad_ψₖ,
     154          31 :                               /* in ⟹ out */ x̂ₖ, pₖ, ŷx̂ₖ,
     155             :                               /* inout */ ψx̂ₖ, pₖᵀpₖ, grad_ψₖᵀpₖ, Lₖ, γₖ);
     156          31 :             if (k > 0 && γₖ != old_γₖ) // Flush L-BFGS if γ changed
     157           0 :                 direction_provider.changed_γ(γₖ, old_γₖ);
     158          31 :             else if (k == 0) // Initialize L-BFGS
     159          31 :                 direction_provider.initialize(xₖ, x̂ₖ, pₖ, grad_ψₖ);
     160          31 :             if (γₖ != old_γₖ)
     161          13 :                 φₖ = ψₖ + 1 / (2 * γₖ) * pₖᵀpₖ + grad_ψₖᵀpₖ;
     162          31 :         }
     163             :         // Calculate ∇ψ(x̂ₖ)
     164         579 :         if (need_grad_̂ψₖ)
     165         579 :             calc_grad_ψ_from_ŷ(x̂ₖ, ŷx̂ₖ, /* in ⟹ out */ grad_̂ψₖ);
     166             : 
     167             :         // Check stop condition ------------------------------------------------
     168        1737 :         real_t εₖ = detail::calc_error_stop_crit(
     169         579 :             problem.C, params.stop_crit, pₖ, γₖ, xₖ, x̂ₖ, ŷx̂ₖ, grad_ψₖ, grad_̂ψₖ);
     170             : 
     171             :         // Print progress
     172         579 :         if (params.print_interval != 0 && k % params.print_interval == 0)
     173           0 :             print_progress(k, ψₖ, grad_ψₖ, pₖᵀpₖ, γₖ, εₖ);
     174         579 :         if (progress_cb)
     175           0 :             progress_cb({k, xₖ, pₖ, pₖᵀpₖ, x̂ₖ, φₖ, ψₖ, grad_ψₖ, ψx̂ₖ, grad_̂ψₖ,
     176           0 :                          Lₖ, γₖ, τ, εₖ, Σ, y, problem, params});
     177             : 
     178         579 :         auto time_elapsed = std::chrono::steady_clock::now() - start_time;
     179        1158 :         auto stop_status  = detail::check_all_stop_conditions(
     180         579 :             params, time_elapsed, k, stop_signal, ε, εₖ, no_progress);
     181         579 :         if (stop_status != SolverStatus::Unknown) {
     182             :             // TODO: We could cache g(x) and ẑ, but would that faster?
     183             :             //       It saves 1 evaluation of g per ALM iteration, but requires
     184             :             //       many extra stores in the inner loops of PANOC.
     185             :             // TODO: move the computation of ẑ and g(x) to ALM?
     186          31 :             if (stop_status == SolverStatus::Converged ||
     187           0 :                 stop_status == SolverStatus::Interrupted ||
     188           0 :                 always_overwrite_results) {
     189          31 :                 calc_err_z(x̂ₖ, /* in ⟹ out */ err_z);
     190          31 :                 x = std::move(x̂ₖ);
     191          31 :                 y = std::move(ŷx̂ₖ);
     192          31 :             }
     193          31 :             s.iterations   = k;
     194          31 :             s.ε            = εₖ;
     195          31 :             s.elapsed_time = duration_cast<microseconds>(time_elapsed);
     196          31 :             s.status       = stop_status;
     197          31 :             return s;
     198             :         }
     199             : 
     200             :         // Calculate quasi-Newton step -----------------------------------------
     201         548 :         real_t step_size =
     202         548 :             params.lbfgs_stepsize == LBFGSStepSize::BasedOnGradientStepSize
     203             :                 ? 1
     204             :                 : -1;
     205         548 :         if (k > 0)
     206         520 :             direction_provider.apply(xₖ, x̂ₖ, pₖ, step_size,
     207         520 :                                      /* in ⟹ out */ qₖ);
     208             : 
     209             :         // Line search initialization ------------------------------------------
     210         548 :         τ                  = 1;
     211         548 :         real_t σₖγₖ⁻¹pₖᵀpₖ = (1 - γₖ * Lₖ) * pₖᵀpₖ / (2 * γₖ);
     212         548 :         real_t φₖ₊₁, ψₖ₊₁, ψx̂ₖ₊₁, grad_ψₖ₊₁ᵀpₖ₊₁, pₖ₊₁ᵀpₖ₊₁;
     213         548 :         real_t Lₖ₊₁, γₖ₊₁;
     214         548 :         real_t ls_cond;
     215             :         // TODO: make separate parameter
     216        1096 :         real_t margin =
     217         548 :             (1 + std::abs(φₖ)) * params.quadratic_upperbound_tolerance_factor;
     218             : 
     219             :         // Make sure quasi-Newton step is valid
     220         548 :         if (k == 0) {
     221          28 :             τ = 0; // Always use prox step on first iteration
     222         548 :         } else if (not qₖ.allFinite()) {
     223           0 :             τ = 0;
     224           0 :             ++s.lbfgs_failures;
     225           0 :             direction_provider.reset(); // Is there anything else we can do?
     226           0 :         }
     227             : 
     228             :         // Line search loop ----------------------------------------------------
     229         548 :         do {
     230         828 :             Lₖ₊₁ = Lₖ;
     231         828 :             γₖ₊₁ = γₖ;
     232             : 
     233             :             // Calculate xₖ₊₁
     234         828 :             if (τ / 2 < params.τ_min) { // line search failed
     235          48 :                 xₖ₊₁.swap(x̂ₖ);          // → safe prox step
     236          48 :                 ψₖ₊₁ = ψx̂ₖ;
     237          48 :                 if (need_grad_̂ψₖ)
     238          48 :                     grad_ψₖ₊₁.swap(grad_̂ψₖ);
     239             :                 else
     240           0 :                     calc_grad_ψ_from_ŷ(xₖ₊₁, ŷx̂ₖ, /* in ⟹ out */ grad_ψₖ₊₁);
     241          48 :             } else {        // line search didn't fail (yet)
     242         780 :                 if (τ == 1) // → faster quasi-Newton step
     243         520 :                     xₖ₊₁ = xₖ + qₖ;
     244             :                 else
     245         260 :                     xₖ₊₁ = xₖ + (1 - τ) * pₖ + τ * qₖ;
     246             :                 // Calculate ψ(xₖ₊₁), ∇ψ(xₖ₊₁)
     247         780 :                 ψₖ₊₁ = calc_ψ_grad_ψ(xₖ₊₁, /* in ⟹ out */ grad_ψₖ₊₁);
     248             :             }
     249             : 
     250             :             // Calculate x̂ₖ₊₁, pₖ₊₁ (projected gradient step in xₖ₊₁)
     251         828 :             calc_x̂(γₖ₊₁, xₖ₊₁, grad_ψₖ₊₁, /* in ⟹ out */ x̂ₖ₊₁, pₖ₊₁);
     252             :             // Calculate ψ(x̂ₖ₊₁) and ŷ(x̂ₖ₊₁)
     253         828 :             ψx̂ₖ₊₁ = calc_ψ_ŷ(x̂ₖ₊₁, /* in ⟹ out */ ŷx̂ₖ₊₁);
     254             : 
     255             :             // Quadratic upper bound -------------------------------------------
     256         828 :             grad_ψₖ₊₁ᵀpₖ₊₁ = grad_ψₖ₊₁.dot(pₖ₊₁);
     257         828 :             pₖ₊₁ᵀpₖ₊₁      = pₖ₊₁.squaredNorm();
     258         828 :             real_t pₖ₊₁ᵀpₖ₊₁_ₖ = pₖ₊₁ᵀpₖ₊₁; // prox step with step size γₖ
     259             : 
     260         828 :             if (params.update_lipschitz_in_linesearch == true) {
     261             :                 // Decrease step size until quadratic upper bound is satisfied
     262        1656 :                 (void)descent_lemma(xₖ₊₁, ψₖ₊₁, grad_ψₖ₊₁,
     263         828 :                                     /* in ⟹ out */ x̂ₖ₊₁, pₖ₊₁, ŷx̂ₖ₊₁,
     264             :                                     /* inout */ ψx̂ₖ₊₁, pₖ₊₁ᵀpₖ₊₁,
     265             :                                     grad_ψₖ₊₁ᵀpₖ₊₁, Lₖ₊₁, γₖ₊₁);
     266         828 :             }
     267             : 
     268             :             // Compute forward-backward envelope
     269         828 :             φₖ₊₁ = ψₖ₊₁ + 1 / (2 * γₖ₊₁) * pₖ₊₁ᵀpₖ₊₁ + grad_ψₖ₊₁ᵀpₖ₊₁;
     270             :             // Compute line search condition
     271         828 :             ls_cond = φₖ₊₁ - (φₖ - σₖγₖ⁻¹pₖᵀpₖ);
     272         828 :             if (params.alternative_linesearch_cond)
     273           0 :                 ls_cond -= (0.5 / γₖ₊₁ - 0.5 / γₖ) * pₖ₊₁ᵀpₖ₊₁_ₖ;
     274             : 
     275         828 :             τ /= 2;
     276         828 :         } while (ls_cond > margin && τ >= params.τ_min);
     277             : 
     278             :         // If τ < τ_min the line search failed and we accepted the prox step
     279         548 :         if (τ < params.τ_min && k != 0) {
     280          20 :             ++s.linesearch_failures;
     281          20 :             τ = 0;
     282          20 :         }
     283         548 :         if (k != 0) {
     284         520 :             s.count_τ += 1;
     285         520 :             s.sum_τ += τ * 2;
     286         520 :             s.τ_1_accepted += τ * 2 == 1;
     287         520 :         }
     288             : 
     289             :         // Update L-BFGS -------------------------------------------------------
     290         548 :         if (γₖ != γₖ₊₁) // Flush L-BFGS if γ changed
     291           9 :             direction_provider.changed_γ(γₖ₊₁, γₖ);
     292             : 
     293        1096 :         s.lbfgs_rejected += not direction_provider.update(
     294         548 :             xₖ, xₖ₊₁, pₖ, pₖ₊₁, grad_ψₖ₊₁, problem.C, γₖ₊₁);
     295             : 
     296             :         // Check if we made any progress
     297         548 :         if (no_progress > 0 || k % params.max_no_progress == 0)
     298          69 :             no_progress = xₖ == xₖ₊₁ ? no_progress + 1 : 0;
     299             : 
     300             :         // Advance step --------------------------------------------------------
     301         548 :         Lₖ = Lₖ₊₁;
     302         548 :         γₖ = γₖ₊₁;
     303             : 
     304         548 :         ψₖ  = ψₖ₊₁;
     305         548 :         ψx̂ₖ = ψx̂ₖ₊₁;
     306         548 :         φₖ  = φₖ₊₁;
     307             : 
     308         548 :         xₖ.swap(xₖ₊₁);
     309         548 :         x̂ₖ.swap(x̂ₖ₊₁);
     310         548 :         ŷx̂ₖ.swap(ŷx̂ₖ₊₁);
     311         548 :         pₖ.swap(pₖ₊₁);
     312         548 :         grad_ψₖ.swap(grad_ψₖ₊₁);
     313         548 :         grad_ψₖᵀpₖ = grad_ψₖ₊₁ᵀpₖ₊₁;
     314         548 :         pₖᵀpₖ      = pₖ₊₁ᵀpₖ₊₁;
     315         579 :     }
     316           0 :     throw std::logic_error("[PANOC] loop error");
     317          31 : }
     318             : 
     319             : } // namespace pa

Generated by: LCOV version 1.15