Line data Source code
1 : #pragma once
2 :
3 : /// @file
4 : /// @ingroup io
5 : /// Stream buffer and redirector for capturing ostream output.
6 :
7 : #include <cstring>
8 : #include <functional>
9 : #include <ostream>
10 : #include <span>
11 : #include <streambuf>
12 : #include <vector>
13 :
14 : namespace guanaqo {
15 :
16 : /// An implementation of a `std::streambuf` that calls the given callback
17 : /// function with the characters that are written.
18 : /// @note Not thread-safe.
19 : ///
20 : /// Inspired by https://github.com/pybind/pybind11/blob/master/include/pybind11/iostream.h
21 : /// @ingroup io
22 : class callback_streambuf : public std::streambuf {
23 : public:
24 : using write_func_t = std::function<void(std::span<const char>)>;
25 :
26 : private:
27 : /// Computes how many bytes at the end of the buffer are part of an
28 : /// incomplete sequence of UTF-8 bytes.
29 : /// @pre `pbase() < pptr()`
30 3 : [[nodiscard]] size_t utf8_remainder() const {
31 3 : const auto rbase = std::reverse_iterator<char *>(pbase());
32 3 : const auto rpptr = std::reverse_iterator<char *>(pptr());
33 3 : static auto uch = [](char c) { return static_cast<unsigned char>(c); };
34 3 : auto is_ascii = [](char c) { return (uch(c) & 0x80) == 0x00; };
35 0 : auto is_leading = [](char c) { return (uch(c) & 0xC0) == 0xC0; };
36 0 : auto is_leading_2b = [](char c) { return uch(c) <= 0xDF; };
37 0 : auto is_leading_3b = [](char c) { return uch(c) <= 0xEF; };
38 : // If the last character is ASCII, there are no incomplete code points
39 3 : if (is_ascii(*rpptr))
40 3 : return 0;
41 : // Otherwise, work back from the end of the buffer and find the first
42 : // UTF-8 leading byte
43 0 : const auto rpend = rbase - rpptr >= 3 ? rpptr + 3 : rbase;
44 0 : const auto leading = std::find_if(rpptr, rpend, is_leading);
45 0 : if (leading == rbase)
46 0 : return 0;
47 0 : const auto dist = static_cast<size_t>(leading - rpptr);
48 0 : size_t remainder = 0;
49 :
50 0 : if (dist == 0)
51 0 : remainder = 1; // 1-byte code point is impossible
52 0 : else if (dist == 1)
53 0 : remainder = is_leading_2b(*leading) ? 0 : dist + 1;
54 0 : else if (dist == 2)
55 0 : remainder = is_leading_3b(*leading) ? 0 : dist + 1;
56 : // else if (dist >= 3), at least 4 bytes before encountering an UTF-8
57 : // leading byte, either no remainder or invalid UTF-8.
58 : // We do not intend to handle invalid UTF-8 here.
59 0 : return remainder;
60 : }
61 :
62 : /// Calls @ref write_func if the buffer is not empty.
63 5 : int _sync() {
64 5 : if (pbase() != pptr()) { // If buffer is not empty
65 : // This subtraction cannot be negative, so dropping the sign.
66 3 : auto size = static_cast<size_t>(pptr() - pbase());
67 3 : size_t remainder = utf8_remainder();
68 :
69 3 : if (size > remainder)
70 3 : write_func(std::span{pbase(), size - remainder});
71 :
72 : // Copy the remainder at the end of the buffer to the beginning:
73 3 : if (remainder > 0)
74 0 : std::memmove(pbase(), pptr() - remainder, remainder);
75 3 : setp(pbase(), epptr());
76 3 : pbump(static_cast<int>(remainder));
77 : }
78 5 : return 0;
79 : }
80 :
81 3 : int sync() override { return _sync(); }
82 :
83 0 : int overflow(int c) override {
84 : using traits_type = std::streambuf::traits_type;
85 0 : if (!traits_type::eq_int_type(c, traits_type::eof())) {
86 0 : *pptr() = traits_type::to_char_type(c);
87 0 : pbump(1);
88 : }
89 0 : return _sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
90 : }
91 :
92 : public:
93 2 : callback_streambuf(write_func_t write_func, size_t buffer_size = 1024)
94 2 : : write_func{std::move(write_func)} {
95 2 : buffer.resize(buffer_size);
96 2 : setp(buffer.data(), buffer.data() + buffer.size());
97 2 : }
98 :
99 : /// Syncs before destroy
100 2 : ~callback_streambuf() override { _sync(); }
101 :
102 : private:
103 : write_func_t write_func;
104 : std::vector<char> buffer;
105 : };
106 :
107 : /// Temporarily replaces the rdbuf of the given ostream. Flushes and restores
108 : /// the old rdbuf upon destruction.
109 : /// @ingroup io
110 : class scoped_ostream_redirect {
111 : private:
112 : std::ostream &os;
113 : std::streambuf *old_buf;
114 :
115 : public:
116 1 : explicit scoped_ostream_redirect(std::ostream &os, std::streambuf *rdbuf)
117 1 : : os{os}, old_buf{os.rdbuf(rdbuf)} {}
118 :
119 1 : ~scoped_ostream_redirect() {
120 1 : os.flush();
121 1 : os.rdbuf(old_buf);
122 1 : }
123 :
124 : scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
125 : scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
126 : scoped_ostream_redirect &
127 : operator=(const scoped_ostream_redirect &) = delete;
128 : scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
129 : };
130 :
131 : } // namespace guanaqo
|