Line data Source code
1 : #pragma once
2 :
3 : #include <panoc-alm/inner/directions/decl/lbfgs.hpp>
4 : #include <stdexcept>
5 : #include <type_traits>
6 :
7 : namespace pa {
8 :
9 190742 : inline bool LBFGS::update_valid(LBFGSParams params, real_t yᵀs, real_t sᵀs,
10 : real_t pᵀp) {
11 : // Smallest number we want to divide by without overflow
12 190742 : const real_t min_divisor = std::sqrt(std::numeric_limits<real_t>::min());
13 :
14 : // Check if this L-BFGS update is accepted
15 190742 : if (not std::isfinite(yᵀs))
16 0 : return false;
17 190742 : if (yᵀs < min_divisor)
18 20009 : return false;
19 170733 : if (sᵀs < min_divisor)
20 0 : return false;
21 :
22 : // CBFGS condition: https://epubs.siam.org/doi/10.1137/S1052623499354242
23 170733 : real_t α = params.cbfgs.α;
24 170733 : real_t ϵ = params.cbfgs.ϵ;
25 : // Condition: yᵀs / sᵀs >= ϵ ‖p‖^α
26 170733 : bool cbfgs_cond = yᵀs / sᵀs >= ϵ * std::pow(pᵀp, α / 2);
27 170733 : if (not cbfgs_cond)
28 0 : return false;
29 :
30 170733 : return true;
31 190742 : }
32 :
33 190726 : inline bool LBFGS::update(crvec xₖ, crvec xₖ₊₁, crvec pₖ, crvec pₖ₊₁, Sign sign,
34 : bool forced) {
35 190726 : const auto s = xₖ₊₁ - xₖ;
36 190726 : const auto y = sign == Sign::Positive ? pₖ₊₁ - pₖ : pₖ - pₖ₊₁;
37 190726 : real_t yᵀs = y.dot(s);
38 190726 : real_t ρ = 1 / yᵀs;
39 190726 : if (not forced) {
40 190726 : real_t sᵀs = s.squaredNorm();
41 190726 : real_t pᵀp = params.cbfgs.ϵ > 0 ? pₖ₊₁.squaredNorm() : 0;
42 190726 : if (not update_valid(params, yᵀs, sᵀs, pᵀp))
43 20009 : return false;
44 190726 : }
45 :
46 : // Store the new s and y vectors
47 170717 : this->s(idx) = s;
48 170717 : this->y(idx) = y;
49 170717 : this->ρ(idx) = ρ;
50 :
51 : // Increment the index in the circular buffer
52 170717 : idx = succ(idx);
53 170717 : full |= idx == 0;
54 :
55 170717 : return true;
56 190726 : }
57 :
58 : template <class Vec>
59 180710 : bool LBFGS::apply(Vec &&q, real_t γ) {
60 : // Only apply if we have previous vectors s and y
61 180710 : if (idx == 0 && not full)
62 0 : return false;
63 :
64 : // If the step size is negative, compute it as sᵀy/yᵀy
65 180710 : if (γ < 0) {
66 180430 : auto new_idx = idx > 0 ? idx - 1 : history() - 1;
67 180430 : real_t yᵀy = y(new_idx).squaredNorm();
68 180430 : γ = 1. / (ρ(new_idx) * yᵀy);
69 180430 : }
70 :
71 1729003 : auto update1 = [&](size_t i) {
72 1548293 : α(i) = ρ(i) * (s(i).dot(q));
73 1548293 : q -= α(i) * y(i);
74 1548293 : };
75 180710 : if (idx)
76 1725692 : for (size_t i = idx; i-- > 0;)
77 1725692 : update1(i);
78 180710 : if (full)
79 3603 : for (size_t i = history(); i-- > idx;)
80 3603 : update1(i);
81 :
82 : // r ← H₀ q
83 180710 : q *= γ;
84 :
85 1729003 : auto update2 = [&](size_t i) {
86 1548293 : real_t β = ρ(i) * (y(i).dot(q));
87 1548293 : q += (α(i) - β) * s(i);
88 1548293 : };
89 180710 : if (full)
90 3603 : for (size_t i = idx; i < history(); ++i)
91 3603 : update2(i);
92 1725725 : for (size_t i = 0; i < idx; ++i)
93 1545015 : update2(i);
94 :
95 180710 : return true;
96 180710 : }
97 :
98 : template <class Vec, class IndexVec>
99 : bool LBFGS::apply(Vec &&q, real_t γ, const IndexVec &J) {
100 : // Only apply if we have previous vectors s and y
101 : if (idx == 0 && not full)
102 : return false;
103 : using Index = typename std::remove_reference_t<Vec>::Index;
104 : bool fullJ = q.size() == Index(J.size());
105 :
106 : // Eigen 3.3.9 doesn't yet support indexing using a vector of indices
107 : // so we'll have to do it manually
108 : // TODO: Abstract this away in an expression template / nullary expression?
109 : // Or wait for Eigen update?
110 :
111 : // Dot product of two vectors, adding only the indices in set J
112 : auto dotJ = [&J, fullJ](const auto &a, const auto &b) {
113 : if (fullJ) {
114 : return a.dot(b);
115 : } else {
116 : real_t acc = 0;
117 : for (auto j : J)
118 : acc += a(j) * b(j);
119 : return acc;
120 : }
121 : };
122 :
123 : auto update1 = [&](size_t i) {
124 : // Recompute ρ, it depends on the index set J. Note that even if ρ was
125 : // positive for the full vectors s and y, that's not necessarily the
126 : // case for the smaller vectors s(J) and y(J).
127 : if (not fullJ)
128 : ρ(i) = 1. / dotJ(s(i), y(i));
129 :
130 : if (ρ(i) <= 0) // Reject negative ρ to ensure positive definiteness
131 : return;
132 :
133 : α(i) = ρ(i) * dotJ(s(i), q);
134 : if (fullJ)
135 : q -= α(i) * y(i);
136 : else
137 : for (auto j : J)
138 : q(j) -= α(i) * y(i)(j);
139 :
140 : if (γ < 0) {
141 : // Compute step size based on most recent yᵀs/yᵀy > 0
142 : real_t yᵀy = dotJ(y(i), y(i));
143 : γ = 1. / (ρ(i) * yᵀy);
144 : }
145 : };
146 : if (idx)
147 : for (size_t i = idx; i-- > 0;)
148 : update1(i);
149 : if (full)
150 : for (size_t i = history(); i-- > idx;)
151 : update1(i);
152 :
153 : // If all ρ <= 0, fail
154 : if (γ < 0)
155 : return false;
156 :
157 : // r ← H₀ q
158 : if (fullJ)
159 : q *= γ;
160 : else
161 : for (auto j : J)
162 : q(j) *= γ;
163 :
164 : auto update2 = [&](size_t i) {
165 : if (ρ(i) <= 0)
166 : return;
167 : real_t β = ρ(i) * dotJ(y(i), q);
168 : if (fullJ)
169 : q += (α(i) - β) * s(i);
170 : else
171 : for (auto j : J)
172 : q(j) += (α(i) - β) * s(i)(j);
173 : };
174 : if (full)
175 : for (size_t i = idx; i < history(); ++i)
176 : update2(i);
177 : for (size_t i = 0; i < idx; ++i)
178 : update2(i);
179 :
180 : return true;
181 : }
182 :
183 10048 : inline void LBFGS::reset() {
184 10048 : idx = 0;
185 10048 : full = false;
186 10048 : }
187 :
188 10037 : inline void LBFGS::resize(size_t n) {
189 10037 : if (params.memory < 1)
190 0 : throw std::invalid_argument("LBFGSParams::memory must be > 1");
191 10037 : sto.resize(n + 1, params.memory * 2);
192 10037 : reset();
193 10037 : }
194 :
195 0 : inline void LBFGS::scale_y(real_t factor) {
196 0 : if (full) {
197 0 : for (size_t i = 0; i < history(); ++i) {
198 0 : y(i) *= factor;
199 0 : ρ(i) *= 1. / factor;
200 0 : }
201 0 : } else {
202 0 : for (size_t i = 0; i < idx; ++i) {
203 0 : y(i) *= factor;
204 0 : ρ(i) *= 1. / factor;
205 0 : }
206 : }
207 0 : }
208 :
209 10036 : inline void PANOCDirection<LBFGS>::initialize(crvec x₀, crvec x̂₀, crvec p₀,
210 : crvec grad₀) {
211 10036 : lbfgs.resize(x₀.size());
212 : (void)x̂₀;
213 : (void)p₀;
214 : (void)grad₀;
215 10036 : }
216 :
217 190716 : inline bool PANOCDirection<LBFGS>::update(crvec xₖ, crvec xₖ₊₁, crvec pₖ,
218 : crvec pₖ₊₁, crvec grad_new,
219 : const Box &C, real_t γ_new) {
220 : (void)grad_new;
221 190716 : (void)C;
222 : (void)γ_new;
223 190716 : return lbfgs.update(xₖ, xₖ₊₁, pₖ, pₖ₊₁, LBFGS::Sign::Negative);
224 0 : }
225 :
226 180683 : inline bool PANOCDirection<LBFGS>::apply(crvec xₖ, crvec x̂ₖ, crvec pₖ, real_t γ,
227 : rvec qₖ) {
228 : (void)xₖ;
229 : (void)x̂ₖ;
230 180683 : qₖ = pₖ;
231 180683 : return lbfgs.apply(qₖ, γ);
232 : }
233 :
234 11 : inline void PANOCDirection<LBFGS>::changed_γ(real_t γₖ, real_t old_γₖ) {
235 11 : if (lbfgs.get_params().rescale_when_γ_changes)
236 0 : lbfgs.scale_y(γₖ / old_γₖ);
237 : else
238 11 : lbfgs.reset();
239 11 : }
240 :
241 0 : inline void PANOCDirection<LBFGS>::reset() { lbfgs.reset(); }
242 :
243 0 : inline std::string PANOCDirection<LBFGS>::get_name() const {
244 0 : return lbfgs.get_name();
245 : }
246 :
247 : inline LBFGSParams PANOCDirection<LBFGS>::get_params() const {
248 : return lbfgs.get_params();
249 : }
250 :
251 : } // namespace pa
|