Index: source/ps/Future.h =================================================================== --- source/ps/Future.h +++ source/ps/Future.h @@ -24,11 +24,9 @@ #include #include #include +#include #include -template -class PackagedTask; - namespace FutureSharedStateDetail { enum class Status @@ -39,45 +37,44 @@ CANCELED }; -template -class SharedStateResult +class VoidTag {}; + +template +using HandleVoid = std::conditional_t, VoidTag, T>; + +/** + * Abstract base-class of SharedState only exposes the interface for PackagedTask: The function call + * operator. Checks if the task even has to be started. + */ +class ErasedSharedState { +protected: + ~ErasedSharedState() = default; + public: - void ResetResult() + void operator()() { - if (m_HasResult) - m_Result.m_Result.~ResultType(); - m_HasResult = false; + Status expected = Status::PENDING; + if (m_Status.compare_exchange_strong(expected, Status::STARTED)) + Execute(); } - union Result - { - std::aligned_storage_t m_Bytes; - ResultType m_Result; - Result() : m_Bytes() {}; - ~Result() {}; - }; - // We don't use Result directly so the result doesn't have to be default constructible. - Result m_Result; - bool m_HasResult = false; -}; +private: + virtual void Execute() = 0; -// Don't have m_Result for void ReturnType -template<> -class SharedStateResult -{ +protected: + std::atomic m_Status = Status::PENDING; }; /** - * The shared state between futures and packaged state. - * Holds all relevant data. + * Abstract class exposing the interface for Future. A class derived from this holds the actual functor. */ template -class SharedState : public SharedStateResult +class SharedState : public ErasedSharedState { static constexpr bool VoidResult = std::is_same_v; -public: - SharedState(std::function&& func) : m_Func(std::move(func)) {} + +protected: ~SharedState() { // For safety, wait on started task completion, but not on pending ones (auto-cancelled). @@ -86,12 +83,9 @@ Wait(); Cancel(); } - if constexpr (!VoidResult) - SharedStateResult::ResetResult(); } - SharedState(const SharedState&) = delete; - SharedState(SharedState&&) = delete; +public: bool IsDoneOrCanceled() const { @@ -122,7 +116,7 @@ if (m_Status == Status::DONE) m_Status = Status::CANCELED; if constexpr (!VoidResult) - SharedStateResult::ResetResult(); + m_Result.reset(); m_ConditionVariable.notify_all(); return cancelled; } @@ -136,21 +130,80 @@ std::enable_if_t, ResultType> GetResult() { // The caller must ensure that this is only called if we have a result. - ENSURE(SharedStateResult::m_HasResult); + ENSURE(m_Result.has_value()); m_Status = Status::CANCELED; - SharedStateResult::m_HasResult = false; - return std::move(SharedStateResult::m_Result.m_Result); + return *std::move(m_Result); } - - std::atomic m_Status = Status::PENDING; + using ErasedSharedState::m_Status; std::mutex m_Mutex; std::condition_variable m_ConditionVariable; std::function m_Func; + std::optional> m_Result; +}; + +template +class SharedStateImpl final : public SharedState> +{ + using ResultType = std::invoke_result_t; +public: + explicit SharedStateImpl(Fun&& func) : m_Func(std::forward(func)) {} + ~SharedStateImpl() = default; + +private: + void Execute() override + { + if constexpr (std::is_void_v) + { + m_Func(); + SharedState::m_Result.emplace(VoidTag{}); + } + else + SharedState::m_Result.emplace(m_Func()); + + // Because we might have threads waiting on us, we need to make sure that they either: + // - don't wait on our condition variable + // - receive the notification when we're done. + // This requires locking the mutex (@see Wait). + { + std::lock_guard lock(SharedState::m_Mutex); + SharedState::m_Status = Status::DONE; + } + + SharedState::m_ConditionVariable.notify_all(); + + } + + Fun m_Func; }; } // namespace FutureSharedStateDetail +/** + * A lightweight handle to a task can be used for storing in a task queue + * This type is just the shared state and the call operator. + */ +class PackagedTask +{ +public: + PackagedTask() = default; + explicit PackagedTask(std::shared_ptr&& ss) : + m_SharedState(std::move(ss)) + {} + + void operator()() + { + ENSURE(m_SharedState && "PackagedTask was empty probably it was invoked a second time."); + (*m_SharedState)(); + + // We no longer need the shared state, drop it immediately. + m_SharedState.reset(); + } + +protected: + std::shared_ptr m_SharedState; +}; + /** * Corresponds to std::future. * Unlike std::future, Future can request the cancellation of the task that would produce the result. @@ -167,9 +220,6 @@ template class Future { - template - friend class PackagedTask; - static constexpr bool VoidResult = std::is_same_v; using Status = FutureSharedStateDetail::Status; @@ -185,7 +235,14 @@ * Make the future wait for the result of @a func. */ template - PackagedTask Wrap(T&& func); + PackagedTask Wrap(T&& func) + { + static_assert(std::is_same_v, ResultType>, "The return type of the " + "wrapped function is not the same as the type of the Future."); + m_SharedState = std::make_shared>( + std::forward(func)); + return PackagedTask{m_SharedState}; + } /** * Move the result out of the future, and invalidate the future. @@ -260,67 +317,4 @@ std::shared_ptr m_SharedState; }; -/** - * Corresponds somewhat to std::packaged_task. - * Like packaged_task, this holds a function acting as a promise. - * This type is mostly just the shared state and the call operator, - * handling the promise & continuation logic. - */ -template -class PackagedTask -{ - static constexpr bool VoidResult = std::is_same_v; -public: - PackagedTask() = delete; - PackagedTask(std::shared_ptr::SharedState> ss) : m_SharedState(std::move(ss)) {} - - void operator()() - { - typename Future::Status expected = Future::Status::PENDING; - if (!m_SharedState->m_Status.compare_exchange_strong(expected, Future::Status::STARTED)) - return; - - if constexpr (VoidResult) - m_SharedState->m_Func(); - else - { - // To avoid UB, explicitly placement-new the value. - new (&m_SharedState->m_Result) ResultType{std::move(m_SharedState->m_Func())}; - m_SharedState->m_HasResult = true; - } - - // Because we might have threads waiting on us, we need to make sure that they either: - // - don't wait on our condition variable - // - receive the notification when we're done. - // This requires locking the mutex (@see Wait). - { - std::lock_guard lock(m_SharedState->m_Mutex); - m_SharedState->m_Status = Future::Status::DONE; - } - - m_SharedState->m_ConditionVariable.notify_all(); - - // We no longer need the shared state, drop it immediately. - m_SharedState.reset(); - } - - void Cancel() - { - m_SharedState->Cancel(); - m_SharedState.reset(); - } - -protected: - std::shared_ptr::SharedState> m_SharedState; -}; - -template -template -PackagedTask Future::Wrap(T&& func) -{ - static_assert(std::is_convertible_v, ResultType>, "The return type of the wrapped function cannot be converted to the type of the Future."); - m_SharedState = std::make_shared(std::move(func)); - return PackagedTask(m_SharedState); -} - #endif // INCLUDED_FUTURE Index: source/ps/TaskManager.h =================================================================== --- source/ps/TaskManager.h +++ source/ps/TaskManager.h @@ -67,14 +67,14 @@ Future> PushTask(T&& func, TaskPriority priority = TaskPriority::NORMAL) { Future> ret; - DoPushTask(ret.Wrap(std::move(func)), priority); + DoPushTask(ret.Wrap(std::forward(func)), priority); return ret; } private: TaskManager(size_t numberOfWorkers); - void DoPushTask(std::function&& task, TaskPriority priority); + void DoPushTask(PackagedTask&& task, TaskPriority priority); class Impl; std::unique_ptr m; Index: source/ps/TaskManager.cpp =================================================================== --- source/ps/TaskManager.cpp +++ source/ps/TaskManager.cpp @@ -50,7 +50,7 @@ class Thread; -using QueueItem = std::function; +using QueueItem = PackagedTask; /** * Light wrapper around std::thread. Ensures Join has been called. @@ -130,13 +130,13 @@ * Takes ownership of @a task. * May be called from any thread. */ - void PushTask(std::function&& task, TaskPriority priority); + void PushTask(QueueItem&& task, TaskPriority priority); protected: void ClearQueue(); template - bool PopTask(std::function& taskOut); + bool PopTask(QueueItem& taskOut); // Back reference (keep this first). TaskManager& m_TaskManager; @@ -197,12 +197,12 @@ return m->m_Workers.size(); } -void TaskManager::DoPushTask(std::function&& task, TaskPriority priority) +void TaskManager::DoPushTask(QueueItem&& task, TaskPriority priority) { m->PushTask(std::move(task), priority); } -void TaskManager::Impl::PushTask(std::function&& task, TaskPriority priority) +void TaskManager::Impl::PushTask(QueueItem&& task, TaskPriority priority) { std::mutex& mutex = priority == TaskPriority::NORMAL ? m_GlobalMutex : m_GlobalLowPriorityMutex; std::deque& queue = priority == TaskPriority::NORMAL ? m_GlobalQueue : m_GlobalLowPriorityQueue; @@ -218,7 +218,7 @@ } template -bool TaskManager::Impl::PopTask(std::function& taskOut) +bool TaskManager::Impl::PopTask(QueueItem& taskOut) { std::mutex& mutex = Priority == TaskPriority::NORMAL ? m_GlobalMutex : m_GlobalLowPriorityMutex; std::deque& queue = Priority == TaskPriority::NORMAL ? m_GlobalQueue : m_GlobalLowPriorityQueue; @@ -278,7 +278,7 @@ g_Profiler2.RegisterCurrentThread(name); - std::function task; + QueueItem task; bool hasTask = false; std::unique_lock lock(m_Mutex, std::defer_lock); while (!m_Kill) Index: source/ps/tests/test_Future.h =================================================================== --- source/ps/tests/test_Future.h +++ source/ps/tests/test_Future.h @@ -1,4 +1,4 @@ -/* Copyright (C) 2021 Wildfire Games. +/* Copyright (C) 2022 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify @@ -30,7 +30,7 @@ int counter = 0; { Future noret; - std::function task = noret.Wrap([&counter]() mutable { counter++; }); + PackagedTask task = noret.Wrap([&counter]() mutable { counter++; }); task(); TS_ASSERT_EQUALS(counter, 1); } @@ -38,7 +38,7 @@ { Future noret; { - std::function task = noret.Wrap([&counter]() mutable { counter++; }); + PackagedTask task = noret.Wrap([&counter]() mutable { counter++; }); // Auto-cancels the task. } } @@ -49,16 +49,10 @@ { { Future future; - std::function task = future.Wrap([]() { return 1; }); - task(); - TS_ASSERT_EQUALS(future.Get(), 1); - } - - // Convertible type. - { - Future future; - std::function task = future.Wrap([]() -> u8 { return 1; }); + PackagedTask task = future.Wrap([]() { return 1; }); + TS_ASSERT(future.Valid()); task(); + TS_ASSERT(future.Valid()); TS_ASSERT_EQUALS(future.Get(), 1); } @@ -80,21 +74,21 @@ TS_ASSERT_EQUALS(destroyed, 0); { Future future; - std::function task = future.Wrap([]() { return 1; }); + PackagedTask task = future.Wrap([]() { return NonDef{1}; }); task(); TS_ASSERT_EQUALS(future.Get().value, 1); } TS_ASSERT_EQUALS(destroyed, 1); { Future future; - std::function task = future.Wrap([]() { return 1; }); + PackagedTask task = future.Wrap([]() { return NonDef{1}; }); } TS_ASSERT_EQUALS(destroyed, 1); /** * TODO: find a way to test this { Future future; - std::function task = future.Wrap([]() { return 1; }); + PackagedTask task = future.Wrap([]() { return NonDef{1}; }); future.Cancel(); future.Wait(); TS_ASSERT_THROWS(future.Get(), const Future::BadFutureAccess&); @@ -117,7 +111,7 @@ c = new (&functionStorage) std::function{}; *c = []() { return 7; }; - std::function task = f->Wrap(std::move(*c)); + PackagedTask task = f->Wrap(std::move(*c)); future = std::move(*f); function = std::move(*c); @@ -129,7 +123,7 @@ memset(&functionStorage, 0xFF, sizeof(decltype(functionStorage))); // Let's move the packaged task while at it. - std::function task2 = std::move(task); + PackagedTask task2 = std::move(task); task2(); TS_ASSERT_EQUALS(future.Get(), 7); }