LCOV - code coverage report
Current view: top level - src/include/guanaqo - callback-streambuf.hpp Coverage Total Hit
Test: c778863a1c324207102b83c4492a01283ed1d0cc Lines: 57.7 % 52 30
Test Date: 2026-03-19 12:13:51
Legend: Lines:     hit not hit

            Line data    Source code
       1              : #pragma once
       2              : 
       3              : /// @file
       4              : /// @ingroup io
       5              : /// Stream buffer and redirector for capturing ostream output.
       6              : 
       7              : #include <cstring>
       8              : #include <functional>
       9              : #include <ostream>
      10              : #include <span>
      11              : #include <streambuf>
      12              : #include <vector>
      13              : 
      14              : namespace guanaqo {
      15              : 
      16              : /// An implementation of a `std::streambuf` that calls the given callback
      17              : /// function with the characters that are written.
      18              : /// @note   Not thread-safe.
      19              : ///
      20              : /// Inspired by https://github.com/pybind/pybind11/blob/master/include/pybind11/iostream.h
      21              : /// @ingroup io
      22              : class callback_streambuf : public std::streambuf {
      23              :   public:
      24              :     using write_func_t = std::function<void(std::span<const char>)>;
      25              : 
      26              :   private:
      27              :     /// Computes how many bytes at the end of the buffer are part of an
      28              :     /// incomplete sequence of UTF-8 bytes.
      29              :     /// @pre    `pbase() < pptr()`
      30            3 :     [[nodiscard]] size_t utf8_remainder() const {
      31            3 :         const auto rbase = std::reverse_iterator<char *>(pbase());
      32            3 :         const auto rpptr = std::reverse_iterator<char *>(pptr());
      33            3 :         static auto uch  = [](char c) { return static_cast<unsigned char>(c); };
      34            3 :         auto is_ascii    = [](char c) { return (uch(c) & 0x80) == 0x00; };
      35            0 :         auto is_leading  = [](char c) { return (uch(c) & 0xC0) == 0xC0; };
      36            0 :         auto is_leading_2b = [](char c) { return uch(c) <= 0xDF; };
      37            0 :         auto is_leading_3b = [](char c) { return uch(c) <= 0xEF; };
      38              :         // If the last character is ASCII, there are no incomplete code points
      39            3 :         if (is_ascii(*rpptr))
      40            3 :             return 0;
      41              :         // Otherwise, work back from the end of the buffer and find the first
      42              :         // UTF-8 leading byte
      43            0 :         const auto rpend   = rbase - rpptr >= 3 ? rpptr + 3 : rbase;
      44            0 :         const auto leading = std::find_if(rpptr, rpend, is_leading);
      45            0 :         if (leading == rbase)
      46            0 :             return 0;
      47            0 :         const auto dist  = static_cast<size_t>(leading - rpptr);
      48            0 :         size_t remainder = 0;
      49              : 
      50            0 :         if (dist == 0)
      51            0 :             remainder = 1; // 1-byte code point is impossible
      52            0 :         else if (dist == 1)
      53            0 :             remainder = is_leading_2b(*leading) ? 0 : dist + 1;
      54            0 :         else if (dist == 2)
      55            0 :             remainder = is_leading_3b(*leading) ? 0 : dist + 1;
      56              :         // else if (dist >= 3), at least 4 bytes before encountering an UTF-8
      57              :         // leading byte, either no remainder or invalid UTF-8.
      58              :         // We do not intend to handle invalid UTF-8 here.
      59            0 :         return remainder;
      60              :     }
      61              : 
      62              :     /// Calls @ref write_func if the buffer is not empty.
      63            5 :     int _sync() {
      64            5 :         if (pbase() != pptr()) { // If buffer is not empty
      65              :             // This subtraction cannot be negative, so dropping the sign.
      66            3 :             auto size        = static_cast<size_t>(pptr() - pbase());
      67            3 :             size_t remainder = utf8_remainder();
      68              : 
      69            3 :             if (size > remainder)
      70            3 :                 write_func(std::span{pbase(), size - remainder});
      71              : 
      72              :             // Copy the remainder at the end of the buffer to the beginning:
      73            3 :             if (remainder > 0)
      74            0 :                 std::memmove(pbase(), pptr() - remainder, remainder);
      75            3 :             setp(pbase(), epptr());
      76            3 :             pbump(static_cast<int>(remainder));
      77              :         }
      78            5 :         return 0;
      79              :     }
      80              : 
      81            3 :     int sync() override { return _sync(); }
      82              : 
      83            0 :     int overflow(int c) override {
      84              :         using traits_type = std::streambuf::traits_type;
      85            0 :         if (!traits_type::eq_int_type(c, traits_type::eof())) {
      86            0 :             *pptr() = traits_type::to_char_type(c);
      87            0 :             pbump(1);
      88              :         }
      89            0 :         return _sync() == 0 ? traits_type::not_eof(c) : traits_type::eof();
      90              :     }
      91              : 
      92              :   public:
      93            2 :     callback_streambuf(write_func_t write_func, size_t buffer_size = 1024)
      94            2 :         : write_func{std::move(write_func)} {
      95            2 :         buffer.resize(buffer_size);
      96            2 :         setp(buffer.data(), buffer.data() + buffer.size());
      97            2 :     }
      98              : 
      99              :     /// Syncs before destroy
     100            2 :     ~callback_streambuf() override { _sync(); }
     101              : 
     102              :   private:
     103              :     write_func_t write_func;
     104              :     std::vector<char> buffer;
     105              : };
     106              : 
     107              : /// Temporarily replaces the rdbuf of the given ostream. Flushes and restores
     108              : /// the old rdbuf upon destruction.
     109              : /// @ingroup io
     110              : class scoped_ostream_redirect {
     111              :   private:
     112              :     std::ostream &os;
     113              :     std::streambuf *old_buf;
     114              : 
     115              :   public:
     116            1 :     explicit scoped_ostream_redirect(std::ostream &os, std::streambuf *rdbuf)
     117            1 :         : os{os}, old_buf{os.rdbuf(rdbuf)} {}
     118              : 
     119            1 :     ~scoped_ostream_redirect() {
     120            1 :         os.flush();
     121            1 :         os.rdbuf(old_buf);
     122            1 :     }
     123              : 
     124              :     scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
     125              :     scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
     126              :     scoped_ostream_redirect &
     127              :     operator=(const scoped_ostream_redirect &)                     = delete;
     128              :     scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
     129              : };
     130              : 
     131              : } // namespace guanaqo
        

Generated by: LCOV version 2.4-0