Index: source/graphics/MapGenerator.h =================================================================== --- source/graphics/MapGenerator.h +++ source/graphics/MapGenerator.h @@ -19,6 +19,7 @@ #define INCLUDED_MAPGENERATOR #include "lib/file/vfs/vfs_path.h" +#include "ps/Future.h" #include "scriptinterface/StructuredClone.h" #include @@ -37,7 +38,7 @@ * data, according to this format: * https://trac.wildfiregames.com/wiki/Random_Map_Generator_Internals#Dataformat */ -Script::StructuredClone RunMapGenerationScript(std::atomic& progress, +Script::StructuredClone RunMapGenerationScript(StopToken st, std::atomic& progress, ScriptInterface& scriptInterface, const VfsPath& script, const std::string& settings, const u16 flags = JSPROP_ENUMERATE | JSPROP_READONLY | JSPROP_PERMANENT); Index: source/graphics/MapGenerator.cpp =================================================================== --- source/graphics/MapGenerator.cpp +++ source/graphics/MapGenerator.cpp @@ -28,6 +28,7 @@ #include "maths/MathUtil.h" #include "ps/CLogger.h" #include "ps/FileIo.h" +#include "ps/Future.h" #include "ps/Profile.h" #include "ps/scripting/JSInterface_VFS.h" #include "ps/TemplateLoader.h" @@ -44,21 +45,9 @@ #include #include -extern bool IsQuitRequested(); - namespace { -bool MapGenerationInterruptCallback(JSContext* UNUSED(cx)) -{ - // This may not use SDL_IsQuitRequested(), because it runs in a thread separate to SDL, see SDL_PumpEvents - if (IsQuitRequested()) - { - LOGWARNING("Quit requested!"); - return false; - } - - return true; -} +bool MapGenerationInterruptCallback(JSContext* cx); /** * Provides callback's for the JavaScript. @@ -68,8 +57,9 @@ public: // Only the constructor and the destructor are called by C++. - CMapGenerationCallbacks(std::atomic& progress, ScriptInterface& scriptInterface, + CMapGenerationCallbacks(StopToken st, std::atomic& progress, ScriptInterface& scriptInterface, Script::StructuredClone& mapData, const u16 flags) : + m_ST{st}, m_Progress{progress}, m_ScriptInterface{scriptInterface}, m_MapData{mapData} @@ -127,6 +117,8 @@ m_ScriptInterface.SetCallbackData(nullptr); } + StopToken m_ST; + private: // These functions are called by JS. @@ -373,10 +365,17 @@ */ CTemplateLoader m_TemplateLoader; }; + +bool MapGenerationInterruptCallback(JSContext* cx) +{ + // This may not use SDL_IsQuitRequested(), because it runs in a thread separate to SDL, see SDL_PumpEvents + return !ScriptInterface::ObjectFromCBData( + ScriptInterface::CmptPrivate::GetScriptInterface(cx))->m_ST.IsStopRequested(); +} } // anonymous namespace -Script::StructuredClone RunMapGenerationScript(std::atomic& progress, ScriptInterface& scriptInterface, - const VfsPath& script, const std::string& settings, const u16 flags) +Script::StructuredClone RunMapGenerationScript(StopToken st, std::atomic& progress, + ScriptInterface& scriptInterface, const VfsPath& script, const std::string& settings, const u16 flags) { ScriptRequest rq(scriptInterface); @@ -405,7 +404,7 @@ scriptInterface.ReplaceNondeterministicRNG(mapGenRNG); Script::StructuredClone mapData; - CMapGenerationCallbacks callbackData{progress, scriptInterface, mapData, flags}; + CMapGenerationCallbacks callbackData{st, progress, scriptInterface, mapData, flags}; // Copy settings to global variable JS::RootedValue global(rq.cx, rq.globalValue()); Index: source/graphics/MapReader.cpp =================================================================== --- source/graphics/MapReader.cpp +++ source/graphics/MapReader.cpp @@ -58,6 +58,8 @@ #include +extern bool IsQuitRequested(); + #if defined(_MSC_VER) && _MSC_VER > 1900 #pragma warning(disable: 4456) // Declaration hides previous local declaration. #pragma warning(disable: 4458) // Declaration hides class member. @@ -1346,7 +1348,7 @@ // The settings are stringified to pass them to the task. m_GeneratorState->task = Threading::TaskManager::Instance().PushTask( [&progress = m_GeneratorState->progress, scriptFile, - settings = Script::StringifyJSON(rq, &m_ScriptSettings)] + settings = Script::StringifyJSON(rq, &m_ScriptSettings)](StopToken st) { PROFILE2("Map Generation"); @@ -1356,7 +1358,7 @@ MAP_GENERATION_CONTEXT_SIZE)}; ScriptInterface mapgenInterface{"Engine", "MapGenerator", mapgenContext}; - return RunMapGenerationScript(progress, mapgenInterface, scriptPath, settings); + return RunMapGenerationScript(st, progress, mapgenInterface, scriptPath, settings); }); return 0; @@ -1366,13 +1368,19 @@ { throw PSERROR_Game_World_MapLoadFailed{ "Error generating random map.\nCheck application log for details."}; -}; +} int CMapReader::PollMapGeneration() { ENSURE(m_GeneratorState); - if (!m_GeneratorState->task.IsReady()) + if (IsQuitRequested()) + { + LOGWARNING("Quit requested!"); + return -1; + } + + if (!m_GeneratorState->task.IsDone()) return m_GeneratorState->progress.load(); const Script::StructuredClone results{m_GeneratorState->task.Get()}; Index: source/graphics/tests/test_MapGenerator.h =================================================================== --- source/graphics/tests/test_MapGenerator.h +++ source/graphics/tests/test_MapGenerator.h @@ -17,6 +17,7 @@ #include "graphics/MapGenerator.h" #include "ps/Filesystem.h" +#include "ps/Future.h" #include "simulation2/system/ComponentTest.h" #include @@ -57,9 +58,10 @@ // It's never read in the test so it doesn't matter to what value it's initialized. For // good practice it's initialized to 1. std::atomic progress{1}; - - const Script::StructuredClone result{RunMapGenerationScript(progress, scriptInterface, - path, "{\"Seed\": 0}", JSPROP_ENUMERATE | JSPROP_PERMANENT)}; + std::atomic stopRequest{false}; + const Script::StructuredClone result{RunMapGenerationScript(StopToken{stopRequest}, + progress, scriptInterface, path, "{\"Seed\": 0}", + JSPROP_ENUMERATE | JSPROP_PERMANENT)}; // The test scripts don't call `ExportMap` so `RunMapGenerationScript` allways returns // `nullptr`. Index: source/ps/Future.h =================================================================== --- source/ps/Future.h +++ source/ps/Future.h @@ -30,16 +30,27 @@ template class PackagedTask; -namespace FutureSharedStateDetail -{ -enum class Status +class StopToken { - PENDING, - STARTED, - DONE, - CANCELED +public: + explicit StopToken(const std::atomic& request) : + m_Request{request} + {} + + bool IsStopRequested() + { + return m_Request.load(); + } +private: + const std::atomic& m_Request; }; +template +using CallbackResult = typename std::conditional_t, + std::invoke_result, std::invoke_result>::type; + +namespace FutureSharedStateDetail +{ template using ResultHolder = std::conditional_t, std::nullopt_t, std::optional>; @@ -57,51 +68,37 @@ {} ~Receiver() { - // For safety, wait on started task completion, but not on pending ones (auto-cancelled). - if (!Cancel()) - { - Wait(); - Cancel(); - } + // For safety, wait on task completion (auto-cancelled). + RequestStop(); + Wait(); } Receiver(const Receiver&) = delete; Receiver(Receiver&&) = delete; - bool IsDoneOrCanceled() const + bool IsDone() const { - return m_Status == Status::DONE || m_Status == Status::CANCELED; + return m_Done.load(); } void Wait() { // Fast path: we're already done. - if (IsDoneOrCanceled()) + if (IsDone()) return; // Slow path: we aren't done when we run the above check. Lock and wait until we are. std::unique_lock lock(m_Mutex); - m_ConditionVariable.wait(lock, [this]() -> bool { return IsDoneOrCanceled(); }); + m_ConditionVariable.wait(lock, [this]() -> bool { return IsDone(); }); } /** - * If the task is pending, cancel it: the status becomes CANCELED and if the task was completed, the result is destroyed. - * @return true if the task was indeed cancelled, false otherwise (the task is running or already done). + * Requests the executing thread to stop as fast as possible. This is only + * a request the executing thread might ignore it. + * @see GetResult must not be called after a call to @p RequestStop. */ - bool Cancel() + void RequestStop() { - Status expected = Status::PENDING; - bool cancelled = m_Status.compare_exchange_strong(expected, Status::CANCELED); - // If we're done, invalidate, if we're pending, atomically cancel, otherwise fail. - if (cancelled || m_Status == Status::DONE) - { - if (m_Status == Status::DONE) - m_Status = Status::CANCELED; - if constexpr (!VoidResult) - this->reset(); - m_ConditionVariable.notify_all(); - return cancelled; - } - return false; + m_StopRequest.store(true); } /** @@ -112,13 +109,15 @@ { // The caller must ensure that this is only called if we have a result. ENSURE(this->has_value()); - m_Status = Status::CANCELED; ResultType ret = std::move(**this); this->reset(); return ret; } - std::atomic m_Status = Status::PENDING; + // This is only set by the receiving thread and read by the executing thread. It is never cleared. + std::atomic m_StopRequest{false}; + // This is only set by the executeng thread and read by the receiving thread. It is never cleared. + std::atomic m_Done{false}; std::mutex m_Mutex; std::condition_variable m_ConditionVariable; }; @@ -131,7 +130,7 @@ {} Callback callback; - Receiver> receiver; + Receiver> receiver; }; } // namespace FutureSharedStateDetail @@ -156,8 +155,6 @@ friend class PackagedTask; static constexpr bool VoidResult = std::is_same_v; - - using Status = FutureSharedStateDetail::Status; public: Future() = default; Future(const Future& o) = delete; @@ -174,7 +171,7 @@ /** * Move the result out of the future, and invalidate the future. * If the future is not complete, calls Wait(). - * If the future is canceled, asserts. + * If the future is invalid, asserts. */ template std::enable_if_t, ResultType> Get() @@ -186,27 +183,27 @@ return; else { - ENSURE(m_Receiver->m_Status != Status::CANCELED); - - // This mark the state invalid - can't call Get again. - return m_Receiver->GetResult(); + SfinaeType result = m_Receiver->GetResult(); + m_Receiver.reset(); + return result; } } /** - * @return true if the shared state is valid and has a result (i.e. Get can be called). + * @return true if the shared state is valid and the callback completed + * (i.e. Get can be called). */ - bool IsReady() const + bool IsDone() const { - return !!m_Receiver && m_Receiver->m_Status == Status::DONE; + return Valid() && m_Receiver->m_Done.load(); } /** - * @return true if the future has a shared state and it's not been invalidated, ie. pending, started or done. + * @return true if the future has a shared state. */ bool Valid() const { - return !!m_Receiver && m_Receiver->m_Status != Status::CANCELED; + return !!m_Receiver; } void Wait() @@ -215,17 +212,12 @@ m_Receiver->Wait(); } - /** - * Cancels the task, waiting if the task is currently started. - * Use this function over Cancel() if you need to ensure determinism (i.e. in the simulation). - * @see Cancel. - */ void CancelOrWait() { if (!Valid()) return; - if (!m_Receiver->Cancel()) - m_Receiver->Wait(); + m_Receiver->RequestStop(); + m_Receiver->Wait(); m_Receiver.reset(); } @@ -250,25 +242,32 @@ void operator()() { - FutureSharedStateDetail::Status expected = FutureSharedStateDetail::Status::PENDING; - if (!m_SharedState->receiver.m_Status.compare_exchange_strong(expected, - FutureSharedStateDetail::Status::STARTED)) + if (!m_SharedState->receiver.m_StopRequest.load()) { - return; + const auto wrappedCallback = [&] + { + if constexpr (std::is_invocable_v) + return m_SharedState->callback( + StopToken{m_SharedState->receiver.m_StopRequest}); + else + return m_SharedState->callback(); + }; + static_assert(std::is_same_v, + CallbackResult>); + + if constexpr (std::is_void_v>) + wrappedCallback(); + else + m_SharedState->receiver.emplace(wrappedCallback()); } - if constexpr (std::is_void_v>) - m_SharedState->callback(); - else - m_SharedState->receiver.emplace(m_SharedState->callback()); - // Because we might have threads waiting on us, we need to make sure that they either: // - don't wait on our condition variable // - receive the notification when we're done. // This requires locking the mutex (@see Wait). { std::lock_guard lock(m_SharedState->receiver.m_Mutex); - m_SharedState->receiver.m_Status = FutureSharedStateDetail::Status::DONE; + m_SharedState->receiver.m_Done.store(true); } m_SharedState->receiver.m_ConditionVariable.notify_all(); @@ -291,8 +290,10 @@ template PackagedTask Future::Wrap(Callback&& callback) { - static_assert(std::is_same_v, ResultType>, + static_assert(std::is_same_v, ResultType>, "The return type of the wrapped function is not the same as the type the Future expects."); + static_assert(std::is_invocable_v || !std::is_invocable_v, + "Consider taking the `StopToken` by value"); auto temp = std::make_shared>(std::move(callback)); m_Receiver = {temp, &temp->receiver}; return PackagedTask(std::move(temp)); Index: source/ps/TaskManager.h =================================================================== --- source/ps/TaskManager.h +++ source/ps/TaskManager.h @@ -64,9 +64,9 @@ * Push a task to be executed. */ template - Future> PushTask(T&& func, TaskPriority priority = TaskPriority::NORMAL) + Future> PushTask(T&& func, TaskPriority priority = TaskPriority::NORMAL) { - Future> ret; + Future> ret; DoPushTask(ret.Wrap(std::move(func)), priority); return ret; } Index: source/ps/tests/test_Future.h =================================================================== --- source/ps/tests/test_Future.h +++ source/ps/tests/test_Future.h @@ -19,7 +19,9 @@ #include "ps/Future.h" +#include #include +#include #include class TestFuture : public CxxTest::TestSuite @@ -28,20 +30,9 @@ void test_future_basic() { int counter = 0; - { - Future noret; - std::function task = noret.Wrap([&counter]() mutable { counter++; }); - task(); - TS_ASSERT_EQUALS(counter, 1); - } - - { - Future noret; - { - std::function task = noret.Wrap([&counter]() mutable { counter++; }); - // Auto-cancels the task. - } - } + Future noret; + std::function task = noret.Wrap([&counter]() mutable { counter++; }); + task(); TS_ASSERT_EQUALS(counter, 1); } @@ -77,11 +68,6 @@ TS_ASSERT_EQUALS(future.Get().value, 1); } TS_ASSERT_EQUALS(destroyed, 1); - { - Future future; - std::function task = future.Wrap([]() { return NonDef{1}; }); - } - TS_ASSERT_EQUALS(destroyed, 1); /** * TODO: find a way to test this { @@ -113,16 +99,16 @@ future = std::move(*f); function = std::move(*c); + // Let's move the packaged task while at it. + std::function task2 = std::move(task); + task2(); + TS_ASSERT_EQUALS(future.Get(), 7); + // Destroy and clear the memory f->~Future(); c->~function(); memset(&futureStorage, 0xFF, sizeof(decltype(futureStorage))); memset(&functionStorage, 0xFF, sizeof(decltype(functionStorage))); - - // Let's move the packaged task while at it. - std::function task2 = std::move(task); - task2(); - TS_ASSERT_EQUALS(future.Get(), 7); } void test_move_only_function() @@ -139,6 +125,70 @@ MoveOnlyType& operator=(MoveOnlyType&&) = default; }; - future.Wrap([t = MoveOnlyType{}]{}); + future.Wrap([t = MoveOnlyType{}]{})(); + } + + void test_stop_token_overload() + { + { + class DifferentTypes + { + public: + void operator()() + {} + int operator()(StopToken) + { + return 0; + } + }; + + Future future; + future.Wrap(DifferentTypes{})(); + } + { + class DifferentValues + { + public: + int operator()() + { + return 0; + } + int operator()(StopToken) + { + return 1; + } + }; + + Future future; + future.Wrap(DifferentValues{})(); + TS_ASSERT_EQUALS(future.Get(), 1); + } + Future{}.Wrap([](auto... args) + { + static_assert(sizeof...(args) == 1); + })(); + } + + void test_stop_token() + { + using namespace std::literals; + Future future; + auto task = future.Wrap([](StopToken st) + { + while(!st.IsStopRequested()); + }); + + + std::thread taskThread{task}; + std::this_thread::sleep_for(200ms); + // The task didn't stop for a long time (200ms). + TS_ASSERT(!future.IsDone()); + + const auto requestedAt = std::chrono::steady_clock::now(); + future.CancelOrWait(); + taskThread.join(); + const auto stopedAt = std::chrono::steady_clock::now(); + // The task did stop in a short time (5.56ms). + TS_ASSERT_LESS_THAN(stopedAt - requestedAt, 5.56ms); } };