Index: ps/trunk/binaries/data/config/default.cfg =================================================================== --- ps/trunk/binaries/data/config/default.cfg +++ ps/trunk/binaries/data/config/default.cfg @@ -474,6 +474,9 @@ gpu.ext.enable = true ; Allow GL_EXT_timer_query timing mode when available gpu.intel.enable = true ; Allow GL_INTEL_performance_queries timing mode when available +[rlinterface] +address = "127.0.0.1:6000" + [sound] mastergain = 0.9 musicgain = 0.2 Index: ps/trunk/binaries/system/readme.txt =================================================================== --- ps/trunk/binaries/system/readme.txt +++ ps/trunk/binaries/system/readme.txt @@ -42,6 +42,9 @@ 3) Observe the PetraBot on a triggerscript map: -autostart="random/jebel_barkal" -autostart-seed=-1 -autostart-players=2 -autostart-civ=1:athen -autostart-civ=2:brit -autostart-ai=1:petra -autostart-ai=2:petra -autostart-player=-1 +RL client: +-rl-interface Run the RL interface (see source/tools/rlclient) + Configuration: -conf=KEY:VALUE set a config value -nosound disable audio Index: ps/trunk/build/premake/premake5.lua =================================================================== --- ps/trunk/build/premake/premake5.lua +++ ps/trunk/build/premake/premake5.lua @@ -598,6 +598,15 @@ setup_static_lib_project("network", source_dirs, extern_libs, {}) source_dirs = { + "rlinterface", + } + extern_libs = { + "boost", -- dragged in via simulation.h and scriptinterface.h + "spidermonkey", + } + setup_static_lib_project("rlinterface", source_dirs, extern_libs, { no_pch = 1 }) + + source_dirs = { "third_party/tinygettext/src", } extern_libs = { Index: ps/trunk/source/main.cpp =================================================================== --- ps/trunk/source/main.cpp +++ ps/trunk/source/main.cpp @@ -76,6 +76,7 @@ #include "graphics/TextureManager.h" #include "gui/GUIManager.h" #include "renderer/Renderer.h" +#include "rlinterface/RLInterface.cpp" #include "scriptinterface/ScriptEngine.h" #include "simulation2/Simulation2.h" #include "simulation2/system/TurnManager.h" @@ -388,9 +389,13 @@ ogl_WarnIfError(); + if (g_RLInterface) + g_RLInterface->TryApplyMessage(); + if (g_Game && g_Game->IsGameStarted() && need_update) { - g_Game->Update(realTimeSinceLastFrame); + if (!g_RLInterface) + g_Game->Update(realTimeSinceLastFrame); g_Game->GetView()->Update(float(realTimeSinceLastFrame)); } @@ -462,6 +467,65 @@ in_reset_handlers(); } +static void StartRLInterface(CmdLineArgs args) +{ + std::string server_address; + CFG_GET_VAL("rlinterface.address", server_address); + + if (!args.Get("rl-interface").empty()) + server_address = args.Get("rl-interface"); + + g_RLInterface = new RLInterface(); + g_RLInterface->EnableHTTP(server_address.c_str()); + debug_printf("RL interface listening on %s\n", server_address.c_str()); +} + +static void RunRLServer(const bool isNonVisual, const std::vector modsToInstall, const CmdLineArgs args) +{ + int flags = INIT_MODS; + while (!Init(args, flags)) + { + flags &= ~INIT_MODS; + Shutdown(SHUTDOWN_FROM_CONFIG); + } + g_Shutdown = ShutdownType::None; + + std::vector installedMods; + if (!modsToInstall.empty()) + { + Paths paths(args); + CModInstaller installer(paths.UserData() / "mods", paths.Cache()); + + // Install the mods without deleting the pyromod files + for (const OsPath& modPath : modsToInstall) + installer.Install(modPath, g_ScriptRuntime, true); + + installedMods = installer.GetInstalledMods(); + } + + if (isNonVisual) + { + InitNonVisual(args); + StartRLInterface(args); + while (g_Shutdown == ShutdownType::None) + g_RLInterface->TryApplyMessage(); + QuitEngine(); + } + else + { + InitGraphics(args, 0, installedMods); + MainControllerInit(); + StartRLInterface(args); + while (g_Shutdown == ShutdownType::None) + Frame(); + } + + Shutdown(0); + MainControllerShutdown(); + CXeromyces::Terminate(); + delete g_RLInterface; +} + // moved into a helper function to ensure args is destroyed before // exit(), which may result in a memory leak. static void RunGameOrAtlas(int argc, const char* argv[]) @@ -476,7 +540,7 @@ return; } - if (args.Has("autostart-nonvisual") && args.Get("autostart").empty()) + if (args.Has("autostart-nonvisual") && args.Get("autostart").empty() && !args.Has("rl-interface")) { LOGERROR("-autostart-nonvisual cant be used alone. A map with -autostart=\"TYPEDIR/MAPNAME\" is needed."); return; @@ -600,6 +664,12 @@ const double res = timer_Resolution(); g_frequencyFilter = CreateFrequencyFilter(res, 30.0); + if (args.Has("rl-interface")) + { + RunRLServer(isNonVisual, modsToInstall, args); + return; + } + // run the game int flags = INIT_MODS; do Index: ps/trunk/source/rlinterface/RLInterface.h =================================================================== --- ps/trunk/source/rlinterface/RLInterface.h +++ ps/trunk/source/rlinterface/RLInterface.h @@ -0,0 +1,76 @@ +/* Copyright (C) 2020 Wildfire Games. + * This file is part of 0 A.D. + * + * 0 A.D. is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * 0 A.D. is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with 0 A.D. If not, see . + */ +#ifndef INCLUDED_RLINTERFACE +#define INCLUDED_RLINTERFACE + +#include "simulation2/helpers/Player.h" + +#include +#include +#include + +struct ScenarioConfig { + bool saveReplay; + player_id_t playerID; + std::string content; +}; +struct Command { + int playerID; + std::string json_cmd; +}; + +enum GameMessageType { Reset, Commands }; +struct GameMessage { + GameMessageType type; + std::vector commands; +}; + +extern void EndGame(); + +struct mg_context; +const static std::string EMPTY_STATE; + +class RLInterface +{ + + public: + + std::string Step(const std::vector commands); + std::string Reset(const ScenarioConfig* scenario); + std::vector GetTemplates(const std::vector names) const; + + void EnableHTTP(const char* server_address); + std::string SendGameMessage(const GameMessage msg); + bool TryGetGameMessage(GameMessage& msg); + void TryApplyMessage(); + std::string GetGameState(); + bool IsGameRunning(); + + private: + mg_context* m_MgContext = nullptr; + const GameMessage* m_GameMessage = nullptr; + std::string m_GameState; + bool m_NeedsGameState = false; + mutable std::mutex m_lock; + std::mutex m_msgLock; + std::condition_variable m_msgApplied; + ScenarioConfig m_ScenarioConfig; +}; + +extern RLInterface* g_RLInterface; + +#endif // INCLUDED_RLINTERFACE Index: ps/trunk/source/rlinterface/RLInterface.cpp =================================================================== --- ps/trunk/source/rlinterface/RLInterface.cpp +++ ps/trunk/source/rlinterface/RLInterface.cpp @@ -0,0 +1,391 @@ +/* Copyright (C) 2020 Wildfire Games. + * This file is part of 0 A.D. + * + * 0 A.D. is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * 0 A.D. is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with 0 A.D. If not, see . + */ + +// Pull in the headers from the default precompiled header, +// even if rlinterface doesn't use precompiled headers. +#include "lib/precompiled.h" + +#include "rlinterface/RLInterface.h" + +#include "gui/GUIManager.h" +#include "ps/Game.h" +#include "ps/GameSetup/GameSetup.h" +#include "ps/Loader.h" +#include "ps/CLogger.h" +#include "simulation2/components/ICmpAIInterface.h" +#include "simulation2/components/ICmpTemplateManager.h" +#include "simulation2/Simulation2.h" +#include "simulation2/system/LocalTurnManager.h" +#include "third_party/mongoose/mongoose.h" + +#include +#include +#include + +// Globally accessible pointer to the RL Interface. +RLInterface* g_RLInterface = nullptr; + +// Interactions with the game engine (g_Game) must be done in the main +// thread as there are specific checks for this. We will pass our commands +// to the main thread to be applied +std::string RLInterface::SendGameMessage(const GameMessage msg) +{ + std::unique_lock msgLock(m_msgLock); + m_GameMessage = &msg; + m_msgApplied.wait(msgLock); + return m_GameState; +} + +std::string RLInterface::Step(const std::vector commands) +{ + std::lock_guard lock(m_lock); + GameMessage msg = { GameMessageType::Commands, commands }; + return SendGameMessage(msg); +} + +std::string RLInterface::Reset(const ScenarioConfig* scenario) +{ + std::lock_guard lock(m_lock); + m_ScenarioConfig = *scenario; + struct GameMessage msg = { GameMessageType::Reset }; + return SendGameMessage(msg); +} + +std::vector RLInterface::GetTemplates(const std::vector names) const +{ + std::lock_guard lock(m_lock); + CSimulation2& simulation = *g_Game->GetSimulation2(); + CmpPtr cmpTemplateManager(simulation.GetSimContext().GetSystemEntity()); + + std::vector templates; + for (const std::string& templateName : names) + { + const CParamNode* node = cmpTemplateManager->GetTemplate(templateName); + + if (node != nullptr) + { + std::string content = utf8_from_wstring(node->ToXML()); + templates.push_back(content); + } + } + + return templates; +} + +static void* RLMgCallback(mg_event event, struct mg_connection *conn, const struct mg_request_info *request_info) +{ + RLInterface* interface = (RLInterface*)request_info->user_data; + ENSURE(interface); + + void* handled = (void*)""; // arbitrary non-NULL pointer to indicate successful handling + + const char* header200 = + "HTTP/1.1 200 OK\r\n" + "Access-Control-Allow-Origin: *\r\n" + "Content-Type: text/plain; charset=utf-8\r\n\r\n"; + + const char* header404 = + "HTTP/1.1 404 Not Found\r\n" + "Content-Type: text/plain; charset=utf-8\r\n\r\n" + "Unrecognised URI"; + + const char* noPostData = + "HTTP/1.1 400 Bad Request\r\n" + "Content-Type: text/plain; charset=utf-8\r\n\r\n" + "No POST data found."; + + const char* notRunningResponse = + "HTTP/1.1 400 Bad Request\r\n" + "Content-Type: text/plain; charset=utf-8\r\n\r\n" + "Game not running. Please create a scenario first."; + + switch (event) + { + case MG_NEW_REQUEST: + { + std::stringstream stream; + + std::string uri = request_info->uri; + + if (uri == "/reset") + { + const char* val = mg_get_header(conn, "Content-Length"); + if (!val) + { + mg_printf(conn, "%s", noPostData); + return handled; + } + ScenarioConfig scenario; + std::string qs(request_info->query_string); + scenario.saveReplay = qs.find("saveReplay") != std::string::npos; + + scenario.playerID = 1; + char playerID[1]; + int len = mg_get_var(request_info->query_string, qs.length(), "playerID", playerID, 1); + if (len != -1) + scenario.playerID = std::stoi(playerID); + + int bufSize = std::atoi(val); + std::unique_ptr buf = std::unique_ptr(new char[bufSize]); + mg_read(conn, buf.get(), bufSize); + std::string content(buf.get(), bufSize); + scenario.content = content; + + std::string gameState = interface->Reset(&scenario); + + stream << gameState.c_str(); + } + else if (uri == "/step") + { + if (!interface->IsGameRunning()) + { + mg_printf(conn, "%s", notRunningResponse); + return handled; + } + + const char* val = mg_get_header(conn, "Content-Length"); + if (!val) + { + mg_printf(conn, "%s", noPostData); + return handled; + } + int bufSize = std::atoi(val); + std::unique_ptr buf = std::unique_ptr(new char[bufSize]); + mg_read(conn, buf.get(), bufSize); + std::string postData(buf.get(), bufSize); + std::stringstream postStream(postData); + std::string line; + std::vector commands; + + while (std::getline(postStream, line, '\n')) + { + Command cmd; + const std::size_t splitPos = line.find(";"); + if (splitPos != std::string::npos) + { + cmd.playerID = std::stoi(line.substr(0, splitPos)); + cmd.json_cmd = line.substr(splitPos + 1); + commands.push_back(cmd); + } + } + std::string gameState = interface->Step(commands); + if (gameState.empty()) + { + mg_printf(conn, "%s", notRunningResponse); + return handled; + } + else + stream << gameState.c_str(); + } + else if (uri == "/templates") + { + if (!interface->IsGameRunning()) { + mg_printf(conn, "%s", notRunningResponse); + return handled; + } + const char* val = mg_get_header(conn, "Content-Length"); + if (!val) + { + mg_printf(conn, "%s", noPostData); + return handled; + } + int bufSize = std::atoi(val); + std::unique_ptr buf = std::unique_ptr(new char[bufSize]); + mg_read(conn, buf.get(), bufSize); + std::string postData(buf.get(), bufSize); + std::stringstream postStream(postData); + std::string line; + std::vector templateNames; + while (std::getline(postStream, line, '\n')) + templateNames.push_back(line); + + for (std::string templateStr : interface->GetTemplates(templateNames)) + stream << templateStr.c_str() << "\n"; + + } + else + { + mg_printf(conn, "%s", header404); + return handled; + } + + mg_printf(conn, "%s", header200); + std::string str = stream.str(); + mg_write(conn, str.c_str(), str.length()); + return handled; + } + + case MG_HTTP_ERROR: + return nullptr; + + case MG_EVENT_LOG: + // Called by Mongoose's cry() + LOGERROR("Mongoose error: %s", request_info->log_message); + return nullptr; + + case MG_INIT_SSL: + return nullptr; + + default: + debug_warn(L"Invalid Mongoose event type"); + return nullptr; + } +}; + +void RLInterface::EnableHTTP(const char* server_address) +{ + LOGMESSAGERENDER("Starting RL interface HTTP server"); + + // Ignore multiple enablings + if (m_MgContext) + return; + + const char *options[] = { + "listening_ports", server_address, + "num_threads", "6", // enough for the browser's parallel connection limit + nullptr + }; + m_MgContext = mg_start(RLMgCallback, this, options); + ENSURE(m_MgContext); +} + +bool RLInterface::TryGetGameMessage(GameMessage& msg) +{ + if (m_GameMessage != nullptr) { + msg = *m_GameMessage; + m_GameMessage = nullptr; + return true; + } + return false; +} + +void RLInterface::TryApplyMessage() +{ + const bool nonVisual = !g_GUI; + const bool isGameStarted = g_Game && g_Game->IsGameStarted(); + if (m_NeedsGameState && isGameStarted) + { + m_GameState = GetGameState(); + m_msgApplied.notify_one(); + m_msgLock.unlock(); + m_NeedsGameState = false; + } + + if (m_msgLock.try_lock()) + { + GameMessage msg; + if (TryGetGameMessage(msg)) { + switch (msg.type) + { + case GameMessageType::Reset: + { + if (isGameStarted) + EndGame(); + + g_Game = new CGame(m_ScenarioConfig.saveReplay); + ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface(); + JSContext* cx = scriptInterface.GetContext(); + JSAutoRequest rq(cx); + JS::RootedValue attrs(cx); + scriptInterface.ParseJSON(m_ScenarioConfig.content, &attrs); + + g_Game->SetPlayerID(m_ScenarioConfig.playerID); + g_Game->StartGame(&attrs, ""); + + if (nonVisual) + { + LDR_NonprogressiveLoad(); + ENSURE(g_Game->ReallyStartGame() == PSRETURN_OK); + m_GameState = GetGameState(); + m_msgApplied.notify_one(); + m_msgLock.unlock(); + } + else + { + JS::RootedValue initData(cx); + scriptInterface.CreateObject(cx, &initData); + scriptInterface.SetProperty(initData, "attribs", attrs); + + JS::RootedValue playerAssignments(cx); + scriptInterface.CreateObject(cx, &playerAssignments); + scriptInterface.SetProperty(initData, "playerAssignments", playerAssignments); + + g_GUI->SwitchPage(L"page_loading.xml", &scriptInterface, initData); + m_NeedsGameState = true; + } + break; + } + + case GameMessageType::Commands: + { + if (!g_Game) + { + m_GameState = EMPTY_STATE; + m_msgApplied.notify_one(); + m_msgLock.unlock(); + return; + } + const ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface(); + CLocalTurnManager* turnMgr = static_cast(g_Game->GetTurnManager()); + + for (Command command : msg.commands) + { + JSContext* cx = scriptInterface.GetContext(); + JSAutoRequest rq(cx); + JS::RootedValue commandJSON(cx); + scriptInterface.ParseJSON(command.json_cmd, &commandJSON); + turnMgr->PostCommand(command.playerID, commandJSON); + } + + const double deltaRealTime = DEFAULT_TURN_LENGTH_SP; + if (nonVisual) + { + const double deltaSimTime = deltaRealTime * g_Game->GetSimRate(); + size_t maxTurns = static_cast(g_Game->GetSimRate()); + g_Game->GetTurnManager()->Update(deltaSimTime, maxTurns); + } + else + g_Game->Update(deltaRealTime); + + m_GameState = GetGameState(); + m_msgApplied.notify_one(); + m_msgLock.unlock(); + break; + } + } + } + else + m_msgLock.unlock(); + } +} + +std::string RLInterface::GetGameState() +{ + const ScriptInterface& scriptInterface = g_Game->GetSimulation2()->GetScriptInterface(); + const CSimContext simContext = g_Game->GetSimulation2()->GetSimContext(); + CmpPtr cmpAIInterface(simContext.GetSystemEntity()); + JSContext* cx = scriptInterface.GetContext(); + JSAutoRequest rq(cx); + JS::RootedValue state(cx); + cmpAIInterface->GetFullRepresentation(&state, true); + return scriptInterface.StringifyJSON(&state, false); +} + +bool RLInterface::IsGameRunning() +{ + return !!g_Game; +} Index: ps/trunk/source/simulation2/system/LocalTurnManager.h =================================================================== --- ps/trunk/source/simulation2/system/LocalTurnManager.h +++ ps/trunk/source/simulation2/system/LocalTurnManager.h @@ -31,6 +31,7 @@ void OnSimulationMessage(CSimulationMessage* msg) override; void PostCommand(JS::HandleValue data) override; + void PostCommand(player_id_t playerid, JS::HandleValue data); protected: void NotifyFinishedOwnCommands(u32 turn) override; Index: ps/trunk/source/simulation2/system/LocalTurnManager.cpp =================================================================== --- ps/trunk/source/simulation2/system/LocalTurnManager.cpp +++ ps/trunk/source/simulation2/system/LocalTurnManager.cpp @@ -24,6 +24,11 @@ { } +void CLocalTurnManager::PostCommand(player_id_t playerid, JS::HandleValue data) +{ + AddCommand(m_ClientId, playerid, data, m_CurrentTurn + 1); +} + void CLocalTurnManager::PostCommand(JS::HandleValue data) { // Add directly to the next turn, ignoring COMMAND_DELAY, Index: ps/trunk/source/simulation2/system/TurnManager.h =================================================================== --- ps/trunk/source/simulation2/system/TurnManager.h +++ ps/trunk/source/simulation2/system/TurnManager.h @@ -24,6 +24,7 @@ #include #include #include +#include class CSimulationMessage; class CSimulation2; Index: ps/trunk/source/tools/rlclient/python/README.md =================================================================== --- ps/trunk/source/tools/rlclient/python/README.md +++ ps/trunk/source/tools/rlclient/python/README.md @@ -0,0 +1,50 @@ +# 0 AD Python Client +This directory contains `zero_ad`, a python client for 0 AD which enables users to control the environment headlessly. + +## Installation +`zero_ad` can be installed with `pip` by running the following from the current directory: +``` +pip install . +``` + +Development dependencies can be installed with `pip install -r requirements-dev.txt`. Tests are using pytest and can be run with `python -m pytest`. + +## Basic Usage +If there is not a running instance of 0 AD, first start 0 AD with the RL interface enabled: +``` +pyrogenesis --rl-interface=127.0.0.1:6000 +``` + +Next, the python client can be connected with: +``` +import zero_ad +from zero_ad import ZeroAD + +game = ZeroAD('http://localhost:6000') +``` + +A map can be loaded with: + +``` +with open('./samples/arcadia.json', 'r') as f: + arcadia_config = f.read() + +state = game.reset(arcadia_config) +``` + +where `./samples/arcadia.json` is the path to a game configuration JSON (included in the first line of the commands.txt file in a game replay directory) and `state` contains the initial game state for the given map. The game engine can be stepped (optionally applying actions at each step) with: + +``` +state = game.step() +``` + +For example, enemy units could be attacked with: + +``` +my_units = state.units(owner=1) +enemy_units = state.units(owner=2) +actions = [zero_ad.actions.attack(my_units, enemy_units[0])] +state = game.step(actions) +``` + +For a more thorough example, check out samples/simple-example.py! Index: ps/trunk/source/tools/rlclient/python/requirements-dev.txt =================================================================== --- ps/trunk/source/tools/rlclient/python/requirements-dev.txt +++ ps/trunk/source/tools/rlclient/python/requirements-dev.txt @@ -0,0 +1 @@ +pytest Index: ps/trunk/source/tools/rlclient/python/samples/arcadia.json =================================================================== --- ps/trunk/source/tools/rlclient/python/samples/arcadia.json +++ ps/trunk/source/tools/rlclient/python/samples/arcadia.json @@ -0,0 +1,53 @@ +{ + "settings": { + "TriggerScripts": [ + "scripts/TriggerHelper.js", + "scripts/ConquestCommon.js", + "scripts/ConquestUnits.js" + ], + "VictoryConditions": [ + "conquest_units" + ], + "Name": "Arcadia", + "mapType": "scenario", + "AISeed": 0, + "Seed": 0, + "CheatsEnabled": true, + "Ceasefire": 0, + "WonderDuration": 10, + "RelicDuration": 10, + "RelicCount": 2, + "Size": 256, + "PlayerData": [ + { + "Name": "Player 1", + "Civ": "spart", + "Color": { + "r": 150, + "g": 20, + "b": 20 + }, + "AI": "", + "AIDiff": 3, + "AIBehavior": "random", + "Team": 1 + }, + { + "Name": "Player 2", + "Civ": "spart", + "Color": { + "r": 150, + "g": 20, + "b": 20 + }, + "AI": "", + "AIDiff": 3, + "AIBehavior": "random", + "Team": 2 + } + ] + }, + "mapType": "scenario", + "map": "maps/scenarios/arcadia", + "gameSpeed": 1 +} Index: ps/trunk/source/tools/rlclient/python/samples/simple-example.py =================================================================== --- ps/trunk/source/tools/rlclient/python/samples/simple-example.py +++ ps/trunk/source/tools/rlclient/python/samples/simple-example.py @@ -0,0 +1,98 @@ +# This script provides an overview of the zero_ad wrapper for 0 AD +from os import path +import zero_ad + +# First, we will define some helper functions we will use later. +import math +def dist (p1, p2): + return math.sqrt(sum((math.pow(x2 - x1, 2) for (x1, x2) in zip(p1, p2)))) + +def center(units): + sum_position = map(sum, zip(*map(lambda u: u.position(), units))) + return [x/len(units) for x in sum_position] + +def closest(units, position): + dists = (dist(unit.position(), position) for unit in units) + index = 0 + min_dist = next(dists) + for (i, d) in enumerate(dists): + if d < min_dist: + index = i + min_dist = d + + return units[index] + +# Connect to a 0 AD game server listening at localhost:6000 +game = zero_ad.ZeroAD('http://localhost:6000') + +# Load the Arcadia map +samples_dir = path.dirname(path.realpath(__file__)) +scenario_config_path = path.join(samples_dir, 'arcadia.json') +with open(scenario_config_path, 'r') as f: + arcadia_config = f.read() + +state = game.reset(arcadia_config) + +# The game is paused and will only progress upon calling "step" +state = game.step() + +# Units can be queried from the game state +citizen_soldiers = state.units(owner=1, type='infantry') +# (including gaia units like trees or other resources) +nearby_tree = closest(state.units(owner=0, type='tree'), center(citizen_soldiers)) + +# Action commands can be created using zero_ad.actions +collect_wood = zero_ad.actions.gather(citizen_soldiers, nearby_tree) + +female_citizens = state.units(owner=1, type='female_citizen') +house_tpl = 'structures/spart_house' +x = 680 +z = 640 +build_house = zero_ad.actions.construct(female_citizens, house_tpl, x, z, autocontinue=True) + +# These commands can then be applied to the game in a `step` command +state = game.step([collect_wood, build_house]) + +# We can also fetch units by id using the `unit` function on the game state +female_id = female_citizens[0].id() +female_citizen = state.unit(female_id) + +# A variety of unit information can be queried from the unit: +print('female citizen\'s max health is', female_citizen.max_health()) + +# Raw data for units and game states are available via the data attribute +print(female_citizen.data) + +# Units can be built using the "train action" +civic_center = state.units(owner=1, type="civil_centre")[0] +spearman_type = 'units/spart_infantry_spearman_b' +train_spearmen = zero_ad.actions.train([civic_center], spearman_type) + +state = game.step([train_spearmen]) + +# Let's step the engine until the house has been built +is_unit_busy = lambda state, unit_id: len(state.unit(unit_id).data['unitAIOrderData']) > 0 +while is_unit_busy(state, female_id): + state = game.step() + +# The units for the other army can also be controlled +enemy_units = state.units(owner=2) +walk = zero_ad.actions.walk(enemy_units, *civic_center.position()) +game.step([walk], player=[2]) + +# Step the game engine a bit to give them some time to walk +for _ in range(150): + state = game.step() + +# Let's attack with our entire military +state = game.step([zero_ad.actions.chat('An attack is coming!')]) + +while len(state.units(owner=2, type='unit')) > 0: + attack_units = [ unit for unit in state.units(owner=1, type='unit') if 'female' not in unit.type() ] + target = closest(state.units(owner=2, type='unit'), center(attack_units)) + state = game.step([zero_ad.actions.attack(attack_units, target)]) + + while state.unit(target.id()): + state = game.step() + +game.step([zero_ad.actions.chat('The enemies have been vanquished. Our home is safe again.')]) Index: ps/trunk/source/tools/rlclient/python/setup.py =================================================================== --- ps/trunk/source/tools/rlclient/python/setup.py +++ ps/trunk/source/tools/rlclient/python/setup.py @@ -0,0 +1,13 @@ +import os +from setuptools import setup + +setup(name='zero_ad', + version='0.0.1', + description='Python client for 0 AD', + url='https://code.wildfiregames.com', + author='Brian Broll', + author_email='brian.broll@gmail.com', + install_requires=[], + license='MIT', + packages=['zero_ad'], + zip_safe=False) Index: ps/trunk/source/tools/rlclient/python/tests/test_actions.py =================================================================== --- ps/trunk/source/tools/rlclient/python/tests/test_actions.py +++ ps/trunk/source/tools/rlclient/python/tests/test_actions.py @@ -0,0 +1,100 @@ +import zero_ad +import json +import math +from os import path + +game = zero_ad.ZeroAD('http://localhost:6000') +scriptdir = path.dirname(path.realpath(__file__)) +with open(path.join(scriptdir, '..', 'samples', 'arcadia.json'), 'r') as f: + config = f.read() + +def dist (p1, p2): + return math.sqrt(sum((math.pow(x2 - x1, 2) for (x1, x2) in zip(p1, p2)))) + +def center(units): + sum_position = map(sum, zip(*map(lambda u: u.position(), units))) + return [x/len(units) for x in sum_position] + +def closest(units, position): + dists = (dist(unit.position(), position) for unit in units) + index = 0 + min_dist = next(dists) + for (i, d) in enumerate(dists): + if d < min_dist: + index = i + min_dist = d + + return units[index] + +def test_construct(): + state = game.reset(config) + female_citizens = state.units(owner=1, type='female_citizen') + house_tpl = 'structures/spart_house' + house_count = len(state.units(owner=1, type=house_tpl)) + x = 680 + z = 640 + build_house = zero_ad.actions.construct(female_citizens, house_tpl, x, z, autocontinue=True) + # Check that they start building the house + state = game.step([build_house]) + while len(state.units(owner=1, type=house_tpl)) == house_count: + state = game.step() + +def test_gather(): + state = game.reset(config) + female_citizen = state.units(owner=1, type='female_citizen')[0] + trees = state.units(owner=0, type='tree') + nearby_tree = closest(state.units(owner=0, type='tree'), female_citizen.position()) + + collect_wood = zero_ad.actions.gather([female_citizen], nearby_tree) + state = game.step([collect_wood]) + while len(state.unit(female_citizen.id()).data['resourceCarrying']) == 0: + state = game.step() + +def test_train(): + state = game.reset(config) + civic_centers = state.units(owner=1, type="civil_centre") + spearman_type = 'units/spart_infantry_spearman_b' + spearman_count = len(state.units(owner=1, type=spearman_type)) + train_spearmen = zero_ad.actions.train(civic_centers, spearman_type) + + state = game.step([train_spearmen]) + while len(state.units(owner=1, type=spearman_type)) == spearman_count: + state = game.step() + +def test_walk(): + state = game.reset(config) + female_citizens = state.units(owner=1, type='female_citizen') + x = 680 + z = 640 + initial_distance = dist(center(female_citizens), [x, z]) + + walk = zero_ad.actions.walk(female_citizens, x, z) + state = game.step([walk]) + distance = initial_distance + while distance >= initial_distance: + state = game.step() + female_citizens = state.units(owner=1, type='female_citizen') + distance = dist(center(female_citizens), [x, z]) + +def test_attack(): + state = game.reset(config) + units = state.units(owner=1, type='cavalry') + target = state.units(owner=2, type='female_citizen')[0] + initial_health = target.health() + + state = game.step([zero_ad.actions.reveal_map()]) + + attack = zero_ad.actions.attack(units, target) + state = game.step([attack]) + while state.unit(target.id()).health() >= initial_health: + state = game.step() + +def test_debug_print(): + state = game.reset(config) + debug_print = zero_ad.actions.debug_print('hello world!!') + state = game.step([debug_print]) + +def test_chat(): + state = game.reset(config) + chat = zero_ad.actions.chat('hello world!!') + state = game.step([chat]) Index: ps/trunk/source/tools/rlclient/python/zero_ad/__init__.py =================================================================== --- ps/trunk/source/tools/rlclient/python/zero_ad/__init__.py +++ ps/trunk/source/tools/rlclient/python/zero_ad/__init__.py @@ -0,0 +1,4 @@ +from . import actions +from . import environment +ZeroAD = environment.ZeroAD +GameState = environment.GameState Index: ps/trunk/source/tools/rlclient/python/zero_ad/actions.py =================================================================== --- ps/trunk/source/tools/rlclient/python/zero_ad/actions.py +++ ps/trunk/source/tools/rlclient/python/zero_ad/actions.py @@ -0,0 +1,69 @@ +def construct(units, template, x, z, angle=0, autorepair=True, autocontinue=True, queued=False): + unit_ids = [ unit.id() for unit in units ] + return { + 'type': 'construct', + 'entities': unit_ids, + 'template': template, + 'x': x, + 'z': z, + 'angle': angle, + 'autorepair': autorepair, + 'autocontinue': autocontinue, + 'queued': queued, + } + +def gather(units, target, queued=False): + unit_ids = [ unit.id() for unit in units ] + return { + 'type': 'gather', + 'entities': unit_ids, + 'target': target.id(), + 'queued': queued, + } + +def train(entities, unit_type, count=1): + entity_ids = [ unit.id() for unit in entities ] + return { + 'type': 'train', + 'entities': entity_ids, + 'template': unit_type, + 'count': count, + } + +def debug_print(message): + return { + 'type': 'debug-print', + 'message': message + } + +def chat(message): + return { + 'type': 'aichat', + 'message': message + } + +def reveal_map(): + return { + 'type': 'reveal-map', + 'enable': True + } + +def walk(units, x, z, queued=False): + ids = [ unit.id() for unit in units ] + return { + 'type': 'walk', + 'entities': ids, + 'x': x, + 'z': z, + 'queued': queued + } + +def attack(units, target, queued=False, allow_capture=True): + unit_ids = [ unit.id() for unit in units ] + return { + 'type': 'attack', + 'entities': unit_ids, + 'target': target.id(), + 'allowCapture': allow_capture, + 'queued': queued + } Index: ps/trunk/source/tools/rlclient/python/zero_ad/api.py =================================================================== --- ps/trunk/source/tools/rlclient/python/zero_ad/api.py +++ ps/trunk/source/tools/rlclient/python/zero_ad/api.py @@ -0,0 +1,29 @@ +import urllib +from urllib import request +import json + +class RLAPI(): + def __init__(self, url): + self.url = url + + def post(self, route, data): + response = request.urlopen(url=f'{self.url}/{route}', data=bytes(data, 'utf8')) + return response.read() + + def step(self, commands): + post_data = '\n'.join((f'{player};{json.dumps(action)}' for (player, action) in commands)) + return self.post('step', post_data) + + def reset(self, scenario_config, player_id, save_replay): + path = 'reset?' + if save_replay: + path += 'saveReplay=1&' + if player_id: + path += f'playerID={player_id}&' + + return self.post(path, scenario_config) + + def get_templates(self, names): + post_data = '\n'.join(names) + response = self.post('templates', post_data) + return zip(names, response.decode().split('\n')) Index: ps/trunk/source/tools/rlclient/python/zero_ad/environment.py =================================================================== --- ps/trunk/source/tools/rlclient/python/zero_ad/environment.py +++ ps/trunk/source/tools/rlclient/python/zero_ad/environment.py @@ -0,0 +1,113 @@ +from .api import RLAPI +import json +import math +from xml.etree import ElementTree +from itertools import cycle + +class ZeroAD(): + def __init__(self, uri='http://localhost:6000'): + self.api = RLAPI(uri) + self.current_state = None + self.cache = {} + self.player_id = 1 + + def step(self, actions=[], player=None): + player_ids = cycle([self.player_id]) if player is None else cycle(player) + + cmds = zip(player_ids, actions) + cmds = ((player, action) for (player, action) in cmds if action is not None) + state_json = self.api.step(cmds) + self.current_state = GameState(json.loads(state_json), self) + return self.current_state + + def reset(self, config='', save_replay=False, player_id=1): + state_json = self.api.reset(config, player_id, save_replay) + self.current_state = GameState(json.loads(state_json), self) + return self.current_state + + def get_template(self, name): + return self.get_templates([name])[0] + + def get_templates(self, names): + templates = self.api.get_templates(names) + return [ (name, EntityTemplate(content)) for (name, content) in templates ] + + def update_templates(self, types=[]): + all_types = list(set([unit.type() for unit in self.current_state.units()])) + all_types += types + template_pairs = self.get_templates(all_types) + + self.cache = {} + for (name, tpl) in template_pairs: + self.cache[name] = tpl + + return template_pairs + +class GameState(): + def __init__(self, data, game): + self.data = data + self.game = game + self.mapSize = self.data['mapSize'] + + def units(self, owner=None, type=None): + filter_fn = lambda e: (owner is None or e['owner'] == owner) and \ + (type is None or type in e['template']) + return [ Entity(e, self.game) for e in self.data['entities'].values() if filter_fn(e) ] + + def unit(self, id): + id = str(id) + return Entity(self.data['entities'][id], self.game) if id in self.data['entities'] else None + +class Entity(): + + def __init__(self, data, game): + self.data = data + self.game = game + self.template = self.game.cache.get(self.type(), None) + + def type(self): + return self.data['template'] + + def id(self): + return self.data['id'] + + def owner(self): + return self.data['owner'] + + def max_health(self): + template = self.get_template() + return float(template.get('Health/Max')) + + def health(self, ratio=False): + if ratio: + return self.data['hitpoints']/self.max_health() + + return self.data['hitpoints'] + + def position(self): + return self.data['position'] + + def get_template(self): + if self.template is None: + self.game.update_templates([self.type()]) + self.template = self.game.cache[self.type()] + + return self.template + +class EntityTemplate(): + def __init__(self, xml): + self.data = ElementTree.fromstring(f'{xml}') + + def get(self, path): + node = self.data.find(path) + return node.text if node is not None else None + + def set(self, path, value): + node = self.data.find(path) + if node: + node.text = str(value) + + return node is not None + + def __str__(self): + return ElementTree.tostring(self.data).decode('utf-8')