Index: ps/trunk/source/network/NetClient.h =================================================================== --- ps/trunk/source/network/NetClient.h +++ ps/trunk/source/network/NetClient.h @@ -1,4 +1,4 @@ -/* Copyright (C) 2022 Wildfire Games. +/* Copyright (C) 2024 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify @@ -60,8 +60,6 @@ { NONCOPYABLE(CNetClient); - friend class CNetFileReceiveTask_ClientRejoin; - public: /** * Construct a client associated with the given game object. Index: ps/trunk/source/network/NetClient.cpp =================================================================== --- ps/trunk/source/network/NetClient.cpp +++ ps/trunk/source/network/NetClient.cpp @@ -1,4 +1,4 @@ -/* Copyright (C) 2023 Wildfire Games. +/* Copyright (C) 2024 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify @@ -53,37 +53,6 @@ CNetClient *g_NetClient = NULL; -/** - * Async task for receiving the initial game state when rejoining an - * in-progress network game. - */ -class CNetFileReceiveTask_ClientRejoin : public CNetFileReceiveTask -{ - NONCOPYABLE(CNetFileReceiveTask_ClientRejoin); -public: - CNetFileReceiveTask_ClientRejoin(CNetClient& client, const CStr& initAttribs) - : m_Client(client), m_InitAttributes(initAttribs) - { - } - - virtual void OnComplete() - { - // We've received the game state from the server - - // Save it so we can use it after the map has finished loading - m_Client.m_JoinSyncBuffer = m_Buffer; - - // Pretend the server told us to start the game - CGameStartMessage start; - start.m_InitAttributes = m_InitAttributes; - m_Client.HandleMessage(&start); - } - -private: - CNetClient& m_Client; - CStr m_InitAttributes; -}; - CNetClient::CNetClient(CGame* game) : m_Session(NULL), m_UserName(L"anonymous"), @@ -834,8 +803,19 @@ // The server wants us to start downloading the game state from it, so do so client->m_Session->GetFileTransferer().StartTask( - std::shared_ptr(new CNetFileReceiveTask_ClientRejoin(*client, joinSyncStartMessage->m_InitAttributes)) - ); + [client, initAttributes = std::move(joinSyncStartMessage->m_InitAttributes)](std::string buffer) + mutable + { + // We've received the game state from the server. + + // Save it so we can use it after the map has finished loading. + client->m_JoinSyncBuffer = std::move(buffer); + + // Pretend the server told us to start the game. + CGameStartMessage start; + start.m_InitAttributes = std::move(initAttributes); + client->HandleMessage(&start); + }); return true; } Index: ps/trunk/source/network/NetFileTransfer.h =================================================================== --- ps/trunk/source/network/NetFileTransfer.h +++ ps/trunk/source/network/NetFileTransfer.h @@ -1,4 +1,4 @@ -/* Copyright (C) 2021 Wildfire Games. +/* Copyright (C) 2024 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify @@ -18,8 +18,10 @@ #ifndef NETFILETRANSFER_H #define NETFILETRANSFER_H +#include #include #include +#include class CNetMessage; class CFileTransferResponseMessage; @@ -41,35 +43,6 @@ static const size_t MAX_FILE_TRANSFER_SIZE = 8*MiB; /** - * Asynchronous file-receiving task. - * Other code should subclass this, implement OnComplete(), - * then pass it to CNetFileTransferer::StartTask. - */ -class CNetFileReceiveTask -{ -public: - CNetFileReceiveTask() : m_RequestID(0), m_Length(0) { } - virtual ~CNetFileReceiveTask() {} - - /** - * Called when m_Buffer contains the full received data. - */ - virtual void OnComplete() = 0; - - // TODO: Ought to have an OnFailure, e.g. when the session drops or there's another error - - /** - * Uniquely identifies the request within the scope of its CNetFileTransferer. - * Set automatically by StartTask. - */ - u32 m_RequestID; - - size_t m_Length; - - std::string m_Buffer; -}; - -/** * Handles transferring files between clients and servers. */ class CNetFileTransferer @@ -91,7 +64,7 @@ /** * Registers a file-receiving task. */ - void StartTask(const std::shared_ptr& task); + void StartTask(std::function task); /** * Registers data to be sent in response to a request. @@ -127,7 +100,22 @@ u32 m_NextRequestID; - using FileReceiveTasksMap = std::map>; + + struct AsyncFileReceiveTask + { + /** + * Called when m_Buffer contains the full received data. + */ + std::function onComplete; + + // TODO: Ought to have a failure channel, e.g. when the session drops or there's another error. + + size_t length{0}; + + std::string buffer; + }; + + using FileReceiveTasksMap = std::unordered_map; FileReceiveTasksMap m_FileReceiveTasks; using FileSendTasksMap = std::map; Index: ps/trunk/source/network/NetFileTransfer.cpp =================================================================== --- ps/trunk/source/network/NetFileTransfer.cpp +++ ps/trunk/source/network/NetFileTransfer.cpp @@ -1,4 +1,4 @@ -/* Copyright (C) 2021 Wildfire Games. +/* Copyright (C) 2024 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify @@ -19,6 +19,7 @@ #include "NetFileTransfer.h" +#include "lib/alignment.h" #include "lib/timer.h" #include "network/NetMessage.h" #include "network/NetSession.h" @@ -57,12 +58,12 @@ return ERR::FAIL; } - CNetFileReceiveTask& task = *it->second; + AsyncFileReceiveTask& task = it->second; - task.m_Length = message.m_Length; - task.m_Buffer.reserve(message.m_Length); + task.length = message.m_Length; + task.buffer.reserve(message.m_Length); - LOGMESSAGERENDER("Downloading data over network (%lu KB) - please wait...", task.m_Length / 1024); + LOGMESSAGERENDER("Downloading data over network (%lu KiB) - please wait...", task.length / KiB); m_LastProgressReportTime = timer_Time(); return INFO::OK; @@ -77,27 +78,28 @@ return ERR::FAIL; } - CNetFileReceiveTask& task = *it->second; + AsyncFileReceiveTask& task = it->second; - task.m_Buffer += message.m_Data; + task.buffer += message.m_Data; - if (task.m_Buffer.size() > task.m_Length) + if (task.buffer.size() > task.length) { - LOGERROR("Net transfer: Invalid size for file transfer data (length=%lu actual=%zu)", task.m_Length, task.m_Buffer.size()); + LOGERROR("Net transfer: Invalid size for file transfer data (length=%lu actual=%zu)", + task.length, task.buffer.size()); return ERR::FAIL; } CFileTransferAckMessage ackMessage; - ackMessage.m_RequestID = task.m_RequestID; + ackMessage.m_RequestID = message.m_RequestID; ackMessage.m_NumPackets = 1; // TODO: would be nice to send a single ack for multiple packets at once m_Session->SendMessage(&ackMessage); - if (task.m_Buffer.size() == task.m_Length) + if (task.buffer.size() == task.length) { LOGMESSAGERENDER("Download completed"); - task.OnComplete(); - m_FileReceiveTasks.erase(message.m_RequestID); + task.onComplete(std::move(task.buffer)); + m_FileReceiveTasks.erase(it); return INFO::OK; } @@ -107,7 +109,8 @@ double t = timer_Time(); if (t > m_LastProgressReportTime + 0.5) { - LOGMESSAGERENDER("Downloading data: %.1f%% of %lu KB", 100.f * task.m_Buffer.size() / task.m_Length, task.m_Length / 1024); + LOGMESSAGERENDER("Downloading data: %.1f%% of %lu KiB", + 100.f * task.buffer.size() / task.length, task.length / KiB); m_LastProgressReportTime = t; } @@ -138,12 +141,11 @@ } -void CNetFileTransferer::StartTask(const std::shared_ptr& task) +void CNetFileTransferer::StartTask(std::function task) { u32 requestID = m_NextRequestID++; - task->m_RequestID = requestID; - m_FileReceiveTasks[requestID] = task; + m_FileReceiveTasks.emplace(requestID, AsyncFileReceiveTask{std::move(task)}); CFileTransferRequestMessage request; request.m_RequestID = requestID; Index: ps/trunk/source/network/NetServer.h =================================================================== --- ps/trunk/source/network/NetServer.h +++ ps/trunk/source/network/NetServer.h @@ -1,4 +1,4 @@ -/* Copyright (C) 2022 Wildfire Games. +/* Copyright (C) 2024 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify @@ -233,7 +233,6 @@ private: friend class CNetServer; - friend class CNetFileReceiveTask_ServerRejoin; CNetServerWorker(bool useLobbyAuth); ~CNetServerWorker(); Index: ps/trunk/source/network/NetServer.cpp =================================================================== --- ps/trunk/source/network/NetServer.cpp +++ ps/trunk/source/network/NetServer.cpp @@ -90,58 +90,6 @@ return "[" + session->GetGUID().substr(0, 8) + "...]"; } -/** - * Async task for receiving the initial game state to be forwarded to another - * client that is rejoining an in-progress network game. - */ -class CNetFileReceiveTask_ServerRejoin : public CNetFileReceiveTask -{ - NONCOPYABLE(CNetFileReceiveTask_ServerRejoin); -public: - CNetFileReceiveTask_ServerRejoin(CNetServerWorker& server, u32 hostID) - : m_Server(server), m_RejoinerHostID(hostID) - { - } - - virtual void OnComplete() - { - // We've received the game state from an existing player - now - // we need to send it onwards to the newly rejoining player - - // Find the session corresponding to the rejoining host (if any) - CNetServerSession* session = NULL; - for (CNetServerSession* serverSession : m_Server.m_Sessions) - { - if (serverSession->GetHostID() == m_RejoinerHostID) - { - session = serverSession; - break; - } - } - - if (!session) - { - LOGMESSAGE("Net server: rejoining client disconnected before we sent to it"); - return; - } - - // Store the received state file, and tell the client to start downloading it from us - // TODO: this will get kind of confused if there's multiple clients downloading in parallel; - // they'll race and get whichever happens to be the latest received by the server, - // which should still work but isn't great - m_Server.m_JoinSyncFile = m_Buffer; - - // Send the init attributes alongside - these should be correct since the game should be started. - CJoinSyncStartMessage message; - message.m_InitAttributes = Script::StringifyJSON(ScriptRequest(m_Server.GetScriptInterface()), &m_Server.m_InitAttributes); - session->SendMessage(&message); - } - -private: - CNetServerWorker& m_Server; - u32 m_RejoinerHostID; -}; - /* * XXX: We use some non-threadsafe functions from the worker thread. * See http://trac.wildfiregames.com/ticket/654 @@ -1151,24 +1099,50 @@ server.OnUserJoin(session); - if (isRejoining) - { - ENSURE(server.m_State != SERVER_STATE_UNCONNECTED && server.m_State != SERVER_STATE_PREGAME); + if (!isRejoining) + return true; - // Request a copy of the current game state from an existing player, - // so we can send it on to the new player + ENSURE(server.m_State != SERVER_STATE_UNCONNECTED && server.m_State != SERVER_STATE_PREGAME); - // Assume session 0 is most likely the local player, so they're - // the most efficient client to request a copy from - CNetServerSession* sourceSession = server.m_Sessions.at(0); + // Request a copy of the current game state from an existing player, so we can send it on to the new + // player. - sourceSession->GetFileTransferer().StartTask( - std::shared_ptr(new CNetFileReceiveTask_ServerRejoin(server, newHostID)) - ); + // Assume session 0 is most likely the local player, so they're the most efficient client to request a + // copy from. + CNetServerSession* sourceSession = server.m_Sessions.at(0); - session->SetNextState(NSS_JOIN_SYNCING); - } + sourceSession->GetFileTransferer().StartTask([&server, newHostID](std::string buffer) + { + // We've received the game state from an existing player - now we need to send it onwards + // to the newly rejoining player. + + const auto sessionIt = std::find_if(server.m_Sessions.begin(), server.m_Sessions.end(), + [newHostID](const CNetServerSession* serverSession) + { + return serverSession->GetHostID() == newHostID; + }); + + if (sessionIt == server.m_Sessions.end()) + { + LOGMESSAGE("Net server: rejoining client disconnected before we sent to it"); + return; + } + + // Store the received state file, and tell the client to stant downloading it from us. + // TODO: The server will get kind of confused if there's multiple clients downloading in + // parallel; they'll race and get whichever happens to be the latest received by the + // server, which should still work but isn't great. + server.m_JoinSyncFile = std::move(buffer); + + // Send the init attributes alongside - these should be correct since the game should be + // started. + CJoinSyncStartMessage message; + message.m_InitAttributes = Script::StringifyJSON( + ScriptRequest{server.GetScriptInterface()}, &server.m_InitAttributes); + (*sessionIt)->SendMessage(&message); + }); + session->SetNextState(NSS_JOIN_SYNCING); return true; } bool CNetServerWorker::OnSimulationCommand(void* context, CFsmEvent* event) Index: ps/trunk/source/network/tests/test_FileTransfer.h =================================================================== --- ps/trunk/source/network/tests/test_FileTransfer.h +++ ps/trunk/source/network/tests/test_FileTransfer.h @@ -0,0 +1,124 @@ +/* Copyright (C) 2024 Wildfire Games. + * This file is part of 0 A.D. + * + * 0 A.D. is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * 0 A.D. is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with 0 A.D. If not, see . + */ + +#include "lib/self_test.h" + +#include "network/NetFileTransfer.h" +#include "network/NetMessage.h" +#include "network/NetSession.h" + +#include +#include + +namespace +{ +constexpr const char* MESSAGECONTENT{"Some example message content"}; + +class MessageQueues : public INetSession +{ +public: + ~MessageQueues() final = default; + bool SendMessage(const CNetMessage* message) final + { + switch (message->GetType()) + { + case NMT_FILE_TRANSFER_REQUEST: + requests.push_back(*static_cast(message)); + break; + case NMT_FILE_TRANSFER_RESPONSE: + responses.push_back(*static_cast(message)); + break; + case NMT_FILE_TRANSFER_DATA: + data.push_back(*static_cast(message)); + break; + case NMT_FILE_TRANSFER_ACK: + acknowledgements.push_back(*static_cast(message)); + break; + default: + TS_FAIL("Unhandeled message type"); + } + + return true; + } + + std::vector requests; + std::vector responses; + std::vector data; + std::vector acknowledgements; +}; + +void CheckSizes(MessageQueues& queues, size_t requestSize, size_t responseSize, size_t dataSize, + size_t acknowledgementSize) +{ + TS_ASSERT_EQUALS(queues.requests.size(), requestSize); + TS_ASSERT_EQUALS(queues.responses.size(), responseSize); + TS_ASSERT_EQUALS(queues.data.size(), dataSize); + TS_ASSERT_EQUALS(queues.acknowledgements.size(), acknowledgementSize); +} + +struct Participant +{ + MessageQueues queues; + CNetFileTransferer transferer{&queues}; +}; +} + +class TestFileTransfer : public CxxTest::TestSuite +{ +public: + void test_transfer() + { + // The client requests some data from the server. + + Participant server; + Participant client; + + bool complete{false}; + + client.transferer.StartTask([&complete](std::string buffer) + { + // This callback is executed exactly once. + const bool previousComplete{std::exchange(complete, true)}; + TS_ASSERT(!previousComplete); + TS_ASSERT_STR_EQUALS(buffer, MESSAGECONTENT); + }); + CheckSizes(client.queues, 1, 0, 0, 0); + + server.transferer.StartResponse(client.queues.requests.at(0).m_RequestID, MESSAGECONTENT); + CheckSizes(server.queues, 0, 1, 0, 0); + + client.transferer.HandleMessageReceive(server.queues.responses.at(0)); + CheckSizes(client.queues, 1, 0, 0, 0); + + server.transferer.Poll(); + CheckSizes(server.queues, 0, 1, 1, 0); + + server.transferer.Poll(); + // If `MESSAGECONTENT` would be longer another message would be sent. + CheckSizes(server.queues, 0, 1, 1, 0); + + TS_ASSERT(!complete); + + client.transferer.HandleMessageReceive(server.queues.data.at(0)); + CheckSizes(client.queues, 1, 0, 0, 1); + + TS_ASSERT(complete); + + server.transferer.HandleMessageReceive(client.queues.acknowledgements.at(0)); + CheckSizes(server.queues, 0, 1, 1, 0); + } +};