166 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			166 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			C++
		
	
	
	
| /*
 | |
|  * Copyright (C) 2017 The Android Open Source Project
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *      http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| #define LOG_TAG "ExecutionCallback"
 | |
| 
 | |
| #include "ExecutionCallback.h"
 | |
| 
 | |
| #include <android-base/logging.h>
 | |
| 
 | |
| #include <limits>
 | |
| #include <utility>
 | |
| #include <vector>
 | |
| 
 | |
| namespace android::nn {
 | |
| 
 | |
| void ExecutionCallback::notify(ErrorStatus status, const std::vector<OutputShape>& outputShapes,
 | |
|                                const Timing& timing) {
 | |
|     notifyInternal(status, outputShapes, timing);
 | |
| }
 | |
| 
 | |
| void ExecutionCallback::wait() const {
 | |
|     std::unique_lock<std::mutex> lock(mMutex);
 | |
|     mCondition.wait(lock, [this] { return mNotified; });
 | |
| 
 | |
|     /*
 | |
|      * Note that we cannot call std::thread::join from ExecutionCallback's
 | |
|      * destructor: ExecutionCallback is intended to be reference counted, and it
 | |
|      * is possible that the reference count drops to zero in the bound thread,
 | |
|      * causing the bound thread to call this destructor. If a thread tries to
 | |
|      * join itself, it throws an exception, producing a message like the
 | |
|      * following:
 | |
|      *
 | |
|      *     terminating with uncaught exception of type std::__1::system_error:
 | |
|      *     thread::join failed: Resource deadlock would occur
 | |
|      */
 | |
|     if (mThread.joinable()) {
 | |
|         mThread.join();
 | |
|     }
 | |
| }
 | |
| 
 | |
| ErrorStatus ExecutionCallback::getStatus() const {
 | |
|     wait();
 | |
|     return mErrorStatus;
 | |
| }
 | |
| 
 | |
| const std::vector<OutputShape>& ExecutionCallback::getOutputShapes() const {
 | |
|     wait();
 | |
|     return mOutputShapes;
 | |
| }
 | |
| 
 | |
| Timing ExecutionCallback::getTiming() const {
 | |
|     wait();
 | |
|     return mTiming;
 | |
| }
 | |
| 
 | |
| bool ExecutionCallback::bindThread(std::thread asyncThread) {
 | |
|     std::lock_guard<std::mutex> lock(mMutex);
 | |
| 
 | |
|     // Ensure ExecutionCallback object does not already have a thread bound
 | |
|     if (mThread.joinable()) {
 | |
|         LOG(ERROR) << "ExecutionCallback::bindThread -- a thread has already been bound to this "
 | |
|                       "callback object";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     // Ensure the new thread is valid
 | |
|     if (!asyncThread.joinable()) {
 | |
|         LOG(ERROR) << "ExecutionCallback::bindThread -- the new thread is not joinable";
 | |
|         return false;
 | |
|     }
 | |
| 
 | |
|     mThread = std::move(asyncThread);
 | |
|     return true;
 | |
| }
 | |
| 
 | |
| void ExecutionCallback::setOnFinish(const ExecutionFinish& finish) {
 | |
|     std::lock_guard<std::mutex> hold(mMutex);
 | |
| 
 | |
|     // Ensure ExecutionCallback object does not already have a "finish" callback
 | |
|     if (mOnFinish != nullptr) {
 | |
|         LOG(ERROR) << "ExecutionCallback::setOnFinish -- object already has a \"finish\" callback";
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     // Ensure new "finish" callback is valid
 | |
|     if (finish == nullptr) {
 | |
|         LOG(ERROR) << "ExecutionCallback::setOnFinish -- \"finish\" callback is invalid";
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     // Essure ExecutionCallback object has not already been notified
 | |
|     if (mNotified) {
 | |
|         LOG(ERROR) << "ExecutionCallback::setOnFinish -- ExecutionCallback has already been "
 | |
|                       "notified with results";
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     mOnFinish = finish;
 | |
| }
 | |
| 
 | |
| void ExecutionCallback::notifyInternal(ErrorStatus errorStatus,
 | |
|                                        std::vector<OutputShape> outputShapes, Timing timing) {
 | |
|     // check results
 | |
|     {
 | |
|         if (errorStatus == ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
 | |
|             // outputShapes must not be empty if OUTPUT_INSUFFICIENT_SIZE.
 | |
|             if (outputShapes.size() == 0) {
 | |
|                 LOG(ERROR)
 | |
|                         << "Notified with empty output shape vector when OUTPUT_INSUFFICIENT_SIZE";
 | |
|                 errorStatus = ErrorStatus::GENERAL_FAILURE;
 | |
|                 outputShapes = {};
 | |
|                 timing = {};
 | |
|             }
 | |
|         } else if (errorStatus != ErrorStatus::NONE) {
 | |
|             // outputShapes must be empty if errorStatus is neither NONE nor
 | |
|             // OUTPUT_INSUFFICIENT_SIZE.
 | |
|             if (outputShapes.size() != 0) {
 | |
|                 LOG(ERROR) << "Notified with non-empty output shape vector when error status is "
 | |
|                               "neither NONE nor OUTPUT_INSUFFICIENT_SIZE";
 | |
|                 errorStatus = ErrorStatus::GENERAL_FAILURE;
 | |
|                 outputShapes = {};
 | |
|                 timing = {};
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     // store results
 | |
|     {
 | |
|         std::lock_guard<std::mutex> hold(mMutex);
 | |
| 
 | |
|         // quick-return if object has already been notified
 | |
|         if (mNotified) {
 | |
|             return;
 | |
|         }
 | |
| 
 | |
|         mErrorStatus = errorStatus;
 | |
|         mOutputShapes = std::move(outputShapes);
 | |
|         mTiming = timing;
 | |
|         mNotified = true;
 | |
| 
 | |
|         if (mOnFinish != nullptr) {
 | |
|             ErrorStatus status = mOnFinish(mErrorStatus, mOutputShapes);
 | |
|             mOnFinish = nullptr;
 | |
|             if (status != ErrorStatus::NONE) {
 | |
|                 mErrorStatus = status;
 | |
|             }
 | |
|         }
 | |
|     }
 | |
|     mCondition.notify_all();
 | |
| }
 | |
| 
 | |
| }  // namespace android::nn
 |