quala main
Quasi-Newton and other accelerators
kwargs-to-struct.hpp
Go to the documentation of this file.
1/**
2 * @file
3 * This file defines mappings from Python dicts (kwargs) to simple parameter
4 * structs.
5 */
6
7#pragma once
8
9#include <functional>
10#include <map>
11#include <variant>
12
13#include <pybind11/detail/typeid.h>
14#include <pybind11/pybind11.h>
15namespace py = pybind11;
16
17struct cast_error_with_types : py::cast_error {
18 cast_error_with_types(const py::cast_error &e, std::string from, std::string to)
19 : py::cast_error(e), from(std::move(from)), to(std::move(to)) {}
20 std::string from;
21 std::string to;
22};
23
24template <class T, class A>
25auto attr_setter(A T::*attr) {
26 return [attr](T &t, const py::handle &h) {
27 try {
28 t.*attr = h.cast<A>();
29 } catch (const py::cast_error &e) {
30 throw cast_error_with_types(e, py::str(py::type::handle_of(h)), py::type_id<A>());
31 }
32 };
33}
34template <class T, class A>
35auto attr_getter(A T::*attr) {
36 return [attr](const T &t) { return py::cast(t.*attr); };
37}
38
39template <class T>
41 public:
42 template <class A>
43 attr_setter_fun_t(A T::*attr) : set(attr_setter(attr)), get(attr_getter(attr)) {}
44
45 std::function<void(T &, const py::handle &)> set;
46 std::function<py::object(const T &)> get;
47};
48
49template <class T>
50using kwargs_to_struct_table_t = std::map<std::string, attr_setter_fun_t<T>>;
51
52template <class T>
54
55template <class T>
56void kwargs_to_struct_helper(T &t, const py::kwargs &kwargs) {
57 const auto &m = kwargs_to_struct_table<T>;
58 for (auto &&[key, val] : kwargs) {
59 auto skey = key.template cast<std::string>();
60 auto it = m.find(skey);
61 if (it == m.end())
62 throw py::key_error("Unknown parameter " + skey);
63 try {
64 it->second.set(t, val);
65 } catch (const cast_error_with_types &e) {
66 throw std::runtime_error("Error converting parameter '" + skey + "' from " + e.from +
67 " to '" + e.to + "': " + e.what());
68 } catch (const std::runtime_error &e) {
69 throw std::runtime_error("Error setting parameter '" + skey + "': " + e.what());
70 }
71 }
72}
73
74template <class T>
75py::dict struct_to_dict_helper(const T &t) {
76 const auto &m = kwargs_to_struct_table<T>;
77 py::dict d;
78 for (auto &&[key, val] : m) {
79 py::object o = val.get(t);
80 if (py::hasattr(o, "to_dict"))
81 o = o.attr("to_dict")();
82 d[key.c_str()] = std::move(o);
83 }
84 return d;
85}
86
87template <class T>
88T kwargs_to_struct(const py::kwargs &kwargs) {
89 T t{};
90 kwargs_to_struct_helper(t, kwargs);
91 return t;
92}
93
94template <class T>
95py::dict struct_to_dict(const T &t) {
96 return struct_to_dict_helper<T>(t);
97}
98
99template <class T>
100T var_kwargs_to_struct(const std::variant<T, py::dict> &p) {
101 return std::holds_alternative<T>(p) ? std::get<T>(p)
102 : kwargs_to_struct<T>(std::get<py::dict>(p));
103}
104
105#include <quala/lbfgs.hpp>
106
107template <>
109 kwargs_to_struct_table<quala::LBFGSParams>{
110 {"memory", &quala::LBFGSParams::memory},
111 {"min_div_fac", &quala::LBFGSParams::min_div_fac},
112 {"min_abs_s", &quala::LBFGSParams::min_abs_s},
113 {"force_pos_def", &quala::LBFGSParams::force_pos_def},
114 {"cbfgs", &quala::LBFGSParams::cbfgs},
115 };
116
117template <>
120 {"α", &decltype(quala::LBFGSParams::cbfgs)::α},
121 {"ϵ", &decltype(quala::LBFGSParams::cbfgs)::ϵ},
122 };
123
125
126template <>
128 kwargs_to_struct_table<quala::AndersonAccelParams>{
130 };
131
132#include <quala/broyden-good.hpp>
133
134template <>
136 kwargs_to_struct_table<quala::BroydenGoodParams>{
139 {"force_pos_def", &quala::BroydenGoodParams::force_pos_def},
141 {"powell_damping_factor", &quala::BroydenGoodParams::powell_damping_factor},
142 };
std::function< py::object(const T &)> get
attr_setter_fun_t(A T::*attr)
std::function< void(T &, const py::handle &)> set
py::dict struct_to_dict(const T &t)
void kwargs_to_struct_helper(T &t, const py::kwargs &kwargs)
std::map< std::string, attr_setter_fun_t< T > > kwargs_to_struct_table_t
kwargs_to_struct_table_t< T > kwargs_to_struct_table
auto attr_setter(A T::*attr)
T kwargs_to_struct(const py::kwargs &kwargs)
T var_kwargs_to_struct(const std::variant< T, py::dict > &p)
auto attr_getter(A T::*attr)
py::dict struct_to_dict_helper(const T &t)
real_t powell_damping_factor
Powell's trick, damping, prevents nonsingularity.
real_t min_abs_s
Reject update if .
Definition: decl/lbfgs.hpp:15
real_t min_div_abs
Reject update if .
CBFGSParams cbfgs
Parameters in the cautious BFGS update condition.
Definition: decl/lbfgs.hpp:26
bool restarted
If set to true, the buffer is cleared after memory iterations.
length_t memory
Length of the history to keep.
Definition: decl/lbfgs.hpp:11
bool force_pos_def
If set to true, the inverse Hessian estimate should remain definite, i.e.
Definition: decl/lbfgs.hpp:33
real_t min_div_fac
Reject update if .
Definition: decl/lbfgs.hpp:13
cast_error_with_types(const py::cast_error &e, std::string from, std::string to)