24 const auto C = A_in.cols();
25 const auto R = A_in.rows();
30 using simd =
typename types::simd;
31 using isimd =
typename types::isimd;
35 static const isimd iota{[](
auto i) {
return i; }};
37 [[maybe_unused]]
const auto commit_fast = [&](
auto Sc,
index_t c) {
39 for (
index_t r = 0; r < R; ++r) {
40 auto Ac = types::aligned_load(&A_in(0, r, c));
45 const auto commit_no_shift = [&](
auto h_commit) {
47 const auto bs =
static_cast<index_t>(S_in.batch_size());
48 const isimd offsets = (h_commit - isimd{1}) * bs + iota;
49 return gather<T, Abi>(&S_in(0, 0, 0), offsets, h_commit);
52 const auto bs =
static_cast<index_t>(A_in.batch_size());
53 const auto stride = OAi == StorageOrder::ColMajor ? A_in.outer_stride() : 1;
54 const isimd offsets = (h_commit - isimd{1}) * bs * stride + iota;
55 for (
index_t r = 0; r < R; ++r) {
56 auto gather_A = gather<T, Abi>(&A_in(0, r, 0), offsets, h_commit);
57 writeA(gather_A, gather_S, r, j);
61 const auto commit = [&] [[gnu::always_inline]] () {
62 const isimd h = hist[0];
64 hist[k - 1] = hist[k];
70 for (
index_t c = 0; c < C; ++c) {
72 const simd Sc = types::aligned_load(&S_in(0, c, 0));
73 auto Sc_msk = !(Sc == simd{0});
81#if BATMAT_WITH_GSI_HPC_SIMD
83 h = isimd{[&](
int i) ->
index_t {
return h[i] == 0 && Sc_msk[i] ? c1_simd[i] : h[i]; }};
84 Sc_msk =
decltype(Sc_msk){[&](
int i) ->
bool {
return Sc_msk[i] && h_[i] != 0; }};
86 const auto msk = (h == 0) && Sc_msk.__cvt();
87 where(msk, h) = c1_simd;
88 Sc_msk = Sc_msk && (!msk).__cvt();
92 bool first_full = none_of(hist[0] == 0);
94 bool mask_left = any_of(Sc_msk);
95 if (first_full || mask_left) {
97 assert(any_of(hist[0] == 0));
101 auto &h = hist[N - 1];
102#if BATMAT_WITH_GSI_HPC_SIMD
103 h = isimd{[&](
int i) ->
index_t {
return Sc_msk[i] ? c1_simd[i] : h[i]; }};
105 where(Sc_msk.__cvt(), h) = c1_simd;
112 assert(reduce_count(hist[i] != 0) <= reduce_count(hist[i - 1] != 0));
123 GUANAQO_TRACE(
"compress_masks_count", 0, S_in.rows() * S_in.depth());
124 const auto C = S_in.rows();
128 using simd =
typename types::simd;
129 using isimd =
typename types::isimd;
133 const auto commit = [&] [[gnu::always_inline]] () {
135 hist[k - 1] = hist[k];
140 for (
index_t c = 0; c < C; ++c) {
141 const simd Sc = types::aligned_load(&S_in(0, c, 0));
142 auto Sc_msk = !(Sc == simd{0});
144#if BATMAT_WITH_GSI_HPC_SIMD
146 h = isimd{[&](
int i) ->
index_t {
return h[i] == 0 && Sc_msk[i] ? 1 : h[i]; }};
147 Sc_msk =
decltype(Sc_msk){[&](
int i) ->
bool {
return Sc_msk[i] && h_[i] != 0; }};
149 const auto msk = (h == 0) && Sc_msk.__cvt();
151 Sc_msk = Sc_msk && (!msk).__cvt();
155 bool first_full = none_of(hist[0] == 0);
157 bool mask_left = any_of(Sc_msk);
158 if (first_full || mask_left) {
160 assert(any_of(hist[0] == 0));
164 auto &h = hist[N - 1];
165#if BATMAT_WITH_GSI_HPC_SIMD
166 h = isimd{[&](
int i) ->
index_t {
return Sc_msk[i] ? 1 : h[i]; }};
168 where(Sc_msk.__cvt(), h) = isimd{1};
175 assert(reduce_count(hist[i] != 0) <= reduce_count(hist[i - 1] != 0));
184template <
class T,
class Abi, index_t N = 8, StorageOrder OAi, StorageOrder OAo>
187 GUANAQO_TRACE(
"compress_masks", 0, (A_in.rows() + 1) * A_in.cols() * A_in.depth());
194 auto writeS = [S_out] [[gnu::always_inline]] (
auto gather_S,
index_t j) {
197 auto writeA = [A_out] [[gnu::always_inline]] (
auto gather_A,
auto ,
index_t r,
204template <
class T,
class Abi, index_t N = 8, StorageOrder OAi, StorageOrder OAo>
207 GUANAQO_TRACE(
"compress_masks_sqrt", 0, (A_in.rows() + 1) * A_in.cols() * A_in.depth());
211 BATMAT_ASSERT(S_sign_out.rows() == 0 || A_in.depth() == S_sign_out.depth());
212 BATMAT_ASSERT(S_sign_out.rows() == 0 || A_out.cols() == S_sign_out.rows());
213 BATMAT_ASSERT(S_sign_out.rows() == 0 || S_sign_out.cols() == 1);
217 auto writeS = [S_sign_out] [[gnu::always_inline]] (
auto gather_S,
index_t j) {
218 if (S_sign_out.rows() > 0)
221 auto writeA = [A_out] [[gnu::always_inline]] (
auto gather_A,
auto gather_S,
index_t r,