11 #include <type_traits>
13 #include <pybind11/cast.h>
14 #include <pybind11/pybind11.h>
15 namespace py = pybind11;
19 template <
class InnerSolver>
23 pa::vec y) -> std::tuple<pa::vec, pa::vec, pa::vec, py::dict> {
26 return std::make_tuple(std::move(
x), std::move(
y), std::move(
z),
27 stats.ptr->to_dict());
31 template <
class InnerSolver>
35 -> std::tuple<pa::vec, pa::vec, pa::vec, pa::vec, py::dict> {
38 (void)
solver(
p, Σ1, Σ2,
ε,
true,
x,
y, z1, z2);
39 return std::make_tuple(std::move(
x), std::move(
y), std::move(z1),
40 std::move(z2), py::dict{});
45 :
public std::enable_shared_from_this<
46 PolymorphicInnerSolverStatsAccumulatorBase> {
54 :
public std::enable_shared_from_this<PolymorphicInnerSolverStatsBase> {
58 virtual std::shared_ptr<PolymorphicInnerSolverStatsAccumulatorBase>
63 :
public std::enable_shared_from_this<PolymorphicInnerSolverBase> {
66 std::shared_ptr<PolymorphicInnerSolverStatsBase>
ptr;
75 struct AccStats : PolyAccStats {
76 AccStats(py::dict dict) : dict(std::move(dict)) {}
78 py::dict to_dict()
const override {
return dict; }
79 void accumulate(
const PolyStats &s)
override {
80 if (this->dict.contains(
"accumulate"))
81 this->dict[
"accumulate"](this->dict, s.to_dict());
83 throw py::key_error(
"Stats accumulator does not define "
84 "an accumulate function");
87 struct Stats : PolyStats {
88 Stats(py::dict dict) : dict(std::move(dict)) {}
90 py::dict to_dict()
const override {
return dict; }
91 std::shared_ptr<PolyAccStats> accumulator()
const override {
92 if (this->dict.contains(
"accumulator"))
94 std::make_shared<AccStats>(
95 dict[
"accumulator"].cast<py::dict>()),
99 "Stats do not define an accumulator");
102 bool ok = d.contains(
"status") && d.contains(
"ε") &&
103 d.contains(
"iterations");
106 "Stats should contain status, ε and iterations");
108 std::static_pointer_cast<PolyStats>(std::make_shared<Stats>(d)),
109 d[
"status"].cast<decltype(InnerStats::status)>(),
111 d[
"iterations"].cast<decltype(InnerStats::iterations)>(),
125 bool always_overwrite_results,
139 std::shared_ptr<PolymorphicInnerSolverBase>
solver;
141 std::shared_ptr<PolymorphicInnerSolverBase> &&
solver)
145 bool always_overwrite_results,
rvec x,
rvec y,
155 template <
class InnerSolverStats>
156 struct InnerStatsAccumulator;
160 std::shared_ptr<PolymorphicInnerSolverStatsAccumulatorBase>
ptr;
161 py::dict
to_dict()
const {
return ptr->to_dict(); }
164 inline InnerStatsAccumulator<PolymorphicInnerSolverWrapper::Stats> &
169 acc.
ptr = s.
ptr->accumulator();
170 acc.
ptr->accumulate(*s.
ptr);
177 bool always_overwrite_results,
rvec x,
rvec y,
184 virtual std::tuple<pa::vec, pa::vec, pa::vec, py::dict>
187 using ret = std::tuple<pa::vec, pa::vec, pa::vec, py::dict>;
190 always_overwrite_results,
x,
y);
206 using py::operator
""_a;
222 using py::operator
""_a;
236 using py::operator
""_a;
252 using py::operator
""_a;
262 using py::operator
""_a;
274 using py::operator
""_a;
289 using py::operator
""_a;
298 using py::operator
""_a;
306 template <
class InnerSolver>
313 template <
class... Args>
315 :
innersolver(InnerSolver{std::forward<Args>(args)...}) {}
328 using Stats =
typename InnerSolver::Stats;
331 std::shared_ptr<PolymorphicInnerSolverStatsAccumulatorBase>
333 return std::static_pointer_cast<
335 std::make_shared<WrappedStatsAccumulator>());
348 bool always_overwrite_results,
358 std::static_pointer_cast<PolymorphicInnerSolverStatsBase>(
359 std::make_shared<WrappedStats>(
stats)),
372 std::function<
void(
const typename InnerSolver::ProgressInfo &)> cb) {
373 this->
innersolver.set_progress_callback(std::move(cb));
480 using py::operator
""_a;
491 "inner"_a = s.
inner.to_dict(),
495 template <
class InnerSolver>
498 using py::operator
""_a;
500 "outer_iterations"_a = s.outer_iterations,
501 "elapsed_time"_a = s.elapsed_time,
502 "initial_penalty_reduced"_a = s.initial_penalty_reduced,
503 "penalty_reduced"_a = s.penalty_reduced,
504 "inner_convergence_failures"_a = s.inner_convergence_failures,
508 "norm_penalty₁"_a = s.norm_penalty₁,
509 "norm_penalty₂"_a = s.norm_penalty₂,
510 "penalty₂"_a = s.penalty₂,
511 "status"_a = s.status,