summaryrefslogtreecommitdiff
path: root/native/lang_id/common/math/algorithm.h
blob: e2f7179212d692768849983b2332927e92b5c6fd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Generic utils similar to those from the C++ header <algorithm>.

#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_

#include <algorithm>
#include <queue>
#include <vector>

namespace libtextclassifier3 {
namespace mobile {

// Returns index of max element from the vector |elements|.  Returns 0 if
// |elements| is empty.  T should be a type that can be compared by operator<.
template<typename T>
inline int GetArgMax(const std::vector<T> &elements) {
  return std::distance(
      elements.begin(),
      std::max_element(elements.begin(), elements.end()));
}

// Returns index of min element from the vector |elements|.  Returns 0 if
// |elements| is empty.  T should be a type that can be compared by operator<.
template<typename T>
inline int GetArgMin(const std::vector<T> &elements) {
  return std::distance(
      elements.begin(),
      std::min_element(elements.begin(), elements.end()));
}

// Returns indices of greatest k elements from |v|.
//
// The order between elements is indicated by |smaller|, which should be an
// object like std::less<T>, std::greater<T>, etc.  If smaller(a, b) is true,
// that means that "a is smaller than b".  Intuitively, |smaller| is a
// generalization of operator<.  Formally, it is a strict weak ordering, see
// https://en.cppreference.com/w/cpp/named_req/Compare
//
// Calling this function with std::less<T>() returns the indices of the larger k
// elements; calling it with std::greater<T>() returns the indices of the
// smallest k elements.  This is similar to e.g., std::priority_queue: using the
// default std::less gives you a max-heap, while using std::greater results in a
// min-heap.
//
// Returned indices are sorted in decreasing order of the corresponding elements
// (e.g., first element of the returned array is the index of the largest
// element).  In case of ties (e.g., equal elements) we select the one with the
// smallest index.  E.g., getting the indices of the top-2 elements from [3, 2,
// 1, 3, 0, 3] returns [0, 3] (the indices of the first and the second 3).
//
// Corner cases: If k <= 0, this function returns an empty vector.  If |v| has
// only n < k elements, this function returns all n indices [0, 1, 2, ..., n -
// 1], sorted according to the comp order of the indicated elements.
//
// Assuming each comparison is O(1), this function uses O(k) auxiliary space,
// and runs in O(n * log k) time.  Note: it is possible to use std::nth_element
// and obtain an O(n + k * log k) time algorithm, but that uses O(n) auxiliary
// space.  In our case, k << n, e.g., we may want to select the top-3 most
// likely classes from a set of 100 classes, so the time complexity difference
// should not matter in practice.
template <typename T, typename Smaller>
std::vector<int> GetTopKIndices(int k, const std::vector<T> &v,
                                Smaller smaller) {
  if (k <= 0) {
    return std::vector<int>();
  }

  if (static_cast<size_t>(k) > v.size()) {
    k = v.size();
  }

  // An order between indices.  Intuitively, rev_vcomp(i1, i2) iff v[i2] is
  // smaller than v[i1].  No typo: this inversion is necessary for Invariant B
  // below.  "vcomp" stands for "value comparator" (we compare the values
  // indicates by the two indices) and "rev_" stands for the reverse order.
  const auto rev_vcomp = [&v, &smaller](int i1, int i2) -> bool {
    if (smaller(v[i2], v[i1])) return true;
    if (smaller(v[i1], v[i2])) return false;

    // Break ties in favor of earlier elements.
    return i1 < i2;
  };

  // Indices of the top-k elements seen so far.
  std::vector<int> heap(k);

  // First, we fill |heap| with the first k indices.
  for (int i = 0; i < k; ++i) {
    heap[i] = i;
  }
  std::make_heap(heap.begin(), heap.end(), rev_vcomp);

  // Next, we explore the rest of the vector v.  Loop invariants:
  //
  // Invariant A: |heap| contains the indices of the top-k elements from v[0:i].
  //
  // Invariant B: heap[0] is the index of the smallest element from all elements
  // indicated by the indices from |heap|.
  //
  // Invariant C: |heap| is a max heap, according to order rev_vcomp.
  for (size_t i = k; i < v.size(); ++i) {
    // We have to update |heap| iff v[i] is larger than the smallest of the
    // top-k seen so far.  This test is easy to do, due to Invariant B above.
    if (smaller(v[heap[0]], v[i])) {
      // Next lines replace heap[0] with i and re-"heapify" heap[0:k-1].
      heap.push_back(i);
      std::pop_heap(heap.begin(), heap.end(), rev_vcomp);
      heap.pop_back();
    }
  }

  // Arrange indices from |heap| in decreasing order of corresponding elements.
  //
  // More info: in iteration #0, we extract the largest heap element (according
  // to rev_vcomp, i.e., the index of the smallest of the top-k elements) and
  // place it at the end of heap, i.e., in heap[k-1].  In iteration #1, we
  // extract the second largest and place it in heap[k-2], etc.
  for (int i = 0; i < k; ++i) {
    std::pop_heap(heap.begin(), heap.end() - i, rev_vcomp);
  }
  return heap;
}

template <typename T>
std::vector<int> GetTopKIndices(int k, const std::vector<T> &elements) {
  return GetTopKIndices(k, elements, std::less<T>());
}

}  // namespace mobile
}  // namespace nlp_saft

#endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_