Index: ps/trunk/source/network/NetClient.cpp =================================================================== --- ps/trunk/source/network/NetClient.cpp +++ ps/trunk/source/network/NetClient.cpp @@ -391,7 +391,7 @@ { // Handle non-FSM messages first - Status status = m_Session->GetFileTransferer().HandleMessageReceive(message); + Status status = m_Session->GetFileTransferer().HandleMessageReceive(*message); if (status == INFO::OK) return true; if (status != INFO::SKIPPED) Index: ps/trunk/source/network/NetFileTransfer.h =================================================================== --- ps/trunk/source/network/NetFileTransfer.h +++ ps/trunk/source/network/NetFileTransfer.h @@ -1,4 +1,4 @@ -/* Copyright (C) 2016 Wildfire Games. +/* Copyright (C) 2019 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify @@ -21,6 +21,9 @@ #include class CNetMessage; +class CFileTransferResponseMessage; +class CFileTransferDataMessage; +class CFileTransferAckMessage; class CNetClientSession; class CNetServerSession; class INetSession; @@ -84,7 +87,7 @@ * Returns INFO::OK if the message is handled successfully, * or ERR::FAIL if handled unsuccessfully. */ - Status HandleMessageReceive(const CNetMessage* message); + Status HandleMessageReceive(const CNetMessage& message); /** * Registers a file-receiving task. @@ -105,6 +108,10 @@ void Poll(); private: + Status OnFileTransferResponse(const CFileTransferResponseMessage& message); + Status OnFileTransferData(const CFileTransferDataMessage& message); + Status OnFileTransferAck(const CFileTransferAckMessage& message); + /** * Asynchronous file-sending task. */ @@ -121,10 +128,10 @@ u32 m_NextRequestID; - typedef std::map> FileReceiveTasksMap; + using FileReceiveTasksMap = std::map >; FileReceiveTasksMap m_FileReceiveTasks; - typedef std::map FileSendTasksMap; + using FileSendTasksMap = std::map; FileSendTasksMap m_FileSendTasks; double m_LastProgressReportTime; Index: ps/trunk/source/network/NetFileTransfer.cpp =================================================================== --- ps/trunk/source/network/NetFileTransfer.cpp +++ ps/trunk/source/network/NetFileTransfer.cpp @@ -1,4 +1,4 @@ -/* Copyright (C) 2016 Wildfire Games. +/* Copyright (C) 2019 Wildfire Games. * This file is part of 0 A.D. * * 0 A.D. is free software: you can redistribute it and/or modify @@ -24,107 +24,119 @@ #include "network/NetSession.h" #include "ps/CLogger.h" -Status CNetFileTransferer::HandleMessageReceive(const CNetMessage* message) +Status CNetFileTransferer::HandleMessageReceive(const CNetMessage& message) { - if (message->GetType() == NMT_FILE_TRANSFER_RESPONSE) + switch (message.GetType()) { - CFileTransferResponseMessage* respMessage = (CFileTransferResponseMessage*)message; + case NMT_FILE_TRANSFER_RESPONSE: + return OnFileTransferResponse(static_cast(message)); - if (m_FileReceiveTasks.find(respMessage->m_RequestID) == m_FileReceiveTasks.end()) - { - LOGERROR("Net transfer: Unsolicited file transfer response (id=%d)", (int)respMessage->m_RequestID); - return ERR::FAIL; - } - - if (respMessage->m_Length == 0 || respMessage->m_Length > MAX_FILE_TRANSFER_SIZE) - { - LOGERROR("Net transfer: Invalid size for file transfer response (length=%d)", (int)respMessage->m_Length); - return ERR::FAIL; - } - - shared_ptr task = m_FileReceiveTasks[respMessage->m_RequestID]; + case NMT_FILE_TRANSFER_DATA: + return OnFileTransferData(static_cast(message)); - task->m_Length = respMessage->m_Length; - task->m_Buffer.reserve(respMessage->m_Length); + case NMT_FILE_TRANSFER_ACK: + return OnFileTransferAck(static_cast(message)); - LOGMESSAGERENDER("Downloading data over network (%d KB) - please wait...", (int)(task->m_Length/1024)); - m_LastProgressReportTime = timer_Time(); + default: + return INFO::SKIPPED; + } +} - return INFO::OK; +Status CNetFileTransferer::OnFileTransferResponse(const CFileTransferResponseMessage& message) +{ + const FileReceiveTasksMap::iterator it = m_FileReceiveTasks.find(message.m_RequestID); + if (it == m_FileReceiveTasks.end()) + { + LOGERROR("Net transfer: Unsolicited file transfer response (id=%lu)", message.m_RequestID); + return ERR::FAIL; } - else if (message->GetType() == NMT_FILE_TRANSFER_DATA) + + if (message.m_Length == 0 || message.m_Length > MAX_FILE_TRANSFER_SIZE) { - CFileTransferDataMessage* dataMessage = (CFileTransferDataMessage*)message; + LOGERROR("Net transfer: Invalid size for file transfer response (length=%lu)", message.m_Length); + return ERR::FAIL; + } - if (m_FileReceiveTasks.find(dataMessage->m_RequestID) == m_FileReceiveTasks.end()) - { - LOGERROR("Net transfer: Unsolicited file transfer data (id=%d)", (int)dataMessage->m_RequestID); - return ERR::FAIL; - } + CNetFileReceiveTask& task = *it->second; - shared_ptr task = m_FileReceiveTasks[dataMessage->m_RequestID]; + task.m_Length = message.m_Length; + task.m_Buffer.reserve(message.m_Length); - task->m_Buffer += dataMessage->m_Data; + LOGMESSAGERENDER("Downloading data over network (%lu KB) - please wait...", task.m_Length / 1024); + m_LastProgressReportTime = timer_Time(); - if (task->m_Buffer.size() > task->m_Length) - { - LOGERROR("Net transfer: Invalid size for file transfer data (length=%d actual=%d)", (int)task->m_Length, (int)task->m_Buffer.size()); - return ERR::FAIL; - } + return INFO::OK; +} - CFileTransferAckMessage ackMessage; - ackMessage.m_RequestID = task->m_RequestID; - ackMessage.m_NumPackets = 1; // TODO: would be nice to send a single ack for multiple packets at once - m_Session->SendMessage(&ackMessage); +Status CNetFileTransferer::OnFileTransferData(const CFileTransferDataMessage& message) +{ + FileReceiveTasksMap::iterator it = m_FileReceiveTasks.find(message.m_RequestID); + if (it == m_FileReceiveTasks.end()) + { + LOGERROR("Net transfer: Unsolicited file transfer data (id=%lu)", message.m_RequestID); + return ERR::FAIL; + } - if (task->m_Buffer.size() == task->m_Length) - { - LOGMESSAGERENDER("Download completed"); + CNetFileReceiveTask& task = *it->second; - task->OnComplete(); - m_FileReceiveTasks.erase(dataMessage->m_RequestID); - return INFO::OK; - } + task.m_Buffer += message.m_Data; - // TODO: should report progress using proper GUI + if (task.m_Buffer.size() > task.m_Length) + { + LOGERROR("Net transfer: Invalid size for file transfer data (length=%lu actual=%zu)", task.m_Length, task.m_Buffer.size()); + return ERR::FAIL; + } - // Report the download status occassionally - double t = timer_Time(); - if (t > m_LastProgressReportTime + 0.5) - { - LOGMESSAGERENDER("Downloading data: %.1f%% of %d KB", 100.f*task->m_Buffer.size()/task->m_Length, (int)(task->m_Length/1024)); - m_LastProgressReportTime = t; - } + CFileTransferAckMessage ackMessage; + ackMessage.m_RequestID = task.m_RequestID; + ackMessage.m_NumPackets = 1; // TODO: would be nice to send a single ack for multiple packets at once + m_Session->SendMessage(&ackMessage); + + if (task.m_Buffer.size() == task.m_Length) + { + LOGMESSAGERENDER("Download completed"); + task.OnComplete(); + m_FileReceiveTasks.erase(message.m_RequestID); return INFO::OK; } - else if (message->GetType() == NMT_FILE_TRANSFER_ACK) - { - CFileTransferAckMessage* ackMessage = (CFileTransferAckMessage*)message; - if (m_FileSendTasks.find(ackMessage->m_RequestID) == m_FileSendTasks.end()) - { - LOGERROR("Net transfer: Unsolicited file transfer ack (id=%d)", (int)ackMessage->m_RequestID); - return ERR::FAIL; - } + // TODO: should report progress using proper GUI - CNetFileSendTask& task = m_FileSendTasks[ackMessage->m_RequestID]; + // Report the download status occassionally + double t = timer_Time(); + if (t > m_LastProgressReportTime + 0.5) + { + LOGMESSAGERENDER("Downloading data: %.1f%% of %lu KB", 100.f * task.m_Buffer.size() / task.m_Length, task.m_Length / 1024); + m_LastProgressReportTime = t; + } - if (ackMessage->m_NumPackets > task.packetsInFlight) - { - LOGERROR("Net transfer: Invalid num packets for file transfer ack (num=%d inflight=%d)", - (int)ackMessage->m_NumPackets, (int)task.packetsInFlight); - return ERR::FAIL; - } + return INFO::OK; +} - task.packetsInFlight -= ackMessage->m_NumPackets; +Status CNetFileTransferer::OnFileTransferAck(const CFileTransferAckMessage& message) +{ + FileSendTasksMap::iterator it = m_FileSendTasks.find(message.m_RequestID); + if (it == m_FileSendTasks.end()) + { + LOGERROR("Net transfer: Unsolicited file transfer ack (id=%lu)", message.m_RequestID); + return ERR::FAIL; + } - return INFO::OK; + CNetFileSendTask& task = it->second; + + if (message.m_NumPackets > task.packetsInFlight) + { + LOGERROR("Net transfer: Invalid num packets for file transfer ack (num=%lu inflight=%lu)", + message.m_NumPackets, task.packetsInFlight); + return ERR::FAIL; } - return INFO::SKIPPED; -} + task.packetsInFlight -= message.m_NumPackets; + return INFO::OK; + +} void CNetFileTransferer::StartTask(const shared_ptr& task) { Index: ps/trunk/source/network/NetServer.cpp =================================================================== --- ps/trunk/source/network/NetServer.cpp +++ ps/trunk/source/network/NetServer.cpp @@ -621,7 +621,7 @@ void CNetServerWorker::HandleMessageReceive(const CNetMessage* message, CNetServerSession* session) { // Handle non-FSM messages first - Status status = session->GetFileTransferer().HandleMessageReceive(message); + Status status = session->GetFileTransferer().HandleMessageReceive(*message); if (status != INFO::SKIPPED) return;