23 const auto C = A_in.cols();
24 const auto R = A_in.rows();
29 using simd =
typename types::simd;
30 using isimd =
typename types::isimd;
34 static const isimd iota{[](
auto i) {
return i; }};
36 [[maybe_unused]]
const auto commit_fast = [&](
auto Sc, index_t c) {
38 for (index_t r = 0; r < R; ++r) {
39 auto Ac = types::aligned_load(&A_in(0, r, c));
44 const auto commit_no_shift = [&](
auto h_commit) {
46 const auto bs =
static_cast<index_t
>(S_in.batch_size());
47 const isimd offsets = (h_commit - isimd{1}) * bs + iota;
48 return gather<T, Abi>(&S_in(0, 0, 0), offsets, h_commit);
51 const auto bs =
static_cast<index_t
>(A_in.batch_size());
52 const auto stride = OAi == StorageOrder::ColMajor ? A_in.outer_stride() : 1;
53 const isimd offsets = (h_commit - isimd{1}) * bs * stride + iota;
54 for (index_t r = 0; r < R; ++r) {
55 auto gather_A = gather<T, Abi>(&A_in(0, r, 0), offsets, h_commit);
56 writeA(gather_A, gather_S, r, j);
60 const auto commit = [&] [[gnu::always_inline]] () {
61 const isimd h = hist[0];
63 hist[k - 1] = hist[k];
69 for (index_t c = 0; c < C; ++c) {
71 const simd Sc = types::aligned_load(&S_in(0, c, 0));
72 auto Sc_msk = !(Sc == 0);
80#if BATMAT_WITH_GSI_HPC_SIMD
82 h = isimd{[&](
int i) -> index_t {
return h[i] == 0 && Sc_msk[i] ? c1_simd[i] : h[i]; }};
83 Sc_msk =
decltype(Sc_msk){[&](
int i) ->
bool {
return Sc_msk[i] && h_[i] != 0; }};
85 const auto msk = (h == 0) && Sc_msk.__cvt();
86 where(msk, h) = c1_simd;
87 Sc_msk = Sc_msk && (!msk).__cvt();
91 bool first_full = none_of(hist[0] == 0);
93 bool mask_left = any_of(Sc_msk);
94 if (first_full || mask_left) {
96 assert(any_of(hist[0] == 0));
100 auto &h = hist[N - 1];
101#if BATMAT_WITH_GSI_HPC_SIMD
102 h = isimd{[&](
int i) -> index_t {
return Sc_msk[i] ? c1_simd[i] : h[i]; }};
104 where(Sc_msk.__cvt(), h) = c1_simd;
110#
if BATMAT_WITH_GSI_HPC_SIMD
111 assert(reduce_count(hist[i] != 0) <= reduce_count(hist[i - 1] != 0));
113 assert(popcount(hist[i] != 0) <= popcount(hist[i - 1] != 0));
124 GUANAQO_TRACE(
"compress_masks_count", 0, S_in.rows() * S_in.depth());
125 const auto C = S_in.rows();
129 using simd =
typename types::simd;
130 using isimd =
typename types::isimd;
134 const auto commit = [&] [[gnu::always_inline]] () {
136 hist[k - 1] = hist[k];
141 for (index_t c = 0; c < C; ++c) {
142 const simd Sc = types::aligned_load(&S_in(0, c, 0));
143 auto Sc_msk = !(Sc == 0);
145#if BATMAT_WITH_GSI_HPC_SIMD
147 h = isimd{[&](
int i) -> index_t {
return h[i] == 0 && Sc_msk[i] ? 1 : h[i]; }};
148 Sc_msk =
decltype(Sc_msk){[&](
int i) ->
bool {
return Sc_msk[i] && h_[i] != 0; }};
150 const auto msk = (h == 0) && Sc_msk.__cvt();
152 Sc_msk = Sc_msk && (!msk).__cvt();
156 bool first_full = none_of(hist[0] == 0);
158 bool mask_left = any_of(Sc_msk);
159 if (first_full || mask_left) {
161 assert(any_of(hist[0] == 0));
165 auto &h = hist[N - 1];
166#if BATMAT_WITH_GSI_HPC_SIMD
167 h = isimd{[&](
int i) -> index_t {
return Sc_msk[i] ? 1 : h[i]; }};
169 where(Sc_msk.__cvt(), h) = isimd{1};
175#
if BATMAT_WITH_GSI_HPC_SIMD
176 assert(reduce_count(hist[i] != 0) <= reduce_count(hist[i - 1] != 0));
178 assert(popcount(hist[i] != 0) <= popcount(hist[i - 1] != 0));
187template <
class T,
class Abi, index_t N = 8, StorageOrder OAi, StorageOrder OAo>
190 GUANAQO_TRACE(
"compress_masks", 0, (A_in.rows() + 1) * A_in.cols() * A_in.depth());
197 auto writeS = [S_out] [[gnu::always_inline]] (
auto gather_S, index_t j) {
200 auto writeA = [A_out] [[gnu::always_inline]] (
auto gather_A,
auto , index_t r,
207template <
class T,
class Abi, index_t N = 8, StorageOrder OAi, StorageOrder OAo>
210 GUANAQO_TRACE(
"compress_masks_sqrt", 0, (A_in.rows() + 1) * A_in.cols() * A_in.depth());
214 BATMAT_ASSERT(S_sign_out.rows() == 0 || A_in.depth() == S_sign_out.depth());
215 BATMAT_ASSERT(S_sign_out.rows() == 0 || A_out.cols() == S_sign_out.rows());
216 BATMAT_ASSERT(S_sign_out.rows() == 0 || S_sign_out.cols() == 1);
220 auto writeS = [S_sign_out] [[gnu::always_inline]] (
auto gather_S, index_t j) {
221 if (S_sign_out.rows() > 0)
224 auto writeA = [A_out] [[gnu::always_inline]] (
auto gather_A,
auto gather_S, index_t r,