293 lines
12 KiB
C++
293 lines
12 KiB
C++
/*
|
|
* Copyright (C) 2022 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#define LOG_TAG "ModelUtils"
|
|
|
|
#include "ModelUtils.h"
|
|
|
|
#include <android-base/logging.h>
|
|
|
|
#include <algorithm>
|
|
#include <numeric>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "nnapi/TypeUtils.h"
|
|
#include "nnapi/Types.h"
|
|
#include "nnapi/Validation.h"
|
|
|
|
namespace android::nn {
|
|
namespace {
|
|
|
|
// Map each `true` value in `includes` with a unique integer. `false` values are ignored. E.g.:
|
|
// includes = {false, true, true, false, true}
|
|
// returned = { X, 0, 1, X, 2}
|
|
std::vector<uint32_t> getMapping(const std::vector<bool>& includes) {
|
|
std::vector<uint32_t> mapping;
|
|
mapping.reserve(includes.size());
|
|
std::transform_exclusive_scan(includes.begin(), includes.end(), std::back_inserter(mapping), 0u,
|
|
std::plus<>{}, [](bool included) { return included ? 1u : 0u; });
|
|
return mapping;
|
|
}
|
|
|
|
// Remap indexes in `indexes` by the mapping `mapping`.
|
|
// Precondition: indexes != nullptr
|
|
void remapIndexes(std::vector<uint32_t>* indexes, const std::vector<uint32_t>& mapping) {
|
|
CHECK(indexes != nullptr);
|
|
for (uint32_t& index : (*indexes)) {
|
|
index = mapping.at(index);
|
|
}
|
|
}
|
|
|
|
// Keep elements from `elements` specified by `elementsToKeep`, removing all other elements.
|
|
// Precondition: elements != nullptr
|
|
// Precondition: elements->size() == elementsToKeep.size()
|
|
template <typename Type>
|
|
void keepSelectedElements(std::vector<Type>* elements, const std::vector<bool>& elementsToKeep) {
|
|
CHECK(elements != nullptr);
|
|
CHECK_EQ(elements->size(), elementsToKeep.size());
|
|
|
|
size_t elementsCopied = 0;
|
|
for (size_t i = 0; i < elementsToKeep.size(); ++i) {
|
|
if (elementsToKeep[i]) {
|
|
if (elementsCopied != i) {
|
|
(*elements)[elementsCopied] = std::move((*elements)[i]);
|
|
}
|
|
elementsCopied++;
|
|
}
|
|
}
|
|
elements->resize(elementsCopied);
|
|
}
|
|
|
|
// Find which operands in model.main.operands are read or written by model.main.operations and
|
|
// model.main.inputIndexes.
|
|
// Postcondition: returned.size() == model.main.operands.size()
|
|
std::vector<bool> identifyUsedOperands(const Model& model) {
|
|
std::vector<bool> used(model.main.operands.size(), false);
|
|
auto markUsed = [&used](const std::vector<uint32_t>& indexes) {
|
|
std::for_each(indexes.begin(), indexes.end(),
|
|
[&used](uint32_t index) { used.at(index) = true; });
|
|
};
|
|
for (const auto& operation : model.main.operations) {
|
|
markUsed(operation.inputs);
|
|
markUsed(operation.outputs);
|
|
}
|
|
markUsed(model.main.inputIndexes);
|
|
CHECK_EQ(used.size(), model.main.operands.size());
|
|
return used;
|
|
}
|
|
|
|
// Forward declaration.
|
|
void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
|
|
std::vector<bool>* used);
|
|
|
|
// Helper function to find which subgraphs are reachable by `operands`.
|
|
// Precondition: used != nullptr
|
|
// Precondition: subgraphs.size() == used->size()
|
|
void identifyUsedSubgraphs(const std::vector<Operand>& operands,
|
|
const std::vector<Model::Subgraph>& subgraphs, std::vector<bool>* used) {
|
|
for (const auto& operand : operands) {
|
|
if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
|
|
identifyUsedSubgraphs(operand.location.offset, subgraphs, used);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Helper function to find which subgraphs are reachable by the subgraph at the `current` index, and
|
|
// store when a subgraph is used in `used`. `used` also acts as a cache, ensuring each subgraph is
|
|
// processed at most once.
|
|
// Precondition: used != nullptr
|
|
// Precondition: subgraphs.size() == used->size()
|
|
// Precondition: current < subgraphs.size()
|
|
void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
|
|
std::vector<bool>* used) {
|
|
CHECK(used != nullptr);
|
|
CHECK_EQ(subgraphs.size(), used->size());
|
|
CHECK_LT(current, subgraphs.size());
|
|
|
|
// If a subgraph was already marked as used, quickly return to avoid redundant processing.
|
|
if ((*used)[current]) {
|
|
return;
|
|
}
|
|
|
|
// Mark the current subgraph as used, then process any subgraph it references recursively.
|
|
(*used)[current] = true;
|
|
identifyUsedSubgraphs(subgraphs[current].operands, subgraphs, used);
|
|
}
|
|
|
|
// Find which subgraphs are reachable by the main operands of `model`.
|
|
// Postcondition: returned.size() == model.referenced.size()
|
|
std::vector<bool> identifyUsedSubgraphs(const Model& model) {
|
|
std::vector<bool> used(model.referenced.size(), false);
|
|
identifyUsedSubgraphs(model.main.operands, model.referenced, &used);
|
|
CHECK_EQ(used.size(), model.referenced.size());
|
|
return used;
|
|
}
|
|
|
|
// Helper function to find which pools are used by `subgraph`, and store when a pool is used in
|
|
// `used`.
|
|
// Precondition: used != nullptr
|
|
void identifyUsedPools(const Model::Subgraph& subgraph, std::vector<bool>* used) {
|
|
CHECK(used != nullptr);
|
|
for (const auto& operand : subgraph.operands) {
|
|
if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
|
|
used->at(operand.location.poolIndex) = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Find which pools are used by `model`.
|
|
// Postcondition: returned.size() == model.pools.size()
|
|
std::vector<bool> identifyUsedPools(const Model& model) {
|
|
std::vector<bool> used(model.pools.size(), false);
|
|
identifyUsedPools(model.main, &used);
|
|
for (const auto& subgraph : model.referenced) {
|
|
identifyUsedPools(subgraph, &used);
|
|
}
|
|
CHECK_EQ(used.size(), model.pools.size());
|
|
return used;
|
|
}
|
|
|
|
// Fix the DataLocation in `operand` by either remapping an index or by copying constant data.
|
|
// Precondition: operand != nullptr
|
|
// Precondition: newOperandValues != nullptr
|
|
void fixOperandDataLocation(Operand* operand, Model::OperandValues* newOperandValues,
|
|
const Model::OperandValues& oldOperandValues,
|
|
const std::vector<uint32_t>& remappedPoolIndex,
|
|
const std::vector<uint32_t>& remappedSubgraphIndex) {
|
|
CHECK(operand != nullptr);
|
|
CHECK(newOperandValues != nullptr);
|
|
|
|
switch (operand->lifetime) {
|
|
case Operand::LifeTime::CONSTANT_COPY: {
|
|
const uint8_t* data = oldOperandValues.data() + operand->location.offset;
|
|
const uint32_t length = operand->location.length;
|
|
operand->location = newOperandValues->append(data, length);
|
|
break;
|
|
}
|
|
case Operand::LifeTime::CONSTANT_REFERENCE:
|
|
operand->location.poolIndex = remappedPoolIndex.at(operand->location.poolIndex);
|
|
break;
|
|
case Operand::LifeTime::SUBGRAPH: {
|
|
uint32_t& subgraphIndex = operand->location.offset;
|
|
subgraphIndex = remappedSubgraphIndex.at(subgraphIndex);
|
|
break;
|
|
}
|
|
case Operand::LifeTime::TEMPORARY_VARIABLE:
|
|
case Operand::LifeTime::SUBGRAPH_INPUT:
|
|
case Operand::LifeTime::SUBGRAPH_OUTPUT:
|
|
case Operand::LifeTime::NO_VALUE:
|
|
case Operand::LifeTime::POINTER:
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Fix all DataLocations in `operands` by either remapping an index or by copying constant data.
|
|
// Precondition: operands != nullptr
|
|
// Precondition: newOperandValues != nullptr
|
|
void fixOperandDataLocations(std::vector<Operand>* operands, Model::OperandValues* newOperandValues,
|
|
const Model::OperandValues& oldOperandValues,
|
|
const std::vector<uint32_t>& remappedPoolIndex,
|
|
const std::vector<uint32_t>& remappedSubgraphIndex) {
|
|
for (Operand& operand : (*operands)) {
|
|
fixOperandDataLocation(&operand, newOperandValues, oldOperandValues, remappedPoolIndex,
|
|
remappedSubgraphIndex);
|
|
}
|
|
}
|
|
|
|
// Fix all operands' DataLocations in `model` by either remapping an index or by copying constant
|
|
// data.
|
|
// Precondition: model != nullptr
|
|
void fixOperandDataLocations(Model* model, const std::vector<uint32_t>& remappedPoolIndex,
|
|
const std::vector<uint32_t>& remappedSubgraphIndex) {
|
|
const auto operandValues = std::exchange(model->operandValues, Model::OperandValues{});
|
|
fixOperandDataLocations(&model->main.operands, &model->operandValues, operandValues,
|
|
remappedPoolIndex, remappedSubgraphIndex);
|
|
for (auto& subgraph : model->referenced) {
|
|
fixOperandDataLocations(&subgraph.operands, &model->operandValues, operandValues,
|
|
remappedPoolIndex, remappedSubgraphIndex);
|
|
}
|
|
}
|
|
|
|
// Find which extensions are used in `model`.
|
|
// Postcondition: returned.size() == model.extensionNameToPrefix.size()
|
|
std::vector<bool> identifyUsedExtensions(const Model& model) {
|
|
std::unordered_set<uint16_t> prefixes;
|
|
const auto collectPrefix = [&prefixes](const auto& operandOrOperation) {
|
|
const auto prefix = getExtensionPrefix(static_cast<uint32_t>(operandOrOperation.type));
|
|
constexpr uint16_t kStandardPrefix = 0u;
|
|
if (prefix != kStandardPrefix) {
|
|
prefixes.insert(prefix);
|
|
}
|
|
};
|
|
const auto collectPrefixes = [collectPrefix](const Model::Subgraph& subgraph) {
|
|
std::for_each(subgraph.operands.begin(), subgraph.operands.end(), collectPrefix);
|
|
std::for_each(subgraph.operations.begin(), subgraph.operations.end(), collectPrefix);
|
|
};
|
|
|
|
collectPrefixes(model.main);
|
|
for (const auto& subgraph : model.referenced) {
|
|
collectPrefixes(subgraph);
|
|
}
|
|
|
|
std::vector<bool> used;
|
|
used.reserve(model.extensionNameToPrefix.size());
|
|
for (const auto& extension : model.extensionNameToPrefix) {
|
|
used.push_back(prefixes.count(extension.prefix) > 0);
|
|
}
|
|
CHECK_EQ(used.size(), model.extensionNameToPrefix.size());
|
|
return used;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
void removeDeadOperands(Model* model) {
|
|
CHECK(model != nullptr);
|
|
|
|
// Keep only the operands which are used.
|
|
const auto operandsUsed = identifyUsedOperands(*model);
|
|
keepSelectedElements(&model->main.operands, operandsUsed);
|
|
|
|
// Fix operand indexes.
|
|
const auto mappedOperandIndices = getMapping(operandsUsed);
|
|
for (auto& operation : model->main.operations) {
|
|
remapIndexes(&operation.inputs, mappedOperandIndices);
|
|
remapIndexes(&operation.outputs, mappedOperandIndices);
|
|
}
|
|
remapIndexes(&model->main.inputIndexes, mappedOperandIndices);
|
|
remapIndexes(&model->main.outputIndexes, mappedOperandIndices);
|
|
|
|
// Keep only the subgraphs which are used.
|
|
const auto subgraphsUsed = identifyUsedSubgraphs(*model);
|
|
keepSelectedElements(&model->referenced, subgraphsUsed);
|
|
|
|
// Keep only the pools which are used.
|
|
const auto poolsUsed = identifyUsedPools(*model);
|
|
keepSelectedElements(&model->pools, poolsUsed);
|
|
|
|
// Fix operand locations.
|
|
const auto mappedPoolIndices = getMapping(poolsUsed);
|
|
const auto mappedSubgraphIndices = getMapping(subgraphsUsed);
|
|
fixOperandDataLocations(model, mappedPoolIndices, mappedSubgraphIndices);
|
|
|
|
// Keep only the extensionNameToPrefixes which are used.
|
|
const auto extensionsUsed = identifyUsedExtensions(*model);
|
|
keepSelectedElements(&model->extensionNameToPrefix, extensionsUsed);
|
|
}
|
|
|
|
} // namespace android::nn
|