286 lines
10 KiB
C++
286 lines
10 KiB
C++
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
|
|
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
|
|
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
#include "lang_id/common/embedding-network-params.h"
|
|
#include "lang_id/common/flatbuffers/embedding-network_generated.h"
|
|
#include "lang_id/common/lite_base/float16.h"
|
|
#include "lang_id/common/lite_base/logging.h"
|
|
#include "lang_id/common/lite_strings/stringpiece.h"
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace mobile {
|
|
|
|
// EmbeddingNetworkParams implementation backed by a flatbuffer.
|
|
//
|
|
// For info on our flatbuffer schema, see embedding-network.fbs.
|
|
class EmbeddingNetworkParamsFromFlatbuffer : public EmbeddingNetworkParams {
|
|
public:
|
|
// Constructs an EmbeddingNetworkParamsFromFlatbuffer instance, using the
|
|
// flatbuffer from |bytes|.
|
|
//
|
|
// IMPORTANT #1: caller should make sure |bytes| are alive during the lifetime
|
|
// of this EmbeddingNetworkParamsFromFlatbuffer instance. To avoid overhead,
|
|
// this constructor does not copy |bytes|.
|
|
//
|
|
// IMPORTANT #2: immediately after this constructor returns, we suggest you
|
|
// call is_valid() on the newly-constructed object and do not call any other
|
|
// method if the answer is negative (false).
|
|
explicit EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes);
|
|
|
|
bool UpdateTaskContextParameters(mobile::TaskContext *task_context) override {
|
|
// This class does not provide access to the overall TaskContext. It
|
|
// provides only parameters for the Neurosis neural network.
|
|
SAFTM_LOG(DFATAL) << "Not supported";
|
|
return false;
|
|
}
|
|
|
|
bool is_valid() const override { return valid_; }
|
|
|
|
int embeddings_size() const override { return SafeGetNumInputChunks(); }
|
|
|
|
int embeddings_num_rows(int i) const override {
|
|
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
|
|
return SafeGetNumRows(matrix);
|
|
}
|
|
|
|
int embeddings_num_cols(int i) const override {
|
|
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
|
|
return SafeGetNumCols(matrix);
|
|
}
|
|
|
|
const void *embeddings_weights(int i) const override {
|
|
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
|
|
return SafeGetValuesOfMatrix(matrix);
|
|
}
|
|
|
|
QuantizationType embeddings_quant_type(int i) const override {
|
|
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
|
|
return SafeGetQuantizationType(matrix);
|
|
}
|
|
|
|
const float16 *embeddings_quant_scales(int i) const override {
|
|
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
|
|
return SafeGetScales(matrix);
|
|
}
|
|
|
|
int hidden_size() const override {
|
|
// -1 because last layer is always the softmax layer.
|
|
return std::max(SafeGetNumLayers() - 1, 0);
|
|
}
|
|
|
|
int hidden_num_rows(int i) const override {
|
|
const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
|
|
return SafeGetNumRows(weights);
|
|
}
|
|
|
|
int hidden_num_cols(int i) const override {
|
|
const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
|
|
return SafeGetNumCols(weights);
|
|
}
|
|
|
|
QuantizationType hidden_weights_quant_type(int i) const override {
|
|
const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
|
|
return SafeGetQuantizationType(weights);
|
|
}
|
|
|
|
const void *hidden_weights(int i) const override {
|
|
const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
|
|
return SafeGetValuesOfMatrix(weights);
|
|
}
|
|
|
|
int hidden_bias_size() const override { return hidden_size(); }
|
|
|
|
int hidden_bias_num_rows(int i) const override {
|
|
const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
|
|
return SafeGetNumRows(bias);
|
|
}
|
|
|
|
int hidden_bias_num_cols(int i) const override {
|
|
const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
|
|
return SafeGetNumCols(bias);
|
|
}
|
|
|
|
const void *hidden_bias_weights(int i) const override {
|
|
const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
|
|
return SafeGetValues(bias);
|
|
}
|
|
|
|
int softmax_size() const override { return (SafeGetNumLayers() > 0) ? 1 : 0; }
|
|
|
|
int softmax_num_rows(int i) const override {
|
|
const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
|
|
return SafeGetNumRows(weights);
|
|
}
|
|
|
|
int softmax_num_cols(int i) const override {
|
|
const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
|
|
return SafeGetNumCols(weights);
|
|
}
|
|
|
|
QuantizationType softmax_weights_quant_type(int i) const override {
|
|
const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
|
|
return SafeGetQuantizationType(weights);
|
|
}
|
|
|
|
const void *softmax_weights(int i) const override {
|
|
const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
|
|
return SafeGetValuesOfMatrix(weights);
|
|
}
|
|
|
|
int softmax_bias_size() const override { return softmax_size(); }
|
|
|
|
int softmax_bias_num_rows(int i) const override {
|
|
const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
|
|
return SafeGetNumRows(bias);
|
|
}
|
|
|
|
int softmax_bias_num_cols(int i) const override {
|
|
const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
|
|
return SafeGetNumCols(bias);
|
|
}
|
|
|
|
const void *softmax_bias_weights(int i) const override {
|
|
const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
|
|
return SafeGetValues(bias);
|
|
}
|
|
|
|
int embedding_num_features_size() const override {
|
|
return SafeGetNumInputChunks();
|
|
}
|
|
|
|
int embedding_num_features(int i) const override {
|
|
if (!InRangeIndex(i, embedding_num_features_size(),
|
|
"embedding num features")) {
|
|
return 0;
|
|
}
|
|
const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
|
|
if (input_chunk == nullptr) {
|
|
return 0;
|
|
}
|
|
return input_chunk->num_features();
|
|
}
|
|
|
|
bool has_is_precomputed() const override { return false; }
|
|
bool is_precomputed() const override { return false; }
|
|
|
|
private:
|
|
// Returns true if and only if index is in [0, limit). info should be a
|
|
// pointer to a zero-terminated array of chars (ideally a literal string,
|
|
// e.g. "layer") indicating what the index refers to; info is used to make log
|
|
// messages more informative.
|
|
static bool InRangeIndex(int index, int limit, const char *info);
|
|
|
|
// Returns network_->input_chunks()->size(), if all dereferences are safe
|
|
// (i.e., no nullptr); otherwise, returns 0.
|
|
int SafeGetNumInputChunks() const;
|
|
|
|
// Returns network_->input_chunks()->Get(i), if all dereferences are safe
|
|
// (i.e., no nullptr) otherwise, returns nullptr.
|
|
const saft_fbs::InputChunk *SafeGetInputChunk(int i) const;
|
|
|
|
// Returns network_->input_chunks()->Get(i)->embedding(), if all dereferences
|
|
// are safe (i.e., no nullptr); otherwise, returns nullptr.
|
|
const saft_fbs::Matrix *SafeGetEmbeddingMatrix(int i) const;
|
|
|
|
// Returns network_->layers()->size(), if all dereferences are safe (i.e., no
|
|
// nullptr); otherwise, returns 0.
|
|
int SafeGetNumLayers() const;
|
|
|
|
// Returns network_->layers()->Get(i), if all dereferences are safe
|
|
// (i.e., no nullptr); otherwise, returns nullptr.
|
|
const saft_fbs::NeuralLayer *SafeGetLayer(int i) const;
|
|
|
|
// Returns network_->layers()->Get(i)->weights(), if all dereferences are safe
|
|
// (i.e., no nullptr); otherwise, returns nullptr.
|
|
const saft_fbs::Matrix *SafeGetLayerWeights(int i) const;
|
|
|
|
// Returns network_->layers()->Get(i)->bias(), if all dereferences are safe
|
|
// (i.e., no nullptr); otherwise, returns nullptr.
|
|
const saft_fbs::Matrix *SafeGetLayerBias(int i) const;
|
|
|
|
static int SafeGetNumRows(const saft_fbs::Matrix *matrix) {
|
|
return (matrix == nullptr) ? 0 : matrix->rows();
|
|
}
|
|
|
|
static int SafeGetNumCols(const saft_fbs::Matrix *matrix) {
|
|
return (matrix == nullptr) ? 0 : matrix->cols();
|
|
}
|
|
|
|
// Returns matrix->values()->data() if all dereferences are safe (i.e., no
|
|
// nullptr); otherwise, returns nullptr.
|
|
static const float *SafeGetValues(const saft_fbs::Matrix *matrix);
|
|
|
|
// Returns matrix->quantized_values()->data() if all dereferences are safe
|
|
// (i.e., no nullptr); otherwise, returns nullptr.
|
|
static const uint8_t *SafeGetQuantizedValues(const saft_fbs::Matrix *matrix);
|
|
|
|
// Returns matrix->scales()->data() if all dereferences are safe (i.e., no
|
|
// nullptr); otherwise, returns nullptr.
|
|
static const float16 *SafeGetScales(const saft_fbs::Matrix *matrix);
|
|
|
|
// Returns network_->layers()->Get(last_index) with last_index =
|
|
// SafeGetNumLayers() - 1, if all dereferences are safe (i.e., no nullptr) and
|
|
// there exists at least one layer; otherwise, returns nullptr.
|
|
const saft_fbs::NeuralLayer *SafeGetSoftmaxLayer() const;
|
|
|
|
const saft_fbs::Matrix *SafeGetSoftmaxWeights() const {
|
|
const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
|
|
return (layer == nullptr) ? nullptr : layer->weights();
|
|
}
|
|
|
|
const saft_fbs::Matrix *SafeGetSoftmaxBias() const {
|
|
const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
|
|
return (layer == nullptr) ? nullptr : layer->bias();
|
|
}
|
|
|
|
// Returns the quantization type for |matrix|. Returns NONE in case of
|
|
// problems (e.g., matrix is nullptr or unknown quantization type).
|
|
QuantizationType SafeGetQuantizationType(
|
|
const saft_fbs::Matrix *matrix) const;
|
|
|
|
// Returns a pointer to the values (float, uint8, or float16, depending on
|
|
// quantization) from |matrix|, in row-major order. Returns nullptr in case
|
|
// of a problem.
|
|
const void *SafeGetValuesOfMatrix(const saft_fbs::Matrix *matrix) const;
|
|
|
|
// Performs some validity checks. E.g., check that dimensions of the network
|
|
// layers match. Also checks that all pointers we return are inside the
|
|
// |bytes| passed to the constructor, such that client that reads from those
|
|
// pointers will not run into troubles.
|
|
bool ValidityChecking(StringPiece bytes) const;
|
|
|
|
// True if these params are valid. May be false if the original proto was
|
|
// corrupted. We prefer to set this to false to CHECK-failing.
|
|
bool valid_ = false;
|
|
|
|
// EmbeddingNetwork flatbuffer from the bytes passed as parameter to the
|
|
// constructor; see constructor doc.
|
|
const saft_fbs::EmbeddingNetwork *network_ = nullptr;
|
|
};
|
|
|
|
} // namespace mobile
|
|
} // namespace nlp_saft
|
|
|
|
#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
|