Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/cpu/kernels/topkv/generic/neon/fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "src/cpu/kernels/topkv/generic/neon/impl.h"

#include <arm_neon.h>
#include <limits>

namespace arm_compute
{
Expand All @@ -44,21 +45,24 @@ static inline uint32_t reduce_u32x4(uint32x4_t v)
#endif
}

// Explicit specialization for float: may use float32x4_t etc (only in this TU)
// Explicit specialization for float: may use float32x4_t
template <>
uint32_t count_gt_block<float>(const float *ptr, float threshold)
{
using Tag = wrapper::traits::neon_bitvector_tag_t<float, wrapper::traits::BitWidth::W128>;

const auto thr_vec = wrapper::vdup_n(threshold, Tag{});
const auto v = wrapper::vloadq(ptr);
const auto mask = wrapper::vcgt(v, thr_vec); // underlying uint32x4_t
const auto v = wrapper::vloadq(ptr);

// epsilon-aware compare: treat a > b only when (a - b) > epsilon
const float eps_val = std::numeric_limits<float>::epsilon();
const float thr_with_eps = threshold + eps_val;
const auto thr_eps_vec = wrapper::vdup_n(thr_with_eps, Tag{});
const auto mask = wrapper::vcgt(v, thr_eps_vec); // new: v > (threshold + eps)

const uint32x4_t m = mask;
const uint32x4_t b = vshrq_n_u32(m, 31);
return reduce_u32x4(b);
}

} // namespace detail

void topkv_fp32_neon(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win)
Expand Down