Index: binaries/data/config/default.cfg =================================================================== --- binaries/data/config/default.cfg +++ binaries/data/config/default.cfg @@ -467,6 +467,9 @@ gpu.ext.enable = true ; Allow GL_EXT_timer_query timing mode when available gpu.intel.enable = true ; Allow GL_INTEL_performance_queries timing mode when available +[rlinterface] +address = "0.0.0.0:50050" + [sound] mastergain = 0.9 musicgain = 0.2 Index: build/premake/extern_libs5.lua =================================================================== --- build/premake/extern_libs5.lua +++ build/premake/extern_libs5.lua @@ -198,13 +198,28 @@ if os.istarget("windows") or os.istarget("macosx") then add_default_lib_paths("boost") end + add_default_links({ -- The following are not strictly link dependencies on all systems, but -- are included for compatibility with different versions of Boost android_names = { "boost_filesystem-gcc-mt", "boost_system-gcc-mt" }, - unix_names = { os.findlib("boost_filesystem-mt") and "boost_filesystem-mt" or "boost_filesystem", os.findlib("boost_system-mt") and "boost_system-mt" or "boost_system" }, + unix_names = { + os.findlib("boost_filesystem-mt") and "boost_filesystem-mt" or "boost_filesystem", + os.findlib("boost_system-mt") and "boost_system-mt" or "boost_system" + }, osx_names = { "boost_filesystem-mt", "boost_system-mt" }, }) + + if _OPTIONS["with-rlinterface"] then + add_default_links({ + unix_names = { + os.findlib("boost_fiber-mt") and "boost_fiber-mt" or "boost_fiber", + os.findlib("boost_context-mt") and "boost_context-mt" or "boost_context", + }, + osx_names = { "boost_fiber-mt" }, + }) + end + end, }, comsuppw = { @@ -677,6 +692,21 @@ }, } +if _OPTIONS["with-rlinterface"] then + extern_lib_defs['grpc'] = { + compile_settings = function() + pkgconfig.add_includes("grpc") + pkgconfig.add_includes("grpc++") + pkgconfig.add_includes("protobuf") + end, + link_settings = function() + pkgconfig.add_links("protobuf") + pkgconfig.add_links("grpc") + pkgconfig.add_links("grpc++") + end + } +end + -- add a set of external libraries to the project; takes care of -- include / lib path and linking against the import library. Index: build/premake/premake5.lua =================================================================== --- build/premake/premake5.lua +++ build/premake/premake5.lua @@ -13,6 +13,7 @@ newoption { trigger = "without-miniupnpc", description = "Disable use of miniupnpc for port forwarding" } newoption { trigger = "without-nvtt", description = "Disable use of NVTT" } newoption { trigger = "without-pch", description = "Disable generation and usage of precompiled headers" } +newoption { trigger = "with-rlinterface", description = "Enable RPC interface for reinforcement learning" } newoption { trigger = "without-tests", description = "Disable generation of test projects" } -- Linux/BSD specific options @@ -35,6 +36,32 @@ -- Root directory of project checkout relative to this .lua file rootdir = "../.." +if _OPTIONS["with-rlinterface"] then + if os.istarget("windows") then + error("The RL interface is currently not supported on Windows.") + else + -- Generate cpp files + protodir = rootdir .. "/source/rlinterface/proto/" + gencppfiles = "protoc --grpc_out=. --plugin=protoc-gen-grpc=`which grpc_cpp_plugin` RLAPI.proto && protoc --cpp_out=. RLAPI.proto" + updatefilenames = "mv RLAPI.pb.cc RLAPI.pb.cpp && mv RLAPI.grpc.pb.cc RLAPI.grpc.pb.cpp" + _, _, exitcode = os.execute("cd " .. protodir .. " && " .. gencppfiles .. " && " .. updatefilenames) + if exitcode > 0 then + error("Unable to generate GRPC files. Is grpc and protobuf (protoc) installed?") + end + + -- Generate python files + clientsdir = rootdir .. "/source/tools/clients/" + preparepython = "mkdir -p source/tools/clients/python/zero_ad/proto/zero_ad && cp source/rlinterface/proto/RLAPI.proto source/tools/clients/python/zero_ad" + genpythonfiles = "python -m grpc_tools.protoc --python_out=python --grpc_python_out=python -Ipython python/zero_ad/RLAPI.proto" + + os.execute("cd " .. rootdir .. " && " .. preparepython) + + _, _, exitcode = os.execute("cd " .. clientsdir .. " && " .. genpythonfiles) + if exitcode > 0 then + print("WARNING: Unable to generate GRPC files for Python client. Is grpcio-tools installed?") + end + end +end dofile("extern_libs5.lua") @@ -174,6 +201,12 @@ defines { "CONFIG2_NVTT=0" } end + if _OPTIONS["with-rlinterface"] then + defines { "WITH_RLINTERFACE=1" } + else + defines { "WITH_RLINTERFACE=0" } + end + if _OPTIONS["without-lobby"] then defines { "CONFIG2_LOBBY=0" } end @@ -581,6 +614,20 @@ end setup_static_lib_project("network", source_dirs, extern_libs, {}) + if _OPTIONS["with-rlinterface"] then + source_dirs = { + "rlinterface", + "rlinterface/proto" + } + extern_libs = { + "boost", + "spidermonkey", + "sdl", -- key definitions + "grpc", + } + setup_static_lib_project("rlinterface", source_dirs, extern_libs, { no_pch = 1 }) + end + source_dirs = { "third_party/tinygettext/src", } @@ -935,6 +982,10 @@ table.insert(used_extern_libs, "miniupnpc") end +if _OPTIONS["with-rlinterface"] then + table.insert(used_extern_libs, "grpc") +end + -- Bundles static libs together with main.cpp and builds game executable. function setup_main_exe () Index: libraries/osx/build-osx-libs.sh =================================================================== --- libraries/osx/build-osx-libs.sh +++ libraries/osx/build-osx-libs.sh @@ -138,6 +138,7 @@ do case $i in --force-rebuild ) force_rebuild=true;; + --with-rlinterface ) with_rlinterface=true;; -j* ) JOBS=$i ;; esac done @@ -375,8 +376,14 @@ tar -xf $LIB_ARCHIVE pushd $LIB_DIRECTORY + BOOST_LIBS="filesystem,system" + if [[ "$with_rlinterface" = "true" ]] + then + BOOST_LIBS="$BOOST_LIBS,fiber" + fi + # Can't use macosx-version, see above comment. - (./bootstrap.sh --with-libraries=filesystem,system \ + (./bootstrap.sh --with-libraries=$BOOST_LIBS\ --prefix=$INSTALL_DIR \ && ./b2 cflags="$CFLAGS" \ toolset=clang \ Index: source/main.cpp =================================================================== --- source/main.cpp +++ source/main.cpp @@ -79,6 +79,9 @@ #include "scriptinterface/ScriptEngine.h" #include "simulation2/Simulation2.h" #include "simulation2/system/TurnManager.h" +#if WITH_RLINTERFACE +#include "rlinterface/RLInterface.cpp" +#endif #include "soundmanager/ISoundManager.h" #if OS_UNIX @@ -316,8 +319,15 @@ while (more && timer_Time() - startTime < maxTime); } +#if WITH_RLINTERFACE +static void Frame(RLInterface* service=nullptr) +#else static void Frame() +#endif { +#if WITH_RLINTERFACE + bool using_interface = service != nullptr; +#endif g_Profiler2.RecordFrameStart(); PROFILE2("frame"); g_Profiler2.IncrementFrameNumber(); @@ -388,9 +398,17 @@ ogl_WarnIfError(); +#if WITH_RLINTERFACE + if (using_interface) + service->ApplyEvents(); +#endif + if (g_Game && g_Game->IsGameStarted() && need_update) { - g_Game->Update(realTimeSinceLastFrame); +#if WITH_RLINTERFACE + if (!using_interface) +#endif + g_Game->Update(realTimeSinceLastFrame); g_Game->GetView()->Update(float(realTimeSinceLastFrame)); } @@ -460,6 +478,71 @@ in_reset_handlers(); } +#if WITH_RLINTERFACE +static std::unique_ptr StartRLInterface(CmdLineArgs args) +{ + std::string server_address; + CFG_GET_VAL("rlinterface.address", server_address); + + if (!args.Get("rpc-server").empty()) + server_address = args.Get("rpc-server"); + + std::unique_ptr service(new RLInterface); + service.get()->Listen(server_address); + debug_printf("RL interface listening on %s\n", server_address.c_str()); + return service; +} + +static void RunRLServer(const bool isNonVisual, std::vector modsToInstall, CmdLineArgs args) +{ + int flags = INIT_MODS; + while (!Init(args, flags)) + { + flags &= ~INIT_MODS; + Shutdown(SHUTDOWN_FROM_CONFIG); + } + g_Shutdown = ShutdownType::None; + + std::vector installedMods; + if (!modsToInstall.empty()) + { + Paths paths(args); + CModInstaller installer(paths.UserData() / "mods", paths.Cache()); + + // Install the mods without deleting the pyromod files + for (const OsPath& modPath : modsToInstall) + installer.Install(modPath, g_ScriptRuntime, true); + + installedMods = installer.GetInstalledMods(); + } + + if (isNonVisual) + { + InitNonVisual(args); + std::unique_ptr service = StartRLInterface(args); + while (g_Shutdown == ShutdownType::None) + { + service.get()->ApplyEvents(); + } + QuitEngine(); + } + else + { + InitGraphics(args, 0, installedMods); + MainControllerInit(); + std::unique_ptr service = StartRLInterface(args); + while (g_Shutdown == ShutdownType::None) + { + Frame(service.get()); + } + } + + Shutdown(0); + MainControllerShutdown(); + CXeromyces::Terminate(); +} +#endif + // moved into a helper function to ensure args is destroyed before // exit(), which may result in a memory leak. static void RunGameOrAtlas(int argc, const char* argv[]) @@ -474,7 +557,11 @@ return; } +#if WITH_RLINTERFACE + if (args.Has("autostart-nonvisual") && args.Get("autostart").empty() && !args.Has("rpc-server")) +#else if (args.Has("autostart-nonvisual") && args.Get("autostart").empty()) +#endif { LOGERROR("-autostart-nonvisual cant be used alone. A map with -autostart=\"TYPEDIR/MAPNAME\" is needed."); return; @@ -598,8 +685,16 @@ const double res = timer_Resolution(); g_frequencyFilter = CreateFrequencyFilter(res, 30.0); +#if WITH_RLINTERFACE + if (args.Has("rpc-server")) { + RunRLServer(isNonVisual, modsToInstall, args); + return; + } +#endif + // run the game int flags = INIT_MODS; + do { g_Shutdown = ShutdownType::None; Index: source/rlinterface/RLInterface.h =================================================================== --- /dev/null +++ source/rlinterface/RLInterface.h @@ -0,0 +1,77 @@ +/* Copyright (C) 2019 Wildfire Games. + * This file is part of 0 A.D. + * + * 0 A.D. is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * 0 A.D. is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with 0 A.D. If not, see . + */ +#ifndef INCLUDED_RLINTERFACE +#define INCLUDED_RLINTERFACE +#include +#include +#include +#include +#include +#include +#include +#include +#include "rlinterface/proto/RLAPI.grpc.pb.h" + +#include "lib/precompiled.h" +#include "lib/external_libraries/libsdl.h" +#include "simulation2/Simulation2.h" +#include "simulation2/components/ICmpAIInterface.h" +#include "simulation2/components/ICmpTemplateManager.h" +#include "simulation2/system/TurnManager.h" +#include "ps/Game.h" +#include "ps/Loader.h" +#include "gui/GUIManager.h" +#include "ps/VideoMode.h" +#include "ps/GameSetup/GameSetup.h" +#include "ps/ThreadUtil.h" +#include +#include + +using grpc::ServerContext; +using boost::fibers::unbuffered_channel; +using boost::fibers::buffered_channel; + +enum GameMessageType { Reset, Commands }; +struct GameMessage { + GameMessageType type; + std::queue> data; +}; +extern void EndGame(); + +class RLInterface final : public RLAPI::Service +{ + + public: + + grpc::Status Step(ServerContext* context, const Actions* commands, Observation* obs) override; + grpc::Status Reset(ServerContext* context, const ResetRequest* req, Observation* obs) override; + grpc::Status GetTemplates(ServerContext* context, const GetTemplateRequest* req, Templates* res) override; + + void Listen(std::string server_address); + void ApplyEvents(); // Apply RPC messages to the game engine + std::string GetGameState(); + + private: + std::unique_ptr m_Server; + unsigned int m_Turn = 0; + std::mutex m_lock; + buffered_channel m_GameMessages{2}; + unbuffered_channel m_GameStates; + bool m_NeedsGameState = false; + ScenarioConfig m_ScenarioConfig; +}; +#endif // INCLUDED_RLINTERFACE Index: source/rlinterface/RLInterface.cpp =================================================================== --- /dev/null +++ source/rlinterface/RLInterface.cpp @@ -0,0 +1,199 @@ +/* Copyright (C) 2019 Wildfire Games. + * This file is part of 0 A.D. + * + * 0 A.D. is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * 0 A.D. is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with 0 A.D. If not, see . + */ +#include "rlinterface/RLInterface.h" +#include "simulation2/system/LocalTurnManager.h" + +using grpc::ServerContext; +using boost::fibers::channel_op_status; + +grpc::Status RLInterface::Step(ServerContext* UNUSED(context), const Actions* commands, Observation* obs) +{ + std::lock_guard lock(m_lock); + + // Interactions with the game engine (g_Game) must be done in the main + // thread as there are specific checks for this. We will pass our commands + // to the main thread to be applied + + GameMessage msg = { GameMessageType::Commands }; + const int size = commands->actions_size(); + for (int i = 0; i < size; i++) + { + std::string json_cmd = commands->actions(i).content(); + player_id_t playerid = commands->actions(i).playerid(); + msg.data.push(std::make_tuple(playerid, json_cmd)); + } + m_GameMessages.push(msg); + std::string state; + m_GameStates.pop(state); + obs->set_content(state); + + return grpc::Status::OK; +} + +grpc::Status RLInterface::Reset(ServerContext* UNUSED(context), const ResetRequest* req, Observation* obs) +{ + std::lock_guard lock(m_lock); + if (req->has_scenario()) + m_ScenarioConfig = req->scenario(); + + struct GameMessage msg = { GameMessageType::Reset }; + m_GameMessages.push(msg); + + std::string state; + m_GameStates.pop(state); + obs->set_content(state); + + return grpc::Status::OK; +} + +grpc::Status RLInterface::GetTemplates(ServerContext* UNUSED(context), const GetTemplateRequest* req, Templates* res) +{ + std::lock_guard lock(m_lock); + if (!g_Game) + { + LOGERROR("Game not running. Have you started a scenario with a Reset message?"); + return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Game not running."); + } + + CSimulation2& simulation = *g_Game->GetSimulation2(); + CmpPtr cmpTemplateManager(simulation.GetSimContext().GetSystemEntity()); + + const int size = req->names_size(); + for (int i = 0; i < size; i++) + { + const std::string templateName = req->names(i); + const CParamNode* node = cmpTemplateManager->GetTemplate(templateName); + + Template* tpl = res->add_templates(); + tpl->set_name(templateName); + if (node != nullptr) + { + std::string content = utf8_from_wstring(node->ToXML()); + tpl->set_content(content); + } + } + + return grpc::Status::OK; +} + +void RLInterface::Listen(std::string server_address) +{ + grpc::ServerBuilder builder; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(this); + m_Server = builder.BuildAndStart(); +} + +void RLInterface::ApplyEvents() +{ + const bool nonVisual = !g_GUI; + const bool isGameStarted = g_Game && g_Game->IsGameStarted(); + if (m_NeedsGameState && isGameStarted) + { + m_GameStates.push(GetGameState()); // Send the game state back to the request + m_NeedsGameState = false; + } + + GameMessage msg; + while (m_GameMessages.try_pop(msg) == channel_op_status::success) + { + switch (msg.type) + { + case GameMessageType::Reset: + { + if (isGameStarted) + EndGame(); + + g_Game = new CGame(nonVisual, m_ScenarioConfig.savereplay()); + ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface(); + JSContext* cx = scriptInterface.GetContext(); + JS::RootedValue attrs(cx); + scriptInterface.ParseJSON(m_ScenarioConfig.content(), &attrs); + + g_Game->SetPlayerID(m_ScenarioConfig.playerid()); + g_Game->StartGame(&attrs, ""); + + if (nonVisual) + { + LDR_NonprogressiveLoad(); + ENSURE(g_Game->ReallyStartGame() == PSRETURN_OK); + m_GameStates.push(GetGameState()); // Send the game state back to the request + } + else + { + JS::RootedValue initData(cx); + scriptInterface.CreateObject(&initData); + scriptInterface.SetProperty(initData, "attribs", attrs); + + JS::RootedValue playerAssignments(cx); + scriptInterface.CreateObject(&playerAssignments); + scriptInterface.SetProperty(initData, "playerAssignments", playerAssignments); + + g_GUI->SwitchPage(L"page_loading.xml", &scriptInterface, initData); + m_NeedsGameState = true; + } + } + break; + + case GameMessageType::Commands: + if (!isGameStarted) + { + LOGWARNING("Cannot apply game commands w/o running game. Ignoring..."); + continue; + } + + const ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface(); + CLocalTurnManager* turnMgr = static_cast(g_Game->GetTurnManager()); + + while (msg.data.size() > 0) + { + int playerid = std::get<0>(msg.data.front()); + std::string json_cmd = std::get<1>(msg.data.front()); + msg.data.pop(); + + JSContext* cx = scriptInterface.GetContext(); + JS::RootedValue command(cx); + scriptInterface.ParseJSON(json_cmd, &command); + turnMgr->PostCommand(playerid, command); + } + + const double deltaRealTime = DEFAULT_TURN_LENGTH_SP; + if (nonVisual) + { + const double deltaSimTime = deltaRealTime * g_Game->GetSimRate(); + size_t maxTurns = static_cast(g_Game->GetSimRate()); + g_Game->GetTurnManager()->Update(deltaSimTime, maxTurns); + } + else + g_Game->Update(deltaRealTime); + + m_GameStates.push(GetGameState()); + break; + } + } +} + +std::string RLInterface::GetGameState() +{ + const ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface(); + const auto simContext = g_Game->GetSimulation2()->GetSimContext(); + CmpPtr cmpAIInterface(simContext.GetSystemEntity()); + JSContext* cx = scriptInterface.GetContext(); + JS::RootedValue state(cx); + cmpAIInterface->GetFullRepresentation(&state, true); + return scriptInterface.StringifyJSON(&state, false); +} Index: source/rlinterface/proto/RLAPI.proto =================================================================== --- /dev/null +++ source/rlinterface/proto/RLAPI.proto @@ -0,0 +1,65 @@ +/* Copyright (C) 2019 Wildfire Games. + * This file is part of 0 A.D. + * + * 0 A.D. is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * 0 A.D. is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with 0 A.D. If not, see . + */ +syntax = "proto3"; + +service RLAPI { + rpc Step(Actions) returns (Observation) {} + rpc Reset(ResetRequest) returns (Observation) {} + rpc GetTemplates(GetTemplateRequest) returns (Templates) {} +} + +message GetTemplateRequest { + repeated string names = 1; +} + +message Templates { + repeated Template templates = 1; +} + +message Template { + string name = 1; + string content = 2; +} + +message Actions { + repeated Action actions = 1; +} + +message Action { + int32 playerID = 1; + string content = 2; +} + +message Observation { + string content = 1; +} + +message ResetRequest { + ScenarioConfig scenario = 1; +} + +message AIPlayer { + int32 id = 1; + string type = 2; + uint32 difficulty = 3; +} + +message ScenarioConfig { + bool saveReplay = 1; + int32 playerID = 2; + string content = 3; +} Index: source/simulation2/system/ComponentManager.cpp =================================================================== --- source/simulation2/system/ComponentManager.cpp +++ source/simulation2/system/ComponentManager.cpp @@ -31,6 +31,7 @@ #include "ps/CLogger.h" #include "ps/Filesystem.h" #include "ps/scripting/JSInterface_VFS.h" +#include "simulation2/scripting/JSInterface_Simulation.h" /** * Used for script-only message types. @@ -69,6 +70,7 @@ if (!skipScriptFunctions) { JSI_VFS::RegisterScriptFunctions_Simulation(m_ScriptInterface); + m_ScriptInterface.RegisterFunction("GetInitAttributes"); m_ScriptInterface.RegisterFunction ("RegisterComponentType"); m_ScriptInterface.RegisterFunction ("RegisterSystemComponentType"); m_ScriptInterface.RegisterFunction ("ReRegisterComponentType"); Index: source/simulation2/system/LocalTurnManager.h =================================================================== --- source/simulation2/system/LocalTurnManager.h +++ source/simulation2/system/LocalTurnManager.h @@ -31,6 +31,7 @@ void OnSimulationMessage(CSimulationMessage* msg) override; void PostCommand(JS::HandleValue data) override; + void PostCommand(player_id_t playerid, JS::HandleValue data); protected: void NotifyFinishedOwnCommands(u32 turn) override; Index: source/simulation2/system/LocalTurnManager.cpp =================================================================== --- source/simulation2/system/LocalTurnManager.cpp +++ source/simulation2/system/LocalTurnManager.cpp @@ -24,6 +24,11 @@ { } +void CLocalTurnManager::PostCommand(player_id_t playerid, JS::HandleValue data) +{ + AddCommand(m_ClientId, playerid, data, m_CurrentTurn + 1); +} + void CLocalTurnManager::PostCommand(JS::HandleValue data) { // Add directly to the next turn, ignoring COMMAND_DELAY, Index: source/tools/clients/python/.gitignore =================================================================== --- /dev/null +++ source/tools/clients/python/.gitignore @@ -0,0 +1 @@ +__pycache__/ Index: source/tools/clients/python/README.md =================================================================== --- /dev/null +++ source/tools/clients/python/README.md @@ -0,0 +1,51 @@ +# 0 AD Python Client +This directory contains `zero_ad`, a python client for 0 AD which enables users to control the environment headlessly. + +## Installation +`zero_ad` can be installed with `pip` by running the following from the current directory: +``` +pip install . +``` + +Development dependencies can be installed with `pip install -r requirements-dev.txt`. Tests are using pytest and can be run with `python -m pytest`. + +## Basic Usage +If there is not a running instance of 0 AD, first start 0 AD with the RL interface enabled: +``` +pyrogenesis --rpc-server=0.0.0.0:50050 +``` + +Next, the python client can be connected with: +``` +import zero_ad +from zero_ad import ZeroAD + +game = ZeroAD('localhost:50050') +``` + +A map can be loaded with: + +``` +with open('./samples/arcadia.json', 'r') as f: + arcadia_config = f.read() + +config = zero_ad.ScenarioConfig(playerID=1, content=arcadia_config) +state = game.reset(config) +``` + +where `./samples/arcadia.json` is the path to a game configuration JSON (included in the first line of the commands.txt file in a game replay directory) and `state` contains the initial game state for the given map. The game engine can be stepped (optionally applying actions at each step) with: + +``` +state = game.step() +``` + +For example, enemy units could be attacked with: + +``` +my_units = state.units(owner=1) +enemy_units = state.units(owner=2) +actions = [zero_ad.actions.attack(my_units, enemy_units[0])] +state = game.step(actions) +``` + +For a more thorough example, check out samples/simple-example.py! Index: source/tools/clients/python/requirements-dev.txt =================================================================== --- /dev/null +++ source/tools/clients/python/requirements-dev.txt @@ -0,0 +1 @@ +pytest Index: source/tools/clients/python/requirements.txt =================================================================== --- /dev/null +++ source/tools/clients/python/requirements.txt @@ -0,0 +1,2 @@ +grpcio +protobuf Index: source/tools/clients/python/samples/arcadia.json =================================================================== --- /dev/null +++ source/tools/clients/python/samples/arcadia.json @@ -0,0 +1 @@ +{"settings": {"TriggerScripts": ["scripts/TriggerHelper.js", "scripts/ConquestCommon.js", "scripts/ConquestUnits.js"], "VictoryConditions": ["conquest_units"], "Name": "Arcadia", "mapType": "scenario", "AISeed": 0, "Seed": 0, "CheatsEnabled": true, "Ceasefire": 0, "WonderDuration": 10, "RelicDuration": 10, "RelicCount": 2, "Size": 256, "PlayerData": [{"Name": "Player 1", "Civ": "spart", "Color": {"r": 150, "g": 20, "b": 20}, "AI": "", "AIDiff": 3, "AIBehavior": "random", "Team": 1}, {"Name": "Player 2", "Civ": "spart", "Color": {"r": 150, "g": 20, "b": 20}, "AI": "", "AIDiff": 3, "AIBehavior": "random", "Team": 2}]}, "mapType": "scenario", "map": "maps/scenarios/Arcadia", "gameSpeed": 1} Index: source/tools/clients/python/samples/simple-example.py =================================================================== --- /dev/null +++ source/tools/clients/python/samples/simple-example.py @@ -0,0 +1,76 @@ +# This script provides an overview of the zero_ad wrapper for 0 AD +from os import path +import zero_ad + +# Connect to a 0 AD game server listening at localhost:50050 +game = zero_ad.ZeroAD('localhost:50050') + +# Load the Arcadia map +samples_dir = path.dirname(path.realpath(__file__)) +scenario_config_path = path.join(samples_dir, 'arcadia.json') +with open(scenario_config_path, 'r') as f: + arcadia_config = f.read() + +config = zero_ad.ScenarioConfig(playerID=1, content=arcadia_config) +state = game.reset(config) + +# The game is paused and will only progress upon calling "step" +state = game.step() + +# units can be queried from the game state +citizen_soldiers = state.units(owner=1, type='infantry') +# (including gaia units like trees or other resources) +nearby_tree = state.closest(state.units(owner=0, type='tree')) + +# Action commands can be created using zero_ad.actions +collect_wood = zero_ad.actions.gather(citizen_soldiers, nearby_tree) + +female_citizens = state.units(owner=1, type='female_citizen') +house_tpl = 'structures/spart_house' +x = 680 +z = 640 +build_house = zero_ad.actions.construct(female_citizens, house_tpl, x, z, autocontinue=True) + +# These commands can then be applied to the game in a `step` command +state = game.step([collect_wood, build_house]) + +# We can also fetch units by id using the `unit` function on the game state +female_id = female_citizens[0].id() +female_citizen = state.unit(female_id) + +# Raw data for units and game states are available via the data attribute +print(female_citizen.data) + +# Units can be built using the "train action" +civic_center = state.units(owner=1, type="civil_centre")[0] +spearman_type = 'units/spart_infantry_spearman_b' +train_spearmen = zero_ad.actions.train([civic_center], spearman_type) + +state = game.step([train_spearmen]) + +# Let's step the engine until the house has been built +is_unit_busy = lambda state, unit_id: len(state.unit(unit_id).data['unitAIOrderData']) > 0 +while is_unit_busy(state, female_id): + state = game.step() + +# The units for the other army can also be controlled +enemy_units = state.units(owner=2) +walk = zero_ad.actions.walk(enemy_units, *civic_center.position()) +game.step([walk], player=[2]) + +# Step the game engine a bit to give them some time to walk +for _ in range(150): + state = game.step() + +# Let's attack with our entire military +state = game.step([zero_ad.actions.chat('An attack is coming!')]) + +while len(state.units(owner=2, type='unit')) > 0: + attack_units = [ unit for unit in state.units(owner=1, type='unit') if 'female' not in unit.type() ] + target = state.closest(state.units(owner=2), state.center(attack_units)) + state = game.step([zero_ad.actions.attack(attack_units, target)]) + + while state.unit(target.id()): + state = game.step() + +game.step([zero_ad.actions.chat('The enemies have been vanquished. Our home is safe again.')]) Index: source/tools/clients/python/setup.py =================================================================== --- /dev/null +++ source/tools/clients/python/setup.py @@ -0,0 +1,19 @@ +import os +from setuptools import setup + +project_root = os.path.dirname(os.path.realpath(__file__)) +requirementPath = os.path.join(project_root, 'requirements.txt') +install_requires = [] +with open(requirementPath) as f: + install_requires = f.read().splitlines() + +setup(name='zero_ad', + version='0.0.1', + description='Python client for 0 AD', + url='https://code.wildfiregames.com', + author='Brian Broll', + author_email='brian.broll@gmail.com', + install_requires=install_requires, + license='MIT', + packages=['zero_ad'], + zip_safe=False) Index: source/tools/clients/python/tests/test_actions.py =================================================================== --- /dev/null +++ source/tools/clients/python/tests/test_actions.py @@ -0,0 +1,79 @@ +import zero_ad +from zero_ad import MapType +import json + +game = zero_ad.ZeroAD('localhost:50050') +config = zero_ad.ScenarioConfig('Arcadia', type=MapType.SCENARIO) +config.add_player('Player 1', civ='spart', team=1) +config.add_player('Player 2', civ='spart', team=2) + +def test_construct(): + state = game.reset(config) + female_citizens = state.units(owner=1, type='female_citizen') + house_tpl = 'structures/spart_house' + house_count = len(state.units(owner=1, type=house_tpl)) + x = 680 + z = 640 + build_house = zero_ad.actions.construct(female_citizens, house_tpl, x, z, autocontinue=True) + # Check that they start building the house + state = game.step([build_house]) + while len(state.units(owner=1, type=house_tpl)) == house_count: + state = game.step() + +def test_gather(): + state = game.reset(config) + female_citizen = state.units(owner=1, type='female_citizen')[0] + nearby_tree = state.closest(state.units(owner=0, type='tree')) + collect_wood = zero_ad.actions.gather([female_citizen], nearby_tree) + state = game.step([collect_wood]) + while len(state.unit(female_citizen.id()).data['resourceCarrying']) == 0: + state = game.step() + +def test_train(): + state = game.reset(config) + civic_centers = state.units(owner=1, type="civil_centre") + spearman_type = 'units/spart_infantry_spearman_b' + spearman_count = len(state.units(owner=1, type=spearman_type)) + train_spearmen = zero_ad.actions.train(civic_centers, spearman_type) + + state = game.step([train_spearmen]) + while len(state.units(owner=1, type=spearman_type)) == spearman_count: + state = game.step() + +def test_walk(): + state = game.reset(config) + female_citizens = state.units(owner=1, type='female_citizen') + x = 680 + z = 640 + initial_distance = state.dist(state.center(female_citizens), [x, z]) + + walk = zero_ad.actions.walk(female_citizens, x, z) + state = game.step([walk]) + distance = initial_distance + while distance >= initial_distance: + state = game.step() + female_citizens = state.units(owner=1, type='female_citizen') + distance = state.dist(state.center(female_citizens), [x, z]) + +def test_attack(): + state = game.reset(config) + units = state.units(owner=1, type='cavalry') + target = state.units(owner=2, type='female_citizen')[0] + initial_health = target.health() + + state = game.step([zero_ad.actions.reveal_map()]) + + attack = zero_ad.actions.attack(units, target) + state = game.step([attack]) + while state.unit(target.id()).health() >= initial_health: + state = game.step() + +def test_debug_print(): + state = game.reset(config) + debug_print = zero_ad.actions.debug_print('hello world!!') + state = game.step([debug_print]) + +def test_chat(): + state = game.reset(config) + chat = zero_ad.actions.chat('hello world!!') + state = game.step([chat]) Index: source/tools/clients/python/zero_ad/__init__.py =================================================================== --- /dev/null +++ source/tools/clients/python/zero_ad/__init__.py @@ -0,0 +1,5 @@ +from . import actions +from . import environment +from .RLAPI_pb2 import ScenarioConfig +ZeroAD = environment.ZeroAD +GameState = environment.GameState Index: source/tools/clients/python/zero_ad/actions.py =================================================================== --- /dev/null +++ source/tools/clients/python/zero_ad/actions.py @@ -0,0 +1,69 @@ +def construct(units, template, x, z, angle=0, autorepair=True, autocontinue=True, queued=False): + unit_ids = [ unit.id() for unit in units ] + return { + 'type': 'construct', + 'entities': unit_ids, + 'template': template, + 'x': x, + 'z': z, + 'angle': angle, + 'autorepair': autorepair, + 'autocontinue': autocontinue, + 'queued': queued, + } + +def gather(units, target, queued=False): + unit_ids = [ unit.id() for unit in units ] + return { + 'type': 'gather', + 'entities': unit_ids, + 'target': target.id(), + 'queued': queued, + } + +def train(entities, unit_type, count=1): + entity_ids = [ unit.id() for unit in entities ] + return { + 'type': 'train', + 'entities': entity_ids, + 'template': unit_type, + 'count': count, + } + +def debug_print(message): + return { + 'type': 'debug-print', + 'message': message + } + +def chat(message): + return { + 'type': 'aichat', + 'message': message + } + +def reveal_map(): + return { + 'type': 'reveal-map', + 'enable': True + } + +def walk(units, x, z, queued=False): + ids = [ unit.id() for unit in units ] + return { + 'type': 'walk', + 'entities': ids, + 'x': x, + 'z': z, + 'queued': queued + } + +def attack(units, target, queued=False, allow_capture=True): + unit_ids = [ unit.id() for unit in units ] + return { + 'type': 'attack', + 'entities': unit_ids, + 'target': target.id(), + 'allowCapture': allow_capture, + 'queued': queued + } Index: source/tools/clients/python/zero_ad/environment.py =================================================================== --- /dev/null +++ source/tools/clients/python/zero_ad/environment.py @@ -0,0 +1,162 @@ +from .RLAPI_pb2 import Actions, Action, ResetRequest, GetTemplateRequest +from .RLAPI_pb2_grpc import RLAPIStub +import grpc +import json +import math +from xml.etree import ElementTree +from itertools import cycle + +class ZeroAD(): + def __init__(self, uri='localhost:50050'): + channel = grpc.insecure_channel(uri) + self.stub = RLAPIStub(channel) + self.current_state = None + self.cache = {} + self.player_id = None + + def step(self, actions=[], player=None): + player_ids = cycle([self.player_id]) if player is None else cycle(player) + + cmds = Actions() + cmds.actions.extend([ + Action(content=json.dumps(a), playerID=pid) for (a, pid) in zip(actions, player_ids) if a is not None + ]) + res = self.stub.Step(cmds) + self.current_state = GameState(json.loads(res.content), self) + return self.current_state + + def reset(self, config=None): + if config is not None: + self.player_id = config.playerID if config.playerID > 0 else 1 + + req = ResetRequest(scenario=config) + res = self.stub.Reset(req) + self.current_state = GameState(json.loads(res.content), self) + return self.current_state + + def get_template(self, name): + return self.get_templates([name])[0] + + def get_templates(self, names): + req = GetTemplateRequest(names=names) + res = self.stub.GetTemplates(req) + return [ (t.name, EntityTemplate(t.content)) for t in res.templates ] + + def update_templates(self, types=[]): + all_types = list(set([unit.type() for unit in self.current_state.units()])) + all_types += types + template_pairs = self.get_templates(all_types) + + self.cache = {} + for (name, tpl) in template_pairs: + self.cache[name] = tpl + + return template_pairs + +class GameState(): + def __init__(self, data, game): + self.data = data + self.game = game + self.mapSize = self.data['mapSize'] + + def units(self, owner=None, type=None): + filter_fn = lambda e: (owner is None or e['owner'] == owner) and \ + (type is None or type in e['template']) + return [ Entity(e, self.game) for e in self.data['entities'].values() if filter_fn(e) ] + + def unit(self, id): + id = str(id) + return Entity(self.data['entities'][id], self.game) if id in self.data['entities'] else None + + def center(self, units=None): + if units is None: + units = self.units(owner=1) + + positions = [ unit.position() for unit in units ] + squad_center = [ + sum([ x for [x, z] in positions ])/len(positions), + sum([ z for [x, z] in positions ])/len(positions) + ] + return squad_center + + def closest(self, units, position=None): + if position is None: + position = self.center() + + min_dist = math.inf + closest = None + for unit in units: + dist = self.dist(unit.position(), position) + if dist < min_dist: + min_dist = dist + closest = unit + + return closest + + def offset(self, p1, p2): + [x, z] = p1 + [x2, z2] = p2 + dx = x2 - x + dz = z2 - z + return [ dx, dz ] + + def magnitude(self, vec): + [x, z] = vec + return math.sqrt(math.pow(x, 2) + math.pow(z, 2)) + + def dist(self, p1, p2): + return self.magnitude(self.offset(p1, p2)) + +class Entity(): + + def __init__(self, data, game): + self.data = data + self.game = game + self.template = self.game.cache.get(self.type(), None) + + def type(self): + return self.data['template'] + + def id(self): + return self.data['id'] + + def owner(self): + return self.data['owner'] + + def max_health(self): + template = self.get_template() + return float(template.get('Health/Max')) + + def health(self, ratio=False): + if ratio: + return self.data['hitpoints']/self.max_health() + + return self.data['hitpoints'] + + def position(self): + return self.data['position'] + + def get_template(self): + if self.template is None: + self.game.update_templates([self.type()]) + self.template = self.game.cache[self.type()] + + return self.template + +class EntityTemplate(): + def __init__(self, xml): + self.data = ElementTree.fromstring(f'{xml}') + + def get(self, path): + node = self.data.find(path) + return node.text if node is not None else None + + def set(self, path, value): + node = self.data.find(path) + if node: + node.text = str(value) + + return node is not None + + def __str__(self): + return ElementTree.tostring(self.data).decode('utf-8')