637 lines
29 KiB
C++
637 lines
29 KiB
C++
/*
|
|
* Copyright (C) 2021 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "ShimPreparedModel.h"
|
|
|
|
#include <aidl/android/hardware/neuralnetworks/BnBurst.h>
|
|
#include <aidl/android/hardware/neuralnetworks/BnExecution.h>
|
|
#include <aidl/android/hardware/neuralnetworks/BnFencedExecutionCallback.h>
|
|
#include <aidl/android/hardware/neuralnetworks/ErrorStatus.h>
|
|
#include <aidl/android/hardware/neuralnetworks/OutputShape.h>
|
|
#include <aidl/android/hardware/neuralnetworks/RequestMemoryPool.h>
|
|
#include <android-base/chrono_utils.h>
|
|
#include <android-base/logging.h>
|
|
#include <android-base/scopeguard.h>
|
|
#include <android/binder_auto_utils.h>
|
|
#include <nnapi/TypeUtils.h>
|
|
#include <nnapi/hal/aidl/Conversions.h>
|
|
#include <nnapi/hal/aidl/Utils.h>
|
|
|
|
#include <algorithm>
|
|
#include <chrono>
|
|
#include <limits>
|
|
#include <memory>
|
|
#include <thread>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "ShimConverter.h"
|
|
#include "ShimUtils.h"
|
|
|
|
namespace aidl::android::hardware::neuralnetworks {
|
|
|
|
ErrorStatus ShimPreparedModel::parseInputs(
|
|
const Request& request, bool measure, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
|
|
::android::nn::sl_wrapper::Execution* execution,
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>>* requestMemoryPools,
|
|
const std::vector<TokenValuePair>& executionHints,
|
|
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix) {
|
|
for (const auto& requestPool : request.pools) {
|
|
switch (requestPool.getTag()) {
|
|
case RequestMemoryPool::pool: {
|
|
const auto& memoryPool = requestPool.get<RequestMemoryPool::pool>();
|
|
std::shared_ptr<::android::nn::sl_wrapper::Memory> mem =
|
|
convertFromHAL(mNnapi.get(), memoryPool);
|
|
if (!mem) {
|
|
LOG(ERROR) << "Failed to convert request HAL memory pools into SL memory";
|
|
return ErrorStatus::INVALID_ARGUMENT;
|
|
}
|
|
|
|
requestMemoryPools->push_back(mem);
|
|
break;
|
|
}
|
|
case RequestMemoryPool::token: {
|
|
int token = requestPool.get<RequestMemoryPool::token>();
|
|
|
|
auto memory = mBufferTracker->get(static_cast<uint32_t>(token));
|
|
if (memory == nullptr) {
|
|
return ErrorStatus::INVALID_ARGUMENT;
|
|
}
|
|
|
|
requestMemoryPools->push_back(memory);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// enable input and output padding
|
|
const auto enablePaddingResult = execution->enableInputAndOutputPadding(true);
|
|
if (enablePaddingResult != Result::NO_ERROR) {
|
|
return convertResultToErrorStatus(enablePaddingResult);
|
|
}
|
|
|
|
const auto& model = mMainAndReferencedModels[0];
|
|
|
|
if (request.inputs.size() > model.getInputs().size()) {
|
|
return ErrorStatus::INVALID_ARGUMENT;
|
|
}
|
|
|
|
// set inputs
|
|
for (int i = 0; i < request.inputs.size(); ++i) {
|
|
const auto& input = request.inputs[i];
|
|
::android::nn::wrapper::OperandType operandType = model.getOperands()[model.getInputs()[i]];
|
|
if (!input.hasNoValue) {
|
|
if (input.dimensions.size() > 0) {
|
|
operandType.updateDimensions(::android::nn::toUnsigned(input.dimensions).value());
|
|
}
|
|
auto result = execution->setInputFromMemory(
|
|
i, requestMemoryPools->at(input.location.poolIndex).get(),
|
|
input.location.offset, input.location.length, &operandType.operandType);
|
|
if (result != Result::NO_ERROR) {
|
|
return convertResultToErrorStatus(result);
|
|
}
|
|
} else {
|
|
auto result = execution->setInput(i, nullptr, 0);
|
|
if (result != Result::NO_ERROR) {
|
|
return convertResultToErrorStatus(result);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (request.outputs.size() > model.getOutputs().size()) {
|
|
return ErrorStatus::INVALID_ARGUMENT;
|
|
}
|
|
// set outputs
|
|
for (int i = 0; i < request.outputs.size(); ++i) {
|
|
const auto& output = request.outputs[i];
|
|
::android::nn::wrapper::OperandType operandType =
|
|
model.getOperands()[model.getOutputs()[i]];
|
|
|
|
if (!output.hasNoValue) {
|
|
if (output.dimensions.size() > 0) {
|
|
operandType.updateDimensions(::android::nn::toUnsigned(output.dimensions).value());
|
|
}
|
|
auto result = execution->setOutputFromMemory(
|
|
i, requestMemoryPools->at(output.location.poolIndex).get(),
|
|
output.location.offset, output.location.length, &operandType.operandType);
|
|
if (result != Result::NO_ERROR) {
|
|
return convertResultToErrorStatus(result);
|
|
}
|
|
} else {
|
|
auto result = execution->setOutput(i, nullptr, 0);
|
|
if (result != Result::NO_ERROR) {
|
|
return convertResultToErrorStatus(result);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (measure) {
|
|
execution->setMeasureTiming(true);
|
|
}
|
|
|
|
if (deadlineNs > -1) {
|
|
std::chrono::time_point<::android::base::boot_clock> deadlinePoint(
|
|
std::chrono::nanoseconds{deadlineNs});
|
|
const auto currentTime = ::android::base::boot_clock::now();
|
|
const auto timeoutDuration = std::chrono::nanoseconds(deadlinePoint - currentTime);
|
|
if (timeoutDuration <= std::chrono::nanoseconds::zero()) {
|
|
return ErrorStatus::MISSED_DEADLINE_TRANSIENT;
|
|
} else {
|
|
auto result = execution->setTimeout(std::max<uint64_t>(1, timeoutDuration.count()));
|
|
if (result != Result::NO_ERROR) {
|
|
return convertResultToErrorStatus(result);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (loopTimeoutDurationNs > 0) {
|
|
execution->setLoopTimeout(loopTimeoutDurationNs);
|
|
}
|
|
|
|
if (!executionHints.empty() || !extensionNameToPrefix.empty()) {
|
|
std::unordered_map<uint16_t, std::string> prefixToName;
|
|
for (const auto [name, prefix] : extensionNameToPrefix) {
|
|
prefixToName.emplace(prefix, name);
|
|
}
|
|
|
|
for (const auto& [token, value] : executionHints) {
|
|
const auto uToken = static_cast<uint32_t>(token);
|
|
const auto prefix = ::android::nn::getExtensionPrefix(uToken);
|
|
const auto attributeCodeWithinExtension = ::android::nn::getTypeWithinExtension(uToken);
|
|
|
|
const auto it = prefixToName.find(prefix);
|
|
if (it == prefixToName.end()) {
|
|
return ErrorStatus::INVALID_ARGUMENT;
|
|
}
|
|
const std::string& extensionName = it->second;
|
|
|
|
const auto result = execution->addExtensionAttribute(
|
|
extensionName, attributeCodeWithinExtension, value);
|
|
if (result != Result::NO_ERROR) {
|
|
return convertResultToErrorStatus(result);
|
|
}
|
|
}
|
|
}
|
|
|
|
return ErrorStatus::NONE;
|
|
}
|
|
|
|
class ShimFencedExecutionCallback : public BnFencedExecutionCallback {
|
|
public:
|
|
ShimFencedExecutionCallback(
|
|
std::shared_ptr<::android::nn::sl_wrapper::Execution> execution, Event e,
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,
|
|
bool measureTiming)
|
|
: mMemoryPools(std::move(memoryPools)),
|
|
mExecution(std::move(execution)),
|
|
mEvent(std::move(e)),
|
|
mMeasureTiming(measureTiming) {}
|
|
|
|
ndk::ScopedAStatus getExecutionInfo(Timing* timingLaunched, Timing* timingFenced,
|
|
ErrorStatus* errorStatus) override {
|
|
auto status = mEvent.wait();
|
|
*errorStatus = convertResultToErrorStatus(status);
|
|
|
|
if (mMeasureTiming) {
|
|
uint64_t duration;
|
|
constexpr int64_t int64cap = std::numeric_limits<int64_t>::max();
|
|
// Special value used for "no measurements"
|
|
constexpr uint64_t uint64cap = std::numeric_limits<uint64_t>::max();
|
|
auto result = mExecution->getDuration(Duration::ON_HARDWARE, &duration);
|
|
SLW2SAS_RETURN_IF_ERROR(result);
|
|
timingLaunched->timeOnDeviceNs = (duration == uint64cap) ? -1
|
|
: (duration > int64cap)
|
|
? int64cap
|
|
: static_cast<int64_t>(duration);
|
|
|
|
result = mExecution->getDuration(Duration::IN_DRIVER, &duration);
|
|
SLW2SAS_RETURN_IF_ERROR(result);
|
|
timingLaunched->timeInDriverNs = (duration == uint64cap) ? -1
|
|
: (duration > int64cap)
|
|
? int64cap
|
|
: static_cast<int64_t>(duration);
|
|
|
|
result = mExecution->getDuration(Duration::FENCED_ON_HARDWARE, &duration);
|
|
SLW2SAS_RETURN_IF_ERROR(result);
|
|
timingFenced->timeOnDeviceNs = (duration == uint64cap) ? -1
|
|
: (duration > int64cap) ? int64cap
|
|
: static_cast<int64_t>(duration);
|
|
|
|
result = mExecution->getDuration(Duration::FENCED_IN_DRIVER, &duration);
|
|
SLW2SAS_RETURN_IF_ERROR(result);
|
|
timingFenced->timeInDriverNs = (duration == uint64cap) ? -1
|
|
: (duration > int64cap) ? int64cap
|
|
: static_cast<int64_t>(duration);
|
|
} else {
|
|
timingFenced->timeOnDeviceNs = -1;
|
|
timingFenced->timeInDriverNs = -1;
|
|
timingLaunched->timeOnDeviceNs = -1;
|
|
timingLaunched->timeInDriverNs = -1;
|
|
}
|
|
|
|
return ndk::ScopedAStatus::ok();
|
|
}
|
|
|
|
private:
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> mMemoryPools;
|
|
std::shared_ptr<::android::nn::sl_wrapper::Execution> mExecution;
|
|
::android::nn::wrapper::Event mEvent;
|
|
bool mMeasureTiming;
|
|
};
|
|
|
|
static ndk::ScopedAStatus executeFencedInternal(
|
|
const std::shared_ptr<const NnApiSupportLibrary>& nnapi,
|
|
const std::shared_ptr<::android::nn::sl_wrapper::Execution>& execution,
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
|
|
const std::vector<ndk::ScopedFileDescriptor>& waitFor, int64_t durationNs,
|
|
bool measureTiming, FencedExecutionResult* fencedExecutionResult) {
|
|
CHECK(execution != nullptr);
|
|
CHECK(fencedExecutionResult != nullptr);
|
|
|
|
std::vector<const ANeuralNetworksEvent*> deps(waitFor.size());
|
|
auto createResult = Result::NO_ERROR;
|
|
std::transform(waitFor.begin(), waitFor.end(), deps.begin(),
|
|
[&](const ::ndk::ScopedFileDescriptor& e) {
|
|
ANeuralNetworksEvent* r = nullptr;
|
|
if (createResult == Result::NO_ERROR) {
|
|
createResult = static_cast<Result>(
|
|
nnapi->getFL5()->ANeuralNetworksEvent_createFromSyncFenceFd(
|
|
e.get(), &r));
|
|
}
|
|
return r;
|
|
});
|
|
|
|
const auto guard = ::android::base::make_scope_guard([nnapi, deps] {
|
|
for (auto& dep : deps) {
|
|
if (dep != nullptr) {
|
|
nnapi->getFL5()->ANeuralNetworksEvent_free(const_cast<ANeuralNetworksEvent*>(dep));
|
|
}
|
|
}
|
|
});
|
|
|
|
SLW2SAS_RETURN_IF_ERROR(createResult);
|
|
|
|
Event e(nnapi.get());
|
|
auto result = execution->startComputeWithDependencies(deps, durationNs, &e);
|
|
SLW2SAS_RETURN_IF_ERROR(result);
|
|
|
|
int syncFence = -1;
|
|
fencedExecutionResult->syncFence = ndk::ScopedFileDescriptor(
|
|
(e.getSyncFenceFd(&syncFence) == Result::NO_ERROR) ? syncFence : -1);
|
|
fencedExecutionResult->callback = ndk::SharedRefBase::make<ShimFencedExecutionCallback>(
|
|
execution, std::move(e), requestMemoryPools, measureTiming);
|
|
|
|
return ndk::ScopedAStatus::ok();
|
|
}
|
|
|
|
::ndk::ScopedAStatus ShimPreparedModel::executeFencedCommon(
|
|
const Request& request, const std::vector<::ndk::ScopedFileDescriptor>& waitFor,
|
|
bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
|
|
const std::vector<TokenValuePair>& executionHints,
|
|
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
|
|
FencedExecutionResult* fencedExecutionResult) {
|
|
CHECK(fencedExecutionResult != nullptr);
|
|
|
|
if (deadlineNs < -1) {
|
|
LOG(ERROR) << "Invalid deadline value, must be >= -1";
|
|
return ndk::ScopedAStatus::fromServiceSpecificError(
|
|
static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
|
|
}
|
|
auto execution =
|
|
std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
|
|
auto errorStatus =
|
|
parseInputs(request, measureTiming, deadlineNs, loopTimeoutDurationNs, execution.get(),
|
|
&requestMemoryPools, executionHints, extensionNameToPrefix);
|
|
if (errorStatus != ErrorStatus::NONE) {
|
|
return toAStatus(errorStatus);
|
|
}
|
|
return executeFencedInternal(mNnapi, execution, std::move(requestMemoryPools), waitFor,
|
|
durationNs, measureTiming, fencedExecutionResult);
|
|
}
|
|
|
|
::ndk::ScopedAStatus ShimPreparedModel::executeFenced(
|
|
const ::aidl::android::hardware::neuralnetworks::Request& request,
|
|
const std::vector<::ndk::ScopedFileDescriptor>& waitFor, bool measureTiming,
|
|
int64_t deadlineNs, int64_t loopTimeoutDurationNs, int64_t durationNs,
|
|
FencedExecutionResult* fencedExecutionResult) {
|
|
return executeFencedCommon(request, waitFor, measureTiming, deadlineNs, loopTimeoutDurationNs,
|
|
durationNs, /*executionHints=*/{}, /*extensionNameToPrefix=*/{},
|
|
fencedExecutionResult);
|
|
}
|
|
|
|
static ndk::ScopedAStatus executeSynchronouslyInternal(
|
|
const std::shared_ptr<::android::nn::sl_wrapper::Execution>& execution, bool measureTiming,
|
|
int numOutputs, ExecutionResult* executionResult) {
|
|
CHECK(execution != nullptr);
|
|
CHECK(executionResult != nullptr);
|
|
|
|
auto result = execution->compute();
|
|
auto errorStatus = convertResultToErrorStatus(result);
|
|
|
|
std::vector<OutputShape> outputShapes;
|
|
outputShapes.reserve(numOutputs);
|
|
bool sufficientSize = true;
|
|
for (int i = 0; i < numOutputs; ++i) {
|
|
OutputShape outputShape;
|
|
std::vector<uint32_t> outputDims;
|
|
auto result = execution->getOutputOperandDimensions(i, &outputDims);
|
|
if (result == Result::NO_ERROR) {
|
|
outputShape.isSufficient = true;
|
|
outputShape.dimensions.assign(outputDims.begin(), outputDims.end());
|
|
} else if (result == Result::OUTPUT_INSUFFICIENT_SIZE) {
|
|
sufficientSize = false;
|
|
outputShape.isSufficient = false;
|
|
outputShape.dimensions.assign(outputDims.begin(), outputDims.end());
|
|
} else {
|
|
if (errorStatus == ErrorStatus::NONE) {
|
|
errorStatus = ErrorStatus::GENERAL_FAILURE;
|
|
}
|
|
}
|
|
outputShapes.push_back(std::move(outputShape));
|
|
}
|
|
|
|
int64_t timeOnDeviceNs = -1;
|
|
int64_t timeInDriverNs = -1;
|
|
if (measureTiming && errorStatus == ErrorStatus::NONE) {
|
|
uint64_t duration;
|
|
constexpr int64_t int64cap = std::numeric_limits<int64_t>::max();
|
|
// Special value used for "no measurements"
|
|
constexpr uint64_t uint64cap = std::numeric_limits<uint64_t>::max();
|
|
auto result = execution->getDuration(Duration::ON_HARDWARE, &duration);
|
|
SLW2SAS_RETURN_IF_ERROR(result);
|
|
timeOnDeviceNs = (duration == uint64cap) ? -1
|
|
: (duration > int64cap) ? int64cap
|
|
: static_cast<int64_t>(duration);
|
|
|
|
result = execution->getDuration(Duration::IN_DRIVER, &duration);
|
|
SLW2SAS_RETURN_IF_ERROR(result);
|
|
timeInDriverNs = (duration == uint64cap) ? -1
|
|
: (duration > int64cap) ? int64cap
|
|
: static_cast<int64_t>(duration);
|
|
}
|
|
|
|
*executionResult =
|
|
ExecutionResult{sufficientSize,
|
|
std::move(outputShapes),
|
|
{.timeOnDeviceNs = timeOnDeviceNs, .timeInDriverNs = timeInDriverNs}};
|
|
if (errorStatus == ErrorStatus::NONE || errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
|
|
return ndk::ScopedAStatus::ok();
|
|
}
|
|
return toAStatus(errorStatus);
|
|
}
|
|
|
|
::ndk::ScopedAStatus ShimPreparedModel::executeSynchronouslyCommon(
|
|
const Request& request, bool measureTiming, int64_t deadlineNs,
|
|
int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& executionHints,
|
|
const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
|
|
ExecutionResult* executionResult) {
|
|
CHECK(executionResult != nullptr);
|
|
|
|
if (deadlineNs < -1) {
|
|
LOG(ERROR) << "Invalid deadline value, must be >= -1";
|
|
return ndk::ScopedAStatus::fromServiceSpecificError(
|
|
static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
|
|
}
|
|
|
|
auto execution =
|
|
std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
|
|
auto errorStatus =
|
|
parseInputs(request, measureTiming, deadlineNs, loopTimeoutDurationNs, execution.get(),
|
|
&requestMemoryPools, executionHints, extensionNameToPrefix);
|
|
if (errorStatus != ErrorStatus::NONE) {
|
|
return toAStatus(errorStatus);
|
|
}
|
|
return executeSynchronouslyInternal(execution, measureTiming, request.outputs.size(),
|
|
executionResult);
|
|
}
|
|
|
|
::ndk::ScopedAStatus ShimPreparedModel::executeSynchronously(
|
|
const Request& request, bool measureTiming, int64_t deadlineNs,
|
|
int64_t loopTimeoutDurationNs,
|
|
::aidl::android::hardware::neuralnetworks::ExecutionResult* executionResult) {
|
|
return executeSynchronouslyCommon(request, measureTiming, deadlineNs, loopTimeoutDurationNs,
|
|
/*executionHints=*/{}, /*extensionNameToPrefix=*/{},
|
|
executionResult);
|
|
}
|
|
|
|
::ndk::ScopedAStatus ShimPreparedModel::executeSynchronouslyWithConfig(
|
|
const Request& request, const ExecutionConfig& config, int64_t deadlineNs,
|
|
ExecutionResult* executionResult) {
|
|
return executeSynchronouslyCommon(request, config.measureTiming, deadlineNs,
|
|
config.loopTimeoutDurationNs, config.executionHints,
|
|
config.extensionNameToPrefix, executionResult);
|
|
}
|
|
|
|
::ndk::ScopedAStatus ShimPreparedModel::executeFencedWithConfig(
|
|
const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
|
|
const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
|
|
FencedExecutionResult* executionResult) {
|
|
return executeFencedCommon(request, waitFor, config.measureTiming, deadlineNs,
|
|
config.loopTimeoutDurationNs, durationNs, config.executionHints,
|
|
config.extensionNameToPrefix, executionResult);
|
|
}
|
|
|
|
// TODO(183397380): make it use ANNBurst object
|
|
class ShimBurst : public BnBurst {
|
|
public:
|
|
// Precondition: preparedModel != nullptr
|
|
explicit ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel);
|
|
|
|
ndk::ScopedAStatus executeSynchronously(const Request& request,
|
|
const std::vector<int64_t>& memoryIdentifierTokens,
|
|
bool measureTiming, int64_t deadlineNs,
|
|
int64_t loopTimeoutDurationNs,
|
|
ExecutionResult* executionResult) override;
|
|
ndk::ScopedAStatus executeSynchronouslyWithConfig(
|
|
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
|
|
const ExecutionConfig& config, int64_t deadlineNs,
|
|
ExecutionResult* executionResult) override;
|
|
ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override;
|
|
|
|
protected:
|
|
std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
|
|
const std::shared_ptr<ShimPreparedModel> kPreparedModel;
|
|
};
|
|
|
|
ndk::ScopedAStatus ShimPreparedModel::configureExecutionBurst(std::shared_ptr<IBurst>* burst) {
|
|
std::shared_ptr<ShimPreparedModel> self = this->template ref<ShimPreparedModel>();
|
|
*burst = ndk::SharedRefBase::make<ShimBurst>(std::move(self));
|
|
return ndk::ScopedAStatus::ok();
|
|
}
|
|
|
|
ShimBurst::ShimBurst(std::shared_ptr<ShimPreparedModel> preparedModel)
|
|
: kPreparedModel(std::move(preparedModel)) {
|
|
CHECK(kPreparedModel != nullptr);
|
|
}
|
|
|
|
ndk::ScopedAStatus ShimBurst::executeSynchronously(
|
|
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
|
|
bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
|
|
ExecutionResult* executionResult) {
|
|
if (request.pools.size() != memoryIdentifierTokens.size()) {
|
|
return toAStatus(ErrorStatus::INVALID_ARGUMENT,
|
|
"request.pools.size() != memoryIdentifierTokens.size()");
|
|
}
|
|
if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
|
|
[](int64_t token) { return token >= -1; })) {
|
|
return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierTokens");
|
|
}
|
|
|
|
// Ensure at most one execution is in flight at a time.
|
|
const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
|
|
if (executionAlreadyInFlight) {
|
|
return toAStatus(ErrorStatus::GENERAL_FAILURE,
|
|
"Burst object supports at most one execution at a time");
|
|
}
|
|
const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
|
|
|
|
return kPreparedModel->executeSynchronously(request, measureTiming, deadlineNs,
|
|
loopTimeoutDurationNs, executionResult);
|
|
}
|
|
|
|
ndk::ScopedAStatus ShimBurst::executeSynchronouslyWithConfig(
|
|
const Request& request, const std::vector<int64_t>& memoryIdentifierTokens,
|
|
const ExecutionConfig& config, int64_t deadlineNs, ExecutionResult* executionResult) {
|
|
if (request.pools.size() != memoryIdentifierTokens.size()) {
|
|
return toAStatus(ErrorStatus::INVALID_ARGUMENT,
|
|
"request.pools.size() != memoryIdentifierTokens.size()");
|
|
}
|
|
if (!std::all_of(memoryIdentifierTokens.begin(), memoryIdentifierTokens.end(),
|
|
[](int64_t token) { return token >= -1; })) {
|
|
return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierTokens");
|
|
}
|
|
|
|
// Ensure at most one execution is in flight at a time.
|
|
const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
|
|
if (executionAlreadyInFlight) {
|
|
return toAStatus(ErrorStatus::GENERAL_FAILURE,
|
|
"Burst object supports at most one execution at a time");
|
|
}
|
|
const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
|
|
|
|
return kPreparedModel->executeSynchronouslyWithConfig(request, config, deadlineNs,
|
|
executionResult);
|
|
}
|
|
|
|
ndk::ScopedAStatus ShimBurst::releaseMemoryResource(int64_t memoryIdentifierToken) {
|
|
if (memoryIdentifierToken < -1) {
|
|
return toAStatus(ErrorStatus::INVALID_ARGUMENT, "Invalid memoryIdentifierToken");
|
|
}
|
|
return ndk::ScopedAStatus::ok();
|
|
}
|
|
|
|
class ShimExecution : public BnExecution {
|
|
public:
|
|
explicit ShimExecution(
|
|
std::shared_ptr<const NnApiSupportLibrary> nnapi,
|
|
std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
|
|
bool measureTiming, int numberOfOutputs);
|
|
|
|
ndk::ScopedAStatus executeSynchronously(int64_t deadlineNs,
|
|
ExecutionResult* executionResult) override;
|
|
ndk::ScopedAStatus executeFenced(const std::vector<ndk::ScopedFileDescriptor>& waitFor,
|
|
int64_t deadlineNs, int64_t durationNs,
|
|
FencedExecutionResult* fencedExecutionResult) override;
|
|
|
|
protected:
|
|
std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
|
|
std::shared_ptr<const NnApiSupportLibrary> mNnapi;
|
|
std::shared_ptr<::android::nn::sl_wrapper::Execution> mExecution;
|
|
const std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> kRequestMemoryPools;
|
|
const bool kMeasureTiming;
|
|
const int kNumberOfOutputs;
|
|
};
|
|
|
|
ndk::ScopedAStatus ShimPreparedModel::createReusableExecution(
|
|
const Request& request, const ExecutionConfig& config,
|
|
std::shared_ptr<IExecution>* execution) {
|
|
auto wrapperExecution =
|
|
std::make_shared<::android::nn::sl_wrapper::Execution>(mNnapi.get(), &mCompilation);
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools;
|
|
auto errorStatus =
|
|
parseInputs(request, config.measureTiming, kNoDeadline, config.loopTimeoutDurationNs,
|
|
wrapperExecution.get(), &requestMemoryPools, config.executionHints,
|
|
config.extensionNameToPrefix);
|
|
if (errorStatus != ErrorStatus::NONE) {
|
|
return toAStatus(errorStatus);
|
|
}
|
|
auto result = wrapperExecution->setReusable(true);
|
|
SLW2SAS_RETURN_IF_ERROR(result);
|
|
|
|
*execution = ndk::SharedRefBase::make<ShimExecution>(
|
|
mNnapi, std::move(wrapperExecution), std::move(requestMemoryPools),
|
|
config.measureTiming, request.outputs.size());
|
|
return ndk::ScopedAStatus::ok();
|
|
}
|
|
|
|
ShimExecution::ShimExecution(
|
|
std::shared_ptr<const NnApiSupportLibrary> nnapi,
|
|
std::shared_ptr<::android::nn::sl_wrapper::Execution> execution,
|
|
std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>> requestMemoryPools,
|
|
bool measureTiming, int numberOfOutputs)
|
|
: mNnapi(std::move(nnapi)),
|
|
mExecution(std::move(execution)),
|
|
kRequestMemoryPools(std::move(requestMemoryPools)),
|
|
kMeasureTiming(measureTiming),
|
|
kNumberOfOutputs(numberOfOutputs) {}
|
|
|
|
ndk::ScopedAStatus ShimExecution::executeSynchronously(int64_t deadlineNs,
|
|
ExecutionResult* executionResult) {
|
|
if (deadlineNs < -1) {
|
|
LOG(ERROR) << "Invalid deadline value, must be >= -1";
|
|
return ndk::ScopedAStatus::fromServiceSpecificError(
|
|
static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
|
|
}
|
|
|
|
// Ensure at most one execution is in flight at a time.
|
|
const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
|
|
if (executionAlreadyInFlight) {
|
|
return toAStatus(ErrorStatus::GENERAL_FAILURE,
|
|
"Execution object supports at most one execution at a time");
|
|
}
|
|
const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
|
|
|
|
return executeSynchronouslyInternal(mExecution, kMeasureTiming, kNumberOfOutputs,
|
|
executionResult);
|
|
}
|
|
|
|
ndk::ScopedAStatus ShimExecution::executeFenced(
|
|
const std::vector<ndk::ScopedFileDescriptor>& waitFor, int64_t deadlineNs,
|
|
int64_t durationNs, FencedExecutionResult* fencedExecutionResult) {
|
|
if (deadlineNs < -1) {
|
|
LOG(ERROR) << "Invalid deadline value, must be >= -1";
|
|
return ndk::ScopedAStatus::fromServiceSpecificError(
|
|
static_cast<int>(ErrorStatus::INVALID_ARGUMENT));
|
|
}
|
|
|
|
// Ensure at most one execution is in flight at a time.
|
|
const bool executionAlreadyInFlight = mExecutionInFlight.test_and_set();
|
|
if (executionAlreadyInFlight) {
|
|
return toAStatus(ErrorStatus::GENERAL_FAILURE,
|
|
"Execution object supports at most one execution at a time");
|
|
}
|
|
const auto guard = ::android::base::make_scope_guard([this] { mExecutionInFlight.clear(); });
|
|
|
|
return executeFencedInternal(mNnapi, mExecution, kRequestMemoryPools, waitFor, durationNs,
|
|
kMeasureTiming, fencedExecutionResult);
|
|
}
|
|
|
|
} // namespace aidl::android::hardware::neuralnetworks
|