466 lines
14 KiB
C++
466 lines
14 KiB
C++
// Copyright (C) 2019 The Android Open Source Project
|
|
// Copyright (C) 2019 Google Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
#include "android/base/threads/AndroidWorkPool.h"
|
|
|
|
#include "android/base/threads/AndroidFunctorThread.h"
|
|
#include "android/base/synchronization/AndroidLock.h"
|
|
#include "android/base/synchronization/AndroidConditionVariable.h"
|
|
#include "android/base/synchronization/AndroidMessageChannel.h"
|
|
|
|
#include <atomic>
|
|
#include <memory>
|
|
#include <unordered_map>
|
|
#include <sys/time.h>
|
|
|
|
using android::base::guest::AutoLock;
|
|
using android::base::guest::ConditionVariable;
|
|
using android::base::guest::FunctorThread;
|
|
using android::base::guest::Lock;
|
|
using android::base::guest::MessageChannel;
|
|
|
|
namespace android {
|
|
namespace base {
|
|
namespace guest {
|
|
|
|
class WaitGroup { // intrusive refcounted
|
|
public:
|
|
|
|
WaitGroup(int numTasksRemaining) :
|
|
mNumTasksInitial(numTasksRemaining),
|
|
mNumTasksRemaining(numTasksRemaining) { }
|
|
|
|
~WaitGroup() = default;
|
|
|
|
android::base::guest::Lock& getLock() { return mLock; }
|
|
|
|
void acquire() {
|
|
if (0 == mRefCount.fetch_add(1, std::memory_order_seq_cst)) {
|
|
ALOGE("%s: goofed, refcount0 acquire\n", __func__);
|
|
abort();
|
|
}
|
|
}
|
|
|
|
bool release() {
|
|
if (0 == mRefCount) {
|
|
ALOGE("%s: goofed, refcount0 release\n", __func__);
|
|
abort();
|
|
}
|
|
if (1 == mRefCount.fetch_sub(1, std::memory_order_seq_cst)) {
|
|
std::atomic_thread_fence(std::memory_order_acquire);
|
|
delete this;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// wait on all of or any of the associated tasks to complete.
|
|
bool waitAllLocked(WorkPool::TimeoutUs timeout) {
|
|
return conditionalTimeoutLocked(
|
|
[this] { return mNumTasksRemaining > 0; },
|
|
timeout);
|
|
}
|
|
|
|
bool waitAnyLocked(WorkPool::TimeoutUs timeout) {
|
|
return conditionalTimeoutLocked(
|
|
[this] { return mNumTasksRemaining == mNumTasksInitial; },
|
|
timeout);
|
|
}
|
|
|
|
// broadcasts to all waiters that there has been a new job that has completed
|
|
bool decrementBroadcast() {
|
|
AutoLock<Lock> lock(mLock);
|
|
bool done =
|
|
(1 == mNumTasksRemaining.fetch_sub(1, std::memory_order_seq_cst));
|
|
std::atomic_thread_fence(std::memory_order_acquire);
|
|
mCv.broadcast();
|
|
return done;
|
|
}
|
|
|
|
private:
|
|
|
|
bool doWait(WorkPool::TimeoutUs timeout) {
|
|
if (timeout == ~0ULL) {
|
|
ALOGV("%s: uncond wait\n", __func__);
|
|
mCv.wait(&mLock);
|
|
return true;
|
|
} else {
|
|
return mCv.timedWait(&mLock, getDeadline(timeout));
|
|
}
|
|
}
|
|
|
|
struct timespec getDeadline(WorkPool::TimeoutUs relative) {
|
|
struct timeval deadlineUs;
|
|
struct timespec deadlineNs;
|
|
gettimeofday(&deadlineUs, 0);
|
|
|
|
auto prevDeadlineUs = deadlineUs.tv_usec;
|
|
|
|
deadlineUs.tv_usec += relative;
|
|
|
|
// Wrap around
|
|
if (prevDeadlineUs > deadlineUs.tv_usec) {
|
|
++deadlineUs.tv_sec;
|
|
}
|
|
|
|
deadlineNs.tv_sec = deadlineUs.tv_sec;
|
|
deadlineNs.tv_nsec = deadlineUs.tv_usec * 1000LL;
|
|
return deadlineNs;
|
|
}
|
|
|
|
uint64_t currTimeUs() {
|
|
struct timeval tv;
|
|
gettimeofday(&tv, 0);
|
|
return (uint64_t)(tv.tv_sec * 1000000LL + tv.tv_usec);
|
|
}
|
|
|
|
bool conditionalTimeoutLocked(std::function<bool()> conditionFunc, WorkPool::TimeoutUs timeout) {
|
|
uint64_t currTime = currTimeUs();
|
|
WorkPool::TimeoutUs currTimeout = timeout;
|
|
|
|
while (conditionFunc()) {
|
|
doWait(currTimeout);
|
|
if (!conditionFunc()) {
|
|
// Decrement timeout for wakeups
|
|
uint64_t nextTime = currTimeUs();
|
|
WorkPool::TimeoutUs waited =
|
|
nextTime - currTime;
|
|
currTime = nextTime;
|
|
|
|
if (currTimeout > waited) {
|
|
currTimeout -= waited;
|
|
} else {
|
|
return conditionFunc();
|
|
}
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
std::atomic<int> mRefCount = { 1 };
|
|
int mNumTasksInitial;
|
|
std::atomic<int> mNumTasksRemaining;
|
|
|
|
Lock mLock;
|
|
ConditionVariable mCv;
|
|
};
|
|
|
|
class WorkPoolThread {
|
|
public:
|
|
// State diagram for each work pool thread
|
|
//
|
|
// Unacquired: (Start state) When no one else has claimed the thread.
|
|
// Acquired: When the thread has been claimed for work,
|
|
// but work has not been issued to it yet.
|
|
// Scheduled: When the thread is running tasks from the acquirer.
|
|
// Exiting: cleanup
|
|
//
|
|
// Messages:
|
|
//
|
|
// Acquire
|
|
// Run
|
|
// Exit
|
|
//
|
|
// Transitions:
|
|
//
|
|
// Note: While task is being run, messages will come back with a failure value.
|
|
//
|
|
// Unacquired:
|
|
// message Acquire -> Acquired. effect: return success value
|
|
// message Run -> Unacquired. effect: return failure value
|
|
// message Exit -> Exiting. effect: return success value
|
|
//
|
|
// Acquired:
|
|
// message Acquire -> Acquired. effect: return failure value
|
|
// message Run -> Scheduled. effect: run the task, return success
|
|
// message Exit -> Exiting. effect: return success value
|
|
//
|
|
// Scheduled:
|
|
// implicit effect: after task is run, transition back to Unacquired.
|
|
// message Acquire -> Scheduled. effect: return failure value
|
|
// message Run -> Scheduled. effect: return failure value
|
|
// message Exit -> queue up exit message, then transition to Exiting after that is done.
|
|
// effect: return success value
|
|
//
|
|
enum State {
|
|
Unacquired = 0,
|
|
Acquired = 1,
|
|
Scheduled = 2,
|
|
Exiting = 3,
|
|
};
|
|
|
|
WorkPoolThread() : mThread([this] { threadFunc(); }) {
|
|
mThread.start();
|
|
}
|
|
|
|
~WorkPoolThread() {
|
|
exit();
|
|
mThread.wait();
|
|
}
|
|
|
|
bool acquire() {
|
|
AutoLock<Lock> lock(mLock);
|
|
switch (mState) {
|
|
case State::Unacquired:
|
|
mState = State::Acquired;
|
|
return true;
|
|
case State::Acquired:
|
|
case State::Scheduled:
|
|
case State::Exiting:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool run(WorkPool::WaitGroupHandle waitGroupHandle, WaitGroup* waitGroup, WorkPool::Task task) {
|
|
AutoLock<Lock> lock(mLock);
|
|
switch (mState) {
|
|
case State::Unacquired:
|
|
return false;
|
|
case State::Acquired: {
|
|
mState = State::Scheduled;
|
|
mToCleanupWaitGroupHandle = waitGroupHandle;
|
|
waitGroup->acquire();
|
|
mToCleanupWaitGroup = waitGroup;
|
|
mShouldCleanupWaitGroup = false;
|
|
TaskInfo msg = {
|
|
Command::Run,
|
|
waitGroup, task,
|
|
};
|
|
mRunMessages.send(msg);
|
|
return true;
|
|
}
|
|
case State::Scheduled:
|
|
case State::Exiting:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool shouldCleanupWaitGroup(WorkPool::WaitGroupHandle* waitGroupHandle, WaitGroup** waitGroup) {
|
|
AutoLock<Lock> lock(mLock);
|
|
bool res = mShouldCleanupWaitGroup;
|
|
*waitGroupHandle = mToCleanupWaitGroupHandle;
|
|
*waitGroup = mToCleanupWaitGroup;
|
|
mShouldCleanupWaitGroup = false;
|
|
return res;
|
|
}
|
|
|
|
private:
|
|
enum Command {
|
|
Run = 0,
|
|
Exit = 1,
|
|
};
|
|
|
|
struct TaskInfo {
|
|
Command cmd;
|
|
WaitGroup* waitGroup = nullptr;
|
|
WorkPool::Task task = {};
|
|
};
|
|
|
|
bool exit() {
|
|
AutoLock<Lock> lock(mLock);
|
|
TaskInfo msg { Command::Exit, };
|
|
mRunMessages.send(msg);
|
|
return true;
|
|
}
|
|
|
|
void threadFunc() {
|
|
TaskInfo taskInfo;
|
|
bool done = false;
|
|
|
|
while (!done) {
|
|
mRunMessages.receive(&taskInfo);
|
|
switch (taskInfo.cmd) {
|
|
case Command::Run:
|
|
doRun(taskInfo);
|
|
break;
|
|
case Command::Exit: {
|
|
AutoLock<Lock> lock(mLock);
|
|
mState = State::Exiting;
|
|
break;
|
|
}
|
|
}
|
|
AutoLock<Lock> lock(mLock);
|
|
done = mState == State::Exiting;
|
|
}
|
|
}
|
|
|
|
// Assumption: the wait group refcount is >= 1 when entering
|
|
// this function (before decrement)..
|
|
// at least it doesn't get to 0
|
|
void doRun(TaskInfo& msg) {
|
|
WaitGroup* waitGroup = msg.waitGroup;
|
|
|
|
if (msg.task) msg.task();
|
|
|
|
bool lastTask =
|
|
waitGroup->decrementBroadcast();
|
|
|
|
AutoLock<Lock> lock(mLock);
|
|
mState = State::Unacquired;
|
|
|
|
if (lastTask) {
|
|
mShouldCleanupWaitGroup = true;
|
|
}
|
|
|
|
waitGroup->release();
|
|
}
|
|
|
|
FunctorThread mThread;
|
|
Lock mLock;
|
|
State mState = State::Unacquired;
|
|
MessageChannel<TaskInfo, 4> mRunMessages;
|
|
WorkPool::WaitGroupHandle mToCleanupWaitGroupHandle = 0;
|
|
WaitGroup* mToCleanupWaitGroup = nullptr;
|
|
bool mShouldCleanupWaitGroup = false;
|
|
};
|
|
|
|
class WorkPool::Impl {
|
|
public:
|
|
Impl(int numInitialThreads) : mThreads(numInitialThreads) {
|
|
for (size_t i = 0; i < mThreads.size(); ++i) {
|
|
mThreads[i].reset(new WorkPoolThread);
|
|
}
|
|
}
|
|
|
|
~Impl() = default;
|
|
|
|
WorkPool::WaitGroupHandle schedule(const std::vector<WorkPool::Task>& tasks) {
|
|
|
|
if (tasks.empty()) abort();
|
|
|
|
AutoLock<Lock> lock(mLock);
|
|
|
|
// Sweep old wait groups
|
|
for (size_t i = 0; i < mThreads.size(); ++i) {
|
|
WaitGroupHandle handle;
|
|
WaitGroup* waitGroup;
|
|
bool cleanup = mThreads[i]->shouldCleanupWaitGroup(&handle, &waitGroup);
|
|
if (cleanup) {
|
|
mWaitGroups.erase(handle);
|
|
waitGroup->release();
|
|
}
|
|
}
|
|
|
|
WorkPool::WaitGroupHandle resHandle = genWaitGroupHandleLocked();
|
|
WaitGroup* waitGroup =
|
|
new WaitGroup(tasks.size());
|
|
|
|
mWaitGroups[resHandle] = waitGroup;
|
|
|
|
std::vector<size_t> threadIndices;
|
|
|
|
while (threadIndices.size() < tasks.size()) {
|
|
for (size_t i = 0; i < mThreads.size(); ++i) {
|
|
if (!mThreads[i]->acquire()) continue;
|
|
threadIndices.push_back(i);
|
|
if (threadIndices.size() == tasks.size()) break;
|
|
}
|
|
if (threadIndices.size() < tasks.size()) {
|
|
mThreads.resize(mThreads.size() + 1);
|
|
mThreads[mThreads.size() - 1].reset(new WorkPoolThread);
|
|
}
|
|
}
|
|
|
|
// every thread here is acquired
|
|
for (size_t i = 0; i < threadIndices.size(); ++i) {
|
|
mThreads[threadIndices[i]]->run(resHandle, waitGroup, tasks[i]);
|
|
}
|
|
|
|
return resHandle;
|
|
}
|
|
|
|
bool waitAny(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
|
|
AutoLock<Lock> lock(mLock);
|
|
auto it = mWaitGroups.find(waitGroupHandle);
|
|
if (it == mWaitGroups.end()) return true;
|
|
|
|
auto waitGroup = it->second;
|
|
waitGroup->acquire();
|
|
lock.unlock();
|
|
|
|
bool waitRes = false;
|
|
|
|
{
|
|
AutoLock<Lock> waitGroupLock(waitGroup->getLock());
|
|
waitRes = waitGroup->waitAnyLocked(timeout);
|
|
}
|
|
|
|
waitGroup->release();
|
|
|
|
return waitRes;
|
|
}
|
|
|
|
bool waitAll(WorkPool::WaitGroupHandle waitGroupHandle, WorkPool::TimeoutUs timeout) {
|
|
auto waitGroup = acquireWaitGroupFromHandle(waitGroupHandle);
|
|
if (!waitGroup) return true;
|
|
|
|
bool waitRes = false;
|
|
|
|
{
|
|
AutoLock<Lock> waitGroupLock(waitGroup->getLock());
|
|
waitRes = waitGroup->waitAllLocked(timeout);
|
|
}
|
|
|
|
waitGroup->release();
|
|
|
|
return waitRes;
|
|
}
|
|
|
|
private:
|
|
// Increments wait group refcount by 1.
|
|
WaitGroup* acquireWaitGroupFromHandle(WorkPool::WaitGroupHandle waitGroupHandle) {
|
|
AutoLock<Lock> lock(mLock);
|
|
auto it = mWaitGroups.find(waitGroupHandle);
|
|
if (it == mWaitGroups.end()) return nullptr;
|
|
|
|
auto waitGroup = it->second;
|
|
waitGroup->acquire();
|
|
|
|
return waitGroup;
|
|
}
|
|
|
|
using WaitGroupStore = std::unordered_map<WorkPool::WaitGroupHandle, WaitGroup*>;
|
|
|
|
WorkPool::WaitGroupHandle genWaitGroupHandleLocked() {
|
|
WorkPool::WaitGroupHandle res = mNextWaitGroupHandle;
|
|
++mNextWaitGroupHandle;
|
|
return res;
|
|
}
|
|
|
|
Lock mLock;
|
|
uint64_t mNextWaitGroupHandle = 0;
|
|
WaitGroupStore mWaitGroups;
|
|
std::vector<std::unique_ptr<WorkPoolThread>> mThreads;
|
|
};
|
|
|
|
WorkPool::WorkPool(int numInitialThreads) : mImpl(new WorkPool::Impl(numInitialThreads)) { }
|
|
WorkPool::~WorkPool() = default;
|
|
|
|
WorkPool::WaitGroupHandle WorkPool::schedule(const std::vector<WorkPool::Task>& tasks) {
|
|
return mImpl->schedule(tasks);
|
|
}
|
|
|
|
bool WorkPool::waitAny(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
|
|
return mImpl->waitAny(waitGroup, timeout);
|
|
}
|
|
|
|
bool WorkPool::waitAll(WorkPool::WaitGroupHandle waitGroup, WorkPool::TimeoutUs timeout) {
|
|
return mImpl->waitAll(waitGroup, timeout);
|
|
}
|
|
|
|
} // namespace guest
|
|
} // namespace base
|
|
} // namespace android
|