Line data Source code
1 : #pragma once
2 :
3 : #include <panoc-alm/detail/alm-helpers.hpp>
4 : #include <panoc-alm/util/solverstatus.hpp>
5 :
6 : #include <iomanip>
7 : #include <iostream>
8 :
9 : namespace pa {
10 :
11 : using std::chrono::duration_cast;
12 : using std::chrono::microseconds;
13 :
14 : template <class InnerSolverT>
15 : typename ALMSolver<InnerSolverT>::Stats
16 4 : ALMSolver<InnerSolverT>::operator()(const Problem &problem, rvec y, rvec x) {
17 4 : auto start_time = std::chrono::steady_clock::now();
18 :
19 4 : constexpr auto sigNaN = std::numeric_limits<real_t>::signaling_NaN();
20 4 : vec Σ = vec::Constant(problem.m, sigNaN);
21 4 : vec Σ_old = vec::Constant(problem.m, sigNaN);
22 4 : vec error₁ = vec::Constant(problem.m, sigNaN);
23 4 : vec error₂ = vec::Constant(problem.m, sigNaN);
24 4 : real_t norm_e₁ = sigNaN;
25 4 : real_t norm_e₂ = sigNaN;
26 :
27 4 : Stats s;
28 :
29 4 : Problem prec_problem;
30 4 : real_t prec_f;
31 4 : vec prec_g;
32 :
33 4 : if (params.preconditioning)
34 0 : detail::apply_preconditioning(problem, prec_problem, x, prec_f, prec_g);
35 4 : const auto &p = params.preconditioning ? prec_problem : problem;
36 :
37 : // Initialize the penalty weights
38 4 : if (params.Σ₀ > 0) {
39 3 : Σ.fill(params.Σ₀);
40 3 : }
41 : // Initial penalty weights from problem
42 : else {
43 1 : detail::initialize_penalty(p, params, x, Σ);
44 : }
45 :
46 4 : real_t ε = params.ε₀;
47 4 : real_t ε_old = sigNaN;
48 4 : real_t Δ = params.Δ;
49 4 : real_t ρ = params.ρ;
50 4 : bool first_successful_iter = true;
51 :
52 34 : for (unsigned int i = 0; i < params.max_iter; ++i) {
53 : // TODO: this is unnecessary when the previous iteration lowered the
54 : // penalty update factor.
55 30 : detail::project_y(y, p.D.lowerbound, p.D.upperbound, params.M);
56 : // Check if we're allowed to lower the penalty factor even further.
57 60 : bool out_of_penalty_factor_updates =
58 34 : (first_successful_iter
59 4 : ? s.initial_penalty_reduced == params.max_num_initial_retries
60 60 : : s.penalty_reduced == params.max_num_retries) ||
61 30 : (s.initial_penalty_reduced + s.penalty_reduced ==
62 30 : params.max_total_num_retries);
63 30 : bool out_of_iter = i + 1 == params.max_iter;
64 : // If this is the final iteration, or the final chance to reduce the
65 : // penalty update factor, the inner solver can just return its results,
66 : // even if it doesn't converge.
67 30 : bool overwrite_results = out_of_iter || out_of_penalty_factor_updates;
68 :
69 : // Inner solver
70 : // ------------
71 :
72 : // Call the inner solver to minimize the augmented lagrangian for fixed
73 : // Lagrange multipliers y.
74 30 : auto ps = inner_solver(p, Σ, ε, overwrite_results, x, y, error₂);
75 30 : bool inner_converged = ps.status == SolverStatus::Converged;
76 : // Accumulate the inner solver statistics
77 30 : s.inner_convergence_failures += not inner_converged;
78 30 : s.inner += ps;
79 :
80 30 : auto time_elapsed = std::chrono::steady_clock::now() - start_time;
81 30 : bool out_of_time = time_elapsed > params.max_time;
82 60 : bool backtrack =
83 30 : not inner_converged && not overwrite_results && not out_of_time;
84 :
85 : // Print statistics of current iteration
86 30 : if (params.print_interval != 0 && i % params.print_interval == 0) {
87 11 : real_t δ = backtrack ? NaN : vec_util::norm_inf(error₂);
88 11 : auto color = inner_converged ? "\x1b[0;32m" : "\x1b[0;31m";
89 11 : auto color_end = "\x1b[0m";
90 11 : std::cout << "[\x1b[0;34mALM\x1b[0m] " << std::setw(5) << i
91 11 : << ": ‖Σ‖ = " << std::setw(13) << Σ.norm()
92 11 : << ", ‖y‖ = " << std::setw(13) << y.norm()
93 11 : << ", δ = " << std::setw(13) << δ
94 11 : << ", ε = " << std::setw(13) << ps.ε
95 11 : << ", Δ = " << std::setw(13) << Δ
96 11 : << ", status = " << color << std::setw(13) << ps.status
97 11 : << color_end << ", iter = " << std::setw(13)
98 11 : << ps.iterations << "\r\n";
99 11 : }
100 :
101 : // TODO: check penalty size?
102 30 : if (ps.status == SolverStatus::Interrupted) {
103 0 : s.ε = ps.ε;
104 0 : s.δ = vec_util::norm_inf(error₂);
105 0 : s.norm_penalty = Σ.norm();
106 0 : s.outer_iterations = i + 1;
107 0 : s.elapsed_time = duration_cast<microseconds>(time_elapsed);
108 0 : s.status = ps.status;
109 0 : if (params.preconditioning)
110 0 : y = prec_g.asDiagonal() * y / prec_f;
111 0 : return s;
112 : }
113 :
114 : // Backtrack and lower penalty if inner solver did not converge
115 30 : if (backtrack) {
116 : // This means the inner solver didn't produce a solution that
117 : // satisfies the required tolerance.
118 : // The best thing we can do now is to restore the penalty to its
119 : // previous value (when the inner solver did converge), then lower
120 : // the penalty factor, and update the penalty with this smaller
121 : // factor.
122 : // error₂ was not overwritten by the inner solver, so it still
123 : // contains the error from the iteration before the previous
124 : // successful iteration. error₁ contains the error of the last
125 : // successful iteration.
126 0 : if (not first_successful_iter) {
127 : // We have a previous Σ and error
128 : // Recompute penalty with smaller Δ
129 0 : Δ = std::fmax(1., Δ * params.Δ_lower);
130 0 : detail::update_penalty_weights(params, Δ, first_successful_iter,
131 0 : error₁, error₂, norm_e₁, norm_e₂,
132 0 : Σ_old, Σ);
133 : // Recompute the primal tolerance with larger ρ
134 0 : ρ = std::fmin(0.5, ρ * params.ρ_increase); // keep ρ <= 0.5
135 0 : ε = std::fmax(ρ * ε_old, params.ε);
136 0 : ++s.penalty_reduced;
137 0 : } else {
138 : // We don't have a previous Σ, simply lower the current Σ and
139 : // increase ε
140 0 : Σ *= params.Σ₀_lower;
141 0 : ε *= params.ε₀_increase;
142 0 : ++s.initial_penalty_reduced;
143 : }
144 0 : }
145 :
146 : // If the inner solver did converge, increase penalty
147 : else {
148 : // After this line, error₁ contains the error of the current
149 : // (successful) iteration, and error₂ contains the error of the
150 : // previous successful iteration.
151 30 : error₂.swap(error₁);
152 30 : norm_e₂ = std::exchange(norm_e₁, vec_util::norm_inf(error₁));
153 :
154 : // Check the termination criteria
155 60 : bool alm_converged =
156 30 : ps.ε <= params.ε && inner_converged && norm_e₁ <= params.δ;
157 30 : bool exit = alm_converged || out_of_iter || out_of_time;
158 30 : if (exit) {
159 4 : s.ε = ps.ε;
160 4 : s.δ = norm_e₁;
161 4 : s.norm_penalty = Σ.norm();
162 4 : s.outer_iterations = i + 1;
163 4 : s.elapsed_time = duration_cast<microseconds>(time_elapsed);
164 4 : s.status = alm_converged ? SolverStatus::Converged
165 0 : : out_of_time ? SolverStatus::MaxTime
166 0 : : out_of_iter ? SolverStatus::MaxIter
167 : : SolverStatus::Unknown;
168 4 : if (params.preconditioning)
169 0 : y = prec_g.asDiagonal() * y / prec_f;
170 4 : return s;
171 : }
172 : // After this line, Σ_old contains the penalty used in the current
173 : // (successful) iteration.
174 26 : Σ_old.swap(Σ);
175 : // Update Σ to contain the penalty to use on the next iteration.
176 52 : detail::update_penalty_weights(params, Δ, first_successful_iter,
177 26 : error₁, error₂, norm_e₁, norm_e₂,
178 26 : Σ_old, Σ);
179 : // Lower the primal tolerance for the inner solver.
180 26 : ε_old = std::exchange(ε, std::fmax(ρ * ε, params.ε));
181 26 : first_successful_iter = false;
182 30 : }
183 30 : }
184 0 : throw std::logic_error("[ALM] loop error");
185 4 : }
186 :
187 : } // namespace pa
|