LCOV - code coverage report
Current view: top level - src/include/panoc-alm - alm.hpp (source / functions) Hit Total Coverage
Test: ecee3ec3a495b05c61f077aa7a236b7e00601437 Lines: 81 108 75.0 %
Date: 2021-11-04 22:49:09 Functions: 1 2 50.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : #pragma once
       2             : 
       3             : #include <panoc-alm/detail/alm-helpers.hpp>
       4             : #include <panoc-alm/util/solverstatus.hpp>
       5             : 
       6             : #include <iomanip>
       7             : #include <iostream>
       8             : 
       9             : namespace pa {
      10             : 
      11             : using std::chrono::duration_cast;
      12             : using std::chrono::microseconds;
      13             : 
      14             : template <class InnerSolverT>
      15             : typename ALMSolver<InnerSolverT>::Stats
      16           4 : ALMSolver<InnerSolverT>::operator()(const Problem &problem, rvec y, rvec x) {
      17           4 :     auto start_time = std::chrono::steady_clock::now();
      18             : 
      19           4 :     constexpr auto sigNaN = std::numeric_limits<real_t>::signaling_NaN();
      20           4 :     vec Σ                 = vec::Constant(problem.m, sigNaN);
      21           4 :     vec Σ_old             = vec::Constant(problem.m, sigNaN);
      22           4 :     vec error₁            = vec::Constant(problem.m, sigNaN);
      23           4 :     vec error₂            = vec::Constant(problem.m, sigNaN);
      24           4 :     real_t norm_e₁        = sigNaN;
      25           4 :     real_t norm_e₂        = sigNaN;
      26             : 
      27           4 :     Stats s;
      28             : 
      29           4 :     Problem prec_problem;
      30           4 :     real_t prec_f;
      31           4 :     vec prec_g;
      32             : 
      33           4 :     if (params.preconditioning)
      34           0 :         detail::apply_preconditioning(problem, prec_problem, x, prec_f, prec_g);
      35           4 :     const auto &p = params.preconditioning ? prec_problem : problem;
      36             : 
      37             :     // Initialize the penalty weights
      38           4 :     if (params.Σ₀ > 0) {
      39           3 :         Σ.fill(params.Σ₀);
      40           3 :     }
      41             :     // Initial penalty weights from problem
      42             :     else {
      43           1 :         detail::initialize_penalty(p, params, x, Σ);
      44             :     }
      45             : 
      46           4 :     real_t ε                   = params.ε₀;
      47           4 :     real_t ε_old               = sigNaN;
      48           4 :     real_t Δ                   = params.Δ;
      49           4 :     real_t ρ                   = params.ρ;
      50           4 :     bool first_successful_iter = true;
      51             : 
      52          34 :     for (unsigned int i = 0; i < params.max_iter; ++i) {
      53             :         // TODO: this is unnecessary when the previous iteration lowered the
      54             :         // penalty update factor.
      55          30 :         detail::project_y(y, p.D.lowerbound, p.D.upperbound, params.M);
      56             :         // Check if we're allowed to lower the penalty factor even further.
      57          60 :         bool out_of_penalty_factor_updates =
      58          34 :             (first_successful_iter
      59           4 :                  ? s.initial_penalty_reduced == params.max_num_initial_retries
      60          60 :                  : s.penalty_reduced == params.max_num_retries) ||
      61          30 :             (s.initial_penalty_reduced + s.penalty_reduced ==
      62          30 :              params.max_total_num_retries);
      63          30 :         bool out_of_iter = i + 1 == params.max_iter;
      64             :         // If this is the final iteration, or the final chance to reduce the
      65             :         // penalty update factor, the inner solver can just return its results,
      66             :         // even if it doesn't converge.
      67          30 :         bool overwrite_results = out_of_iter || out_of_penalty_factor_updates;
      68             : 
      69             :         // Inner solver
      70             :         // ------------
      71             : 
      72             :         // Call the inner solver to minimize the augmented lagrangian for fixed
      73             :         // Lagrange multipliers y.
      74          30 :         auto ps = inner_solver(p, Σ, ε, overwrite_results, x, y, error₂);
      75          30 :         bool inner_converged = ps.status == SolverStatus::Converged;
      76             :         // Accumulate the inner solver statistics
      77          30 :         s.inner_convergence_failures += not inner_converged;
      78          30 :         s.inner += ps;
      79             : 
      80          30 :         auto time_elapsed = std::chrono::steady_clock::now() - start_time;
      81          30 :         bool out_of_time  = time_elapsed > params.max_time;
      82          60 :         bool backtrack =
      83          30 :             not inner_converged && not overwrite_results && not out_of_time;
      84             : 
      85             :         // Print statistics of current iteration
      86          30 :         if (params.print_interval != 0 && i % params.print_interval == 0) {
      87          11 :             real_t δ       = backtrack ? NaN : vec_util::norm_inf(error₂);
      88          11 :             auto color     = inner_converged ? "\x1b[0;32m" : "\x1b[0;31m";
      89          11 :             auto color_end = "\x1b[0m";
      90          11 :             std::cout << "[\x1b[0;34mALM\x1b[0m]   " << std::setw(5) << i
      91          11 :                       << ": ‖Σ‖ = " << std::setw(13) << Σ.norm()
      92          11 :                       << ", ‖y‖ = " << std::setw(13) << y.norm()
      93          11 :                       << ", δ = " << std::setw(13) << δ
      94          11 :                       << ", ε = " << std::setw(13) << ps.ε
      95          11 :                       << ", Δ = " << std::setw(13) << Δ
      96          11 :                       << ", status = " << color << std::setw(13) << ps.status
      97          11 :                       << color_end << ", iter = " << std::setw(13)
      98          11 :                       << ps.iterations << "\r\n";
      99          11 :         }
     100             : 
     101             :         // TODO: check penalty size?
     102          30 :         if (ps.status == SolverStatus::Interrupted) {
     103           0 :             s.ε                = ps.ε;
     104           0 :             s.δ                = vec_util::norm_inf(error₂);
     105           0 :             s.norm_penalty     = Σ.norm();
     106           0 :             s.outer_iterations = i + 1;
     107           0 :             s.elapsed_time     = duration_cast<microseconds>(time_elapsed);
     108           0 :             s.status           = ps.status;
     109           0 :             if (params.preconditioning)
     110           0 :                 y = prec_g.asDiagonal() * y / prec_f;
     111           0 :             return s;
     112             :         }
     113             : 
     114             :         // Backtrack and lower penalty if inner solver did not converge
     115          30 :         if (backtrack) {
     116             :             // This means the inner solver didn't produce a solution that
     117             :             // satisfies the required tolerance.
     118             :             // The best thing we can do now is to restore the penalty to its
     119             :             // previous value (when the inner solver did converge), then lower
     120             :             // the penalty factor, and update the penalty with this smaller
     121             :             // factor.
     122             :             // error₂ was not overwritten by the inner solver, so it still
     123             :             // contains the error from the iteration before the previous
     124             :             // successful iteration. error₁ contains the error of the last
     125             :             // successful iteration.
     126           0 :             if (not first_successful_iter) {
     127             :                 // We have a previous Σ and error
     128             :                 // Recompute penalty with smaller Δ
     129           0 :                 Δ = std::fmax(1., Δ * params.Δ_lower);
     130           0 :                 detail::update_penalty_weights(params, Δ, first_successful_iter,
     131           0 :                                                error₁, error₂, norm_e₁, norm_e₂,
     132           0 :                                                Σ_old, Σ);
     133             :                 // Recompute the primal tolerance with larger ρ
     134           0 :                 ρ = std::fmin(0.5, ρ * params.ρ_increase); // keep ρ <= 0.5
     135           0 :                 ε = std::fmax(ρ * ε_old, params.ε);
     136           0 :                 ++s.penalty_reduced;
     137           0 :             } else {
     138             :                 // We don't have a previous Σ, simply lower the current Σ and
     139             :                 // increase ε
     140           0 :                 Σ *= params.Σ₀_lower;
     141           0 :                 ε *= params.ε₀_increase;
     142           0 :                 ++s.initial_penalty_reduced;
     143             :             }
     144           0 :         }
     145             : 
     146             :         // If the inner solver did converge, increase penalty
     147             :         else {
     148             :             // After this line, error₁ contains the error of the current
     149             :             // (successful) iteration, and error₂ contains the error of the
     150             :             // previous successful iteration.
     151          30 :             error₂.swap(error₁);
     152          30 :             norm_e₂ = std::exchange(norm_e₁, vec_util::norm_inf(error₁));
     153             : 
     154             :             // Check the termination criteria
     155          60 :             bool alm_converged =
     156          30 :                 ps.ε <= params.ε && inner_converged && norm_e₁ <= params.δ;
     157          30 :             bool exit = alm_converged || out_of_iter || out_of_time;
     158          30 :             if (exit) {
     159           4 :                 s.ε                = ps.ε;
     160           4 :                 s.δ                = norm_e₁;
     161           4 :                 s.norm_penalty     = Σ.norm();
     162           4 :                 s.outer_iterations = i + 1;
     163           4 :                 s.elapsed_time     = duration_cast<microseconds>(time_elapsed);
     164           4 :                 s.status           = alm_converged ? SolverStatus::Converged
     165           0 :                                      : out_of_time ? SolverStatus::MaxTime
     166           0 :                                      : out_of_iter ? SolverStatus::MaxIter
     167             :                                                    : SolverStatus::Unknown;
     168           4 :                 if (params.preconditioning)
     169           0 :                     y = prec_g.asDiagonal() * y / prec_f;
     170           4 :                 return s;
     171             :             }
     172             :             // After this line, Σ_old contains the penalty used in the current
     173             :             // (successful) iteration.
     174          26 :             Σ_old.swap(Σ);
     175             :             // Update Σ to contain the penalty to use on the next iteration.
     176          52 :             detail::update_penalty_weights(params, Δ, first_successful_iter,
     177          26 :                                            error₁, error₂, norm_e₁, norm_e₂,
     178          26 :                                            Σ_old, Σ);
     179             :             // Lower the primal tolerance for the inner solver.
     180          26 :             ε_old = std::exchange(ε, std::fmax(ρ * ε, params.ε));
     181          26 :             first_successful_iter = false;
     182          30 :         }
     183          30 :     }
     184           0 :     throw std::logic_error("[ALM]   loop error");
     185           4 : }
     186             : 
     187             : } // namespace pa

Generated by: LCOV version 1.15