aboutsummaryrefslogtreecommitdiff
path: root/meta/base.h
blob: 4eeb88d54dbf4ec4941926ee8e4f63d8eb4821a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef GEMMLOWP_META_BASE_H_
#define GEMMLOWP_META_BASE_H_

#include <cassert>
#include <cstdint>

#include "../internal/common.h"

namespace gemmlowp {
namespace meta {

template <int align>
inline int AlignTo(int value) {
  return ((value + align - 1) / align) * align;
}

inline int AlignTo(int align, int value) {
  return ((value + align - 1) / align) * align;
}

template <typename Kernel_, typename OutputStream_>
struct FusedKernelParams {
 public:
  typedef Kernel_ Kernel;
  typedef OutputStream_ OutputStream;

  Kernel kernel;
  OutputStream output_stream;
};

template <typename InType_, typename OutType_, typename LeftStream_,
          typename RightStream_, typename Kernel_, typename OutputStream_>
struct GemmParams {
 public:
  typedef InType_ InType;
  typedef OutType_ OutType;
  typedef LeftStream_ LeftStream;
  typedef RightStream_ RightStream;
  typedef Kernel_ Kernel;
  typedef OutputStream_ OutputStream;

  typedef FusedKernelParams<Kernel, OutputStream> FusedKernel;

  // Common parameters.

  int m;
  int n;
  int k;

  const InType* lhs;
  const InType* rhs;
  OutType* result;
  std::uint8_t* scratch;

  // Specialized parameters.

  LeftStream left_stream;
  RightStream right_stream;
  FusedKernel fused_kernel;
};

template <typename InType, int lanes_count, int pack_size, int leftovers,
          typename StreamParams>
class Stream {
 public:
  static void Pack(const InType* in, const StreamParams& params, InType* out);

  static int UnpackedAdvance(const StreamParams& params);

  static int PackedAdvance(const StreamParams& params);

  static int UnpackedStride(const StreamParams& params);

  static int PackedStride(const StreamParams& params);
};

template <typename InType, typename StreamType>
class StreamUtil {
 public:
  static const InType* Offset(const StreamType& params, const InType* source,
                              int offset_stride, int offset_advance);

  static int Scratch(const StreamType& params, int lanes);
};

template <typename InType, typename OutType, typename Kernel,
          typename OutputStream, int kernel_m, int kernel_n, int pack_size>
class MulKernel {
 public:
  static void Multiply(const InType* lhs, const InType* rhs,
                       const FusedKernelParams<Kernel, OutputStream>& params,
                       OutType* result);
};

template <typename InType_, typename OutType_, typename Kernel_>
struct Transform1DParams {
  typedef InType_ InType;
  typedef OutType_ OutType;
  typedef Kernel_ Kernel;

  const InType* input;
  OutType* output;
  std::uint8_t* scratch;

  Kernel kernel;
};

template <typename InType, typename OutType, typename Kernel, int kernel_size,
          int leftovers>
class Transform1DKernel {
 public:
  static void Transform(const InType* input, const Kernel& params,
                        OutType* output);
};

template <typename InType, typename OutType, typename Transform>
class Transform1DUtil {
 public:
  static int EstimateComputeCost(const Transform& params);

  static const InType* OffsetInput(const Transform& params, const InType* input,
                                   int offset);

  static OutType* OffsetOutput(const Transform& params, OutType* output,
                               int offset);
};

}  // namespace meta
}  // namespace gemmlowp

#endif  // GEMMLOWP_META_BASE_H_