149 lines
5.8 KiB
C++
149 lines
5.8 KiB
C++
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
// Generic utils similar to those from the C++ header <algorithm>.
|
|
|
|
#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
|
|
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
|
|
|
|
#include <algorithm>
|
|
#include <queue>
|
|
#include <vector>
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace mobile {
|
|
|
|
// Returns index of max element from the vector |elements|. Returns 0 if
|
|
// |elements| is empty. T should be a type that can be compared by operator<.
|
|
template<typename T>
|
|
inline int GetArgMax(const std::vector<T> &elements) {
|
|
return std::distance(
|
|
elements.begin(),
|
|
std::max_element(elements.begin(), elements.end()));
|
|
}
|
|
|
|
// Returns index of min element from the vector |elements|. Returns 0 if
|
|
// |elements| is empty. T should be a type that can be compared by operator<.
|
|
template<typename T>
|
|
inline int GetArgMin(const std::vector<T> &elements) {
|
|
return std::distance(
|
|
elements.begin(),
|
|
std::min_element(elements.begin(), elements.end()));
|
|
}
|
|
|
|
// Returns indices of greatest k elements from |v|.
|
|
//
|
|
// The order between elements is indicated by |smaller|, which should be an
|
|
// object like std::less<T>, std::greater<T>, etc. If smaller(a, b) is true,
|
|
// that means that "a is smaller than b". Intuitively, |smaller| is a
|
|
// generalization of operator<. Formally, it is a strict weak ordering, see
|
|
// https://en.cppreference.com/w/cpp/named_req/Compare
|
|
//
|
|
// Calling this function with std::less<T>() returns the indices of the larger k
|
|
// elements; calling it with std::greater<T>() returns the indices of the
|
|
// smallest k elements. This is similar to e.g., std::priority_queue: using the
|
|
// default std::less gives you a max-heap, while using std::greater results in a
|
|
// min-heap.
|
|
//
|
|
// Returned indices are sorted in decreasing order of the corresponding elements
|
|
// (e.g., first element of the returned array is the index of the largest
|
|
// element). In case of ties (e.g., equal elements) we select the one with the
|
|
// smallest index. E.g., getting the indices of the top-2 elements from [3, 2,
|
|
// 1, 3, 0, 3] returns [0, 3] (the indices of the first and the second 3).
|
|
//
|
|
// Corner cases: If k <= 0, this function returns an empty vector. If |v| has
|
|
// only n < k elements, this function returns all n indices [0, 1, 2, ..., n -
|
|
// 1], sorted according to the comp order of the indicated elements.
|
|
//
|
|
// Assuming each comparison is O(1), this function uses O(k) auxiliary space,
|
|
// and runs in O(n * log k) time. Note: it is possible to use std::nth_element
|
|
// and obtain an O(n + k * log k) time algorithm, but that uses O(n) auxiliary
|
|
// space. In our case, k << n, e.g., we may want to select the top-3 most
|
|
// likely classes from a set of 100 classes, so the time complexity difference
|
|
// should not matter in practice.
|
|
template <typename T, typename Smaller>
|
|
std::vector<int> GetTopKIndices(int k, const std::vector<T> &v,
|
|
Smaller smaller) {
|
|
if (k <= 0) {
|
|
return std::vector<int>();
|
|
}
|
|
|
|
if (k > v.size()) {
|
|
k = v.size();
|
|
}
|
|
|
|
// An order between indices. Intuitively, rev_vcomp(i1, i2) iff v[i2] is
|
|
// smaller than v[i1]. No typo: this inversion is necessary for Invariant B
|
|
// below. "vcomp" stands for "value comparator" (we compare the values
|
|
// indicates by the two indices) and "rev_" stands for the reverse order.
|
|
const auto rev_vcomp = [&v, &smaller](int i1, int i2) -> bool {
|
|
if (smaller(v[i2], v[i1])) return true;
|
|
if (smaller(v[i1], v[i2])) return false;
|
|
|
|
// Break ties in favor of earlier elements.
|
|
return i1 < i2;
|
|
};
|
|
|
|
// Indices of the top-k elements seen so far.
|
|
std::vector<int> heap(k);
|
|
|
|
// First, we fill |heap| with the first k indices.
|
|
for (int i = 0; i < k; ++i) {
|
|
heap[i] = i;
|
|
}
|
|
std::make_heap(heap.begin(), heap.end(), rev_vcomp);
|
|
|
|
// Next, we explore the rest of the vector v. Loop invariants:
|
|
//
|
|
// Invariant A: |heap| contains the indices of the top-k elements from v[0:i].
|
|
//
|
|
// Invariant B: heap[0] is the index of the smallest element from all elements
|
|
// indicated by the indices from |heap|.
|
|
//
|
|
// Invariant C: |heap| is a max heap, according to order rev_vcomp.
|
|
for (int i = k; i < v.size(); ++i) {
|
|
// We have to update |heap| iff v[i] is larger than the smallest of the
|
|
// top-k seen so far. This test is easy to do, due to Invariant B above.
|
|
if (smaller(v[heap[0]], v[i])) {
|
|
// Next lines replace heap[0] with i and re-"heapify" heap[0:k-1].
|
|
heap.push_back(i);
|
|
std::pop_heap(heap.begin(), heap.end(), rev_vcomp);
|
|
heap.pop_back();
|
|
}
|
|
}
|
|
|
|
// Arrange indices from |heap| in decreasing order of corresponding elements.
|
|
//
|
|
// More info: in iteration #0, we extract the largest heap element (according
|
|
// to rev_vcomp, i.e., the index of the smallest of the top-k elements) and
|
|
// place it at the end of heap, i.e., in heap[k-1]. In iteration #1, we
|
|
// extract the second largest and place it in heap[k-2], etc.
|
|
for (int i = 0; i < k; ++i) {
|
|
std::pop_heap(heap.begin(), heap.end() - i, rev_vcomp);
|
|
}
|
|
return heap;
|
|
}
|
|
|
|
template <typename T>
|
|
std::vector<int> GetTopKIndices(int k, const std::vector<T> &elements) {
|
|
return GetTopKIndices(k, elements, std::less<T>());
|
|
}
|
|
|
|
} // namespace mobile
|
|
} // namespace nlp_saft
|
|
|
|
#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_MATH_ALGORITHM_H_
|