Index: source/main.cpp =================================================================== --- source/main.cpp +++ source/main.cpp @@ -76,7 +76,7 @@ #include "graphics/TextureManager.h" #include "gui/GUIManager.h" #include "renderer/Renderer.h" -#include "rlinterface/RLInterface.cpp" +#include "rlinterface/RLInterface.h" #include "scriptinterface/ScriptEngine.h" #include "simulation2/Simulation2.h" #include "simulation2/system/TurnManager.h" @@ -484,52 +484,6 @@ debug_printf("RL interface listening on %s\n", server_address.c_str()); } -static void RunRLServer(const bool isNonVisual, const std::vector modsToInstall, const CmdLineArgs args) -{ - int flags = INIT_MODS; - while (!Init(args, flags)) - { - flags &= ~INIT_MODS; - Shutdown(SHUTDOWN_FROM_CONFIG); - } - g_Shutdown = ShutdownType::None; - - std::vector installedMods; - if (!modsToInstall.empty()) - { - Paths paths(args); - CModInstaller installer(paths.UserData() / "mods", paths.Cache()); - - // Install the mods without deleting the pyromod files - for (const OsPath& modPath : modsToInstall) - installer.Install(modPath, g_ScriptRuntime, true); - - installedMods = installer.GetInstalledMods(); - } - - if (isNonVisual) - { - InitNonVisual(args); - StartRLInterface(args); - while (g_Shutdown == ShutdownType::None) - g_RLInterface->TryApplyMessage(); - QuitEngine(); - } - else - { - InitGraphics(args, 0, installedMods); - MainControllerInit(); - StartRLInterface(args); - while (g_Shutdown == ShutdownType::None) - Frame(); - } - - Shutdown(0); - MainControllerShutdown(); - CXeromyces::Terminate(); - delete g_RLInterface; -} - // moved into a helper function to ensure args is destroyed before // exit(), which may result in a memory leak. static void RunGameOrAtlas(int argc, const char* argv[]) @@ -556,6 +510,7 @@ const bool isVisualReplay = args.Has("replay-visual"); const bool isNonVisualReplay = args.Has("replay"); const bool isNonVisual = args.Has("autostart-nonvisual"); + const bool isUsingRLInterface = args.Has("rl-interface"); const OsPath replayFile( isVisualReplay ? args.Get("replay-visual") : @@ -668,12 +623,6 @@ const double res = timer_Resolution(); g_frequencyFilter = CreateFrequencyFilter(res, 30.0); - if (args.Has("rl-interface")) - { - RunRLServer(isNonVisual, modsToInstall, args); - return; - } - // run the game int flags = INIT_MODS; do @@ -703,13 +652,21 @@ if (isNonVisual) { InitNonVisual(args); + if (isUsingRLInterface) + StartRLInterface(args); + while (g_Shutdown == ShutdownType::None) - NonVisualFrame(); + if (isUsingRLInterface) + g_RLInterface->TryApplyMessage(); + else + NonVisualFrame(); } else { InitGraphics(args, 0, installedMods); MainControllerInit(); + if (isUsingRLInterface) + StartRLInterface(args); while (g_Shutdown == ShutdownType::None) Frame(); } @@ -720,6 +677,8 @@ Shutdown(0); MainControllerShutdown(); flags &= ~INIT_MODS; + if (isUsingRLInterface) + delete g_RLInterface; } while (g_Shutdown == ShutdownType::Restart); Index: source/rlinterface/RLInterface.h =================================================================== --- source/rlinterface/RLInterface.h +++ source/rlinterface/RLInterface.h @@ -23,51 +23,60 @@ #include #include -struct ScenarioConfig { +struct ScenarioConfig +{ bool saveReplay; player_id_t playerID; std::string content; }; -struct Command { +struct RLGameCommand +{ int playerID; std::string json_cmd; }; -enum GameMessageType { Reset, Commands }; +enum class GameMessageType { Reset, Commands }; struct GameMessage { GameMessageType type; - std::vector commands; + std::vector commands; }; extern void EndGame(); struct mg_context; -const static std::string EMPTY_STATE; +/** + * Implements an interface providing fundamental capabilities required for reinforcement + * learning (over HTTP). + * + * This consists of enabling an external script to configure the scenario (via Reset) and + * then step the game engine manually and apply player actions (via Step). The interface + * also supports querying unit templates to provide information about max health and other + * potentially relevant game state information. + */ class RLInterface { - public: - std::string Step(const std::vector commands); + std::string Step(const std::vector& commands); std::string Reset(const ScenarioConfig* scenario); - std::vector GetTemplates(const std::vector names) const; + std::vector GetTemplates(const std::vector& names) const; void EnableHTTP(const char* server_address); - std::string SendGameMessage(const GameMessage msg); + std::string SendGameMessage(const GameMessage& msg); bool TryGetGameMessage(GameMessage& msg); void TryApplyMessage(); - std::string GetGameState(); - bool IsGameRunning(); + std::string GetGameState() const; + bool IsGameRunning() const; private: mg_context* m_MgContext = nullptr; const GameMessage* m_GameMessage = nullptr; std::string m_GameState; bool m_NeedsGameState = false; - mutable std::mutex m_lock; - std::mutex m_msgLock; - std::condition_variable m_msgApplied; + mutable std::mutex m_Lock; + std::mutex m_MsgLock; + std::condition_variable m_MsgApplied; ScenarioConfig m_ScenarioConfig; }; Index: source/rlinterface/RLInterface.cpp =================================================================== --- source/rlinterface/RLInterface.cpp +++ source/rlinterface/RLInterface.cpp @@ -42,32 +42,32 @@ // Interactions with the game engine (g_Game) must be done in the main // thread as there are specific checks for this. We will pass our commands // to the main thread to be applied -std::string RLInterface::SendGameMessage(const GameMessage msg) +std::string RLInterface::SendGameMessage(const GameMessage& msg) { - std::unique_lock msgLock(m_msgLock); + std::unique_lock msgLock(m_MsgLock); m_GameMessage = &msg; - m_msgApplied.wait(msgLock); + m_MsgApplied.wait(msgLock); return m_GameState; } -std::string RLInterface::Step(const std::vector commands) +std::string RLInterface::Step(const std::vector& commands) { - std::lock_guard lock(m_lock); - GameMessage msg = { GameMessageType::Commands, commands }; + std::lock_guard lock(m_Lock); + const GameMessage msg = { GameMessageType::Commands, commands }; return SendGameMessage(msg); } std::string RLInterface::Reset(const ScenarioConfig* scenario) { - std::lock_guard lock(m_lock); + std::lock_guard lock(m_Lock); m_ScenarioConfig = *scenario; - struct GameMessage msg = { GameMessageType::Reset }; + const GameMessage msg = { GameMessageType::Reset }; return SendGameMessage(msg); } -std::vector RLInterface::GetTemplates(const std::vector names) const +std::vector RLInterface::GetTemplates(const std::vector& names) const { - std::lock_guard lock(m_lock); + std::lock_guard lock(m_Lock); CSimulation2& simulation = *g_Game->GetSimulation2(); CmpPtr cmpTemplateManager(simulation.GetSimContext().GetSystemEntity()); @@ -78,7 +78,7 @@ if (node != nullptr) { - std::string content = utf8_from_wstring(node->ToXML()); + const std::string content = utf8_from_wstring(node->ToXML()); templates.push_back(content); } } @@ -119,7 +119,7 @@ { std::stringstream stream; - std::string uri = request_info->uri; + const std::string uri = request_info->uri; if (uri == "/reset") { @@ -130,22 +130,22 @@ return handled; } ScenarioConfig scenario; - std::string qs(request_info->query_string); + const std::string qs(request_info->query_string); scenario.saveReplay = qs.find("saveReplay") != std::string::npos; scenario.playerID = 1; char playerID[1]; - int len = mg_get_var(request_info->query_string, qs.length(), "playerID", playerID, 1); + const int len = mg_get_var(request_info->query_string, qs.length(), "playerID", playerID, 1); if (len != -1) scenario.playerID = std::stoi(playerID); - int bufSize = std::atoi(val); + const int bufSize = std::atoi(val); std::unique_ptr buf = std::unique_ptr(new char[bufSize]); mg_read(conn, buf.get(), bufSize); - std::string content(buf.get(), bufSize); + const std::string content(buf.get(), bufSize); scenario.content = content; - std::string gameState = interface->Reset(&scenario); + const std::string gameState = interface->Reset(&scenario); stream << gameState.c_str(); } @@ -166,14 +166,14 @@ int bufSize = std::atoi(val); std::unique_ptr buf = std::unique_ptr(new char[bufSize]); mg_read(conn, buf.get(), bufSize); - std::string postData(buf.get(), bufSize); + const std::string postData(buf.get(), bufSize); std::stringstream postStream(postData); std::string line; - std::vector commands; + std::vector commands; while (std::getline(postStream, line, '\n')) { - Command cmd; + RLGameCommand cmd; const std::size_t splitPos = line.find(";"); if (splitPos != std::string::npos) { @@ -182,7 +182,7 @@ commands.push_back(cmd); } } - std::string gameState = interface->Step(commands); + const std::string gameState = interface->Step(commands); if (gameState.empty()) { mg_printf(conn, "%s", notRunningResponse); @@ -203,10 +203,10 @@ mg_printf(conn, "%s", noPostData); return handled; } - int bufSize = std::atoi(val); + const int bufSize = std::atoi(val); std::unique_ptr buf = std::unique_ptr(new char[bufSize]); mg_read(conn, buf.get(), bufSize); - std::string postData(buf.get(), bufSize); + const std::string postData(buf.get(), bufSize); std::stringstream postStream(postData); std::string line; std::vector templateNames; @@ -224,7 +224,7 @@ } mg_printf(conn, "%s", header200); - std::string str = stream.str(); + const std::string str = stream.str(); mg_write(conn, str.c_str(), str.length()); return handled; } @@ -275,17 +275,18 @@ void RLInterface::TryApplyMessage() { + const static std::string EMPTY_STATE; const bool nonVisual = !g_GUI; const bool isGameStarted = g_Game && g_Game->IsGameStarted(); if (m_NeedsGameState && isGameStarted) { m_GameState = GetGameState(); - m_msgApplied.notify_one(); - m_msgLock.unlock(); + m_MsgApplied.notify_one(); + m_MsgLock.unlock(); m_NeedsGameState = false; } - if (m_msgLock.try_lock()) + if (m_MsgLock.try_lock()) { GameMessage msg; if (TryGetGameMessage(msg)) { @@ -311,8 +312,8 @@ LDR_NonprogressiveLoad(); ENSURE(g_Game->ReallyStartGame() == PSRETURN_OK); m_GameState = GetGameState(); - m_msgApplied.notify_one(); - m_msgLock.unlock(); + m_MsgApplied.notify_one(); + m_MsgLock.unlock(); } else { @@ -335,14 +336,14 @@ if (!g_Game) { m_GameState = EMPTY_STATE; - m_msgApplied.notify_one(); - m_msgLock.unlock(); + m_MsgApplied.notify_one(); + m_MsgLock.unlock(); return; } const ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface(); CLocalTurnManager* turnMgr = static_cast(g_Game->GetTurnManager()); - for (Command command : msg.commands) + for (const RLGameCommand& command : msg.commands) { JSContext* cx = scriptInterface.GetContext(); JSAutoRequest rq(cx); @@ -355,25 +356,25 @@ if (nonVisual) { const double deltaSimTime = deltaRealTime * g_Game->GetSimRate(); - size_t maxTurns = static_cast(g_Game->GetSimRate()); + const size_t maxTurns = static_cast(g_Game->GetSimRate()); g_Game->GetTurnManager()->Update(deltaSimTime, maxTurns); } else g_Game->Update(deltaRealTime); m_GameState = GetGameState(); - m_msgApplied.notify_one(); - m_msgLock.unlock(); + m_MsgApplied.notify_one(); + m_MsgLock.unlock(); break; } } } else - m_msgLock.unlock(); + m_MsgLock.unlock(); } } -std::string RLInterface::GetGameState() +std::string RLInterface::GetGameState() const { const ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface(); const CSimContext simContext = g_Game->GetSimulation2()->GetSimContext(); @@ -385,7 +386,7 @@ return scriptInterface.StringifyJSON(&state, false); } -bool RLInterface::IsGameRunning() +bool RLInterface::IsGameRunning() const { return !!g_Game; }