summaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuSoftmaxKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.h')
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.h38
1 files changed, 23 insertions, 15 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h
index 8073a677d..df7d3f7d9 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.h
+++ b/src/cpu/kernels/CpuSoftmaxKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,10 +23,8 @@
*/
#ifndef ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H
#define ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H
-
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
-
namespace arm_compute
{
namespace cpu
@@ -34,8 +32,11 @@ namespace cpu
namespace kernels
{
/** Interface for the identifying the max value of 1D Logits */
-class CpuLogits1DMaxKernel : public ICpuKernel
+class CpuLogits1DMaxKernel : public ICpuKernel<CpuLogits1DMaxKernel>
{
+private:
+ using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type;
+
public:
CpuLogits1DMaxKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DMaxKernel);
@@ -52,27 +53,31 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *dst);
-
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
-
-private:
- using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type;
+ struct SoftmaxLogits1DMaxKernel
+ {
+ const char *name;
+ const DataTypeISASelectorPtr is_selected;
+ SoftmaxLogits1DMaxKernelPtr ukernel;
+ };
+ static const std::vector<SoftmaxLogits1DMaxKernel> &get_available_kernels();
private:
SoftmaxLogits1DMaxKernelPtr _run_method{ nullptr };
std::string _name{};
};
-
/** Interface for softmax computation for QASYMM8 with pre-computed max. */
template <bool IS_LOG = false>
-class CpuLogits1DSoftmaxKernel : public ICpuKernel
+class CpuLogits1DSoftmaxKernel : public ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>>
{
+private:
+ using SoftmaxLogits1DKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type;
+
public:
CpuLogits1DSoftmaxKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DSoftmaxKernel);
-
/** Set the input and output tensors.
*
* @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
@@ -92,13 +97,16 @@ public:
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *max,
const ITensorInfo *dst, const float beta, const ITensorInfo *tmp);
-
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
-
-private:
- using SoftmaxLogits1DKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type;
+ struct SoftmaxLogits1DKernel
+ {
+ const char *name;
+ const DataTypeISASelectorPtr is_selected;
+ SoftmaxLogits1DKernelPtr ukernel;
+ };
+ static const std::vector<SoftmaxLogits1DKernel> &get_available_kernels();
private:
float _beta{ 1.0f };