Index: source/lobby/IXmppClient.h =================================================================== --- source/lobby/IXmppClient.h +++ source/lobby/IXmppClient.h @@ -22,10 +22,6 @@ class ScriptRequest; -namespace StunClient { - struct StunEndpoint; -} - class IXmppClient { public: @@ -64,7 +60,7 @@ virtual bool GuiPollHasPlayerListUpdate() = 0; virtual void SendMUCMessage(const std::string& message) = 0; - virtual void SendStunEndpointToHost(const StunClient::StunEndpoint& stunEndpoint, const std::string& hostJID) = 0; + virtual void SendStunEndpointToHost(const std::string& ip, const u16 port, const std::string& hostJID) = 0; }; extern IXmppClient *g_XmppClient; Index: source/lobby/XmppClient.h =================================================================== --- source/lobby/XmppClient.h +++ source/lobby/XmppClient.h @@ -106,7 +106,7 @@ JS::Value GUIGetBoardList(const ScriptInterface& scriptInterface); JS::Value GUIGetProfile(const ScriptInterface& scriptInterface); - void SendStunEndpointToHost(const StunClient::StunEndpoint& stunEndpoint, const std::string& hostJID); + void SendStunEndpointToHost(const std::string& ip, const u16 port, const std::string& hostJID); /** * Convert gloox values to string or time. Index: source/lobby/XmppClient.cpp =================================================================== --- source/lobby/XmppClient.cpp +++ source/lobby/XmppClient.cpp @@ -20,19 +20,14 @@ #include "XmppClient.h" #include "StanzaExtensions.h" -#ifdef WIN32 -# include -#endif - #include "i18n/L10n.h" -#include "lib/external_libraries/enet.h" #include "lib/utf8.h" #include "network/NetServer.h" #include "network/NetClient.h" -#include "network/StunClient.h" #include "ps/CLogger.h" #include "ps/ConfigDB.h" #include "ps/GUID.h" +#include "ps/Networking.h" #include "ps/Pyrogenesis.h" #include "scriptinterface/ScriptInterface.h" #include "scriptinterface/StructuredClone.h" @@ -872,7 +867,16 @@ return true; } - g_NetClient->SetupServerData(cd->m_Ip.to_string(), stoi(cd->m_Port.to_string()), !cd->m_UseSTUN.empty()); + // The received IP is (for now) an IPV4-mapped IPV6 address. + // Until Enet supports IPV6, convert to IPV4 explicitly. + Networking::IPAddress addr; + if (!addr.Parse(cd->m_Ip.to_string(), cd->m_Port.to_string()) || addr.GetIPV4().empty()) + { + g_NetClient->HandleGetServerDataFailed("not_ipv4"); + return true; + } + + g_NetClient->SetupServerData(addr.GetIPV4(), addr.GetPort(), !cd->m_UseSTUN.empty()); g_NetClient->TryToConnect(iq.from().full()); } if (gq) @@ -890,7 +894,7 @@ LOGWARNING("XmppClient: Received game with no IP in response to Game Register"); return true; } - g_NetServer->SetConnectionData(publicIP, g_NetServer->GetPublicPort(), false); + g_NetServer->SetConnectionData(publicIP, g_NetServer->GetPublicPort()); return true; } @@ -1000,7 +1004,7 @@ glooxwrapper::IQ response(gloox::IQ::Result, iq.from(), iq.id()); ConnectionData* connectionData = new ConnectionData(); - connectionData->m_Ip = g_NetServer->GetPublicIp();; + connectionData->m_Ip = g_NetServer->GetPublicIp(); connectionData->m_Port = std::to_string(g_NetServer->GetPublicPort()); connectionData->m_UseSTUN = g_NetServer->GetUseSTUN() ? "true" : ""; @@ -1475,18 +1479,13 @@ #undef CASE } -void XmppClient::SendStunEndpointToHost(const StunClient::StunEndpoint& stunEndpoint, const std::string& hostJIDStr) +void XmppClient::SendStunEndpointToHost(const std::string& ip, const u16 port, const std::string& hostJIDStr) { DbgXMPP("SendStunEndpointToHost " << hostJIDStr); - char ipStr[256] = "(error)"; - ENetAddress addr; - addr.host = ntohl(stunEndpoint.ip); - enet_address_get_host_ip(&addr, ipStr, ARRAY_SIZE(ipStr)); - glooxwrapper::JID hostJID(hostJIDStr); glooxwrapper::Jingle::Session session = m_sessionManager->createSession(hostJID); - session.sessionInitiate(ipStr, stunEndpoint.port); + session.sessionInitiate(ip.c_str(), port); } void XmppClient::handleSessionAction(gloox::Jingle::Action action, glooxwrapper::Jingle::Session& session, const glooxwrapper::Jingle::Session::Jingle& jingle) Index: source/lobby/glooxwrapper/glooxwrapper.h =================================================================== --- source/lobby/glooxwrapper/glooxwrapper.h +++ source/lobby/glooxwrapper/glooxwrapper.h @@ -673,7 +673,7 @@ Session(gloox::Jingle::Session* wrapped, bool owned); ~Session(); - bool sessionInitiate(char* ipStr, uint16_t port); + bool sessionInitiate(const char* ipStr, uint16_t port); }; class GLOOXWRAPPER_API SessionHandler Index: source/lobby/glooxwrapper/glooxwrapper.cpp =================================================================== --- source/lobby/glooxwrapper/glooxwrapper.cpp +++ source/lobby/glooxwrapper/glooxwrapper.cpp @@ -841,7 +841,7 @@ delete m_Wrapped; } -bool glooxwrapper::Jingle::Session::sessionInitiate(char* ipStr, u16 port) +bool glooxwrapper::Jingle::Session::sessionInitiate(const char* ipStr, u16 port) { gloox::Jingle::ICEUDP::CandidateList candidateList; Index: source/network/NetClient.cpp =================================================================== --- source/network/NetClient.cpp +++ source/network/NetClient.cpp @@ -239,15 +239,8 @@ ENetHost* enetClient = nullptr; if (g_XmppClient && m_UseSTUN) { - // Find an unused port - for (int i = 0; i < 5 && !enetClient; ++i) - { - // Ports below 1024 are privileged on unix - u16 port = 1024 + rand() % (UINT16_MAX - 1024); - ENetAddress hostAddr{ ENET_HOST_ANY, port }; - enetClient = enet_host_create(&hostAddr, 1, 1, 0, 0); - ++hostAddr.port; - } + ENetAddress hostAddr{ ENET_HOST_ANY, 0 }; + enetClient = enet_host_create(&hostAddr, 1, 1, 0, 0); if (!enetClient) { @@ -258,8 +251,8 @@ return false; } - StunClient::StunEndpoint stunEndpoint; - if (!StunClient::FindStunEndpointJoin(*enetClient, stunEndpoint)) + Networking::IPAddress publicIP; + if (!StunClient::FindPublicIP(*enetClient, publicIP)) { PushGuiMessage( "type", "netstatus", @@ -268,7 +261,7 @@ return false; } - g_XmppClient->SendStunEndpointToHost(stunEndpoint, hostJID); + g_XmppClient->SendStunEndpointToHost(publicIP.GetIP(), publicIP.GetPort(), hostJID); SDL_Delay(1000); Index: source/network/NetServer.h =================================================================== --- source/network/NetServer.h +++ source/network/NetServer.h @@ -148,7 +148,9 @@ void SendHolePunchingMessage(const CStr& ip, u16 port); - void SetConnectionData(const CStr& ip, u16 port, bool useSTUN); + void SetConnectionData(const CStr& ip, u16 port); + + bool SetConnectionDataViaSTUN(); bool GetUseSTUN() const; Index: source/network/NetServer.cpp =================================================================== --- source/network/NetServer.cpp +++ source/network/NetServer.cpp @@ -1649,11 +1649,30 @@ return m_PublicIp; } -void CNetServer::SetConnectionData(const CStr& ip, const u16 port, bool useSTUN) +void CNetServer::SetConnectionData(const CStr& ip, const u16 port) { m_PublicIp = ip; m_PublicPort = port; - m_UseSTUN = useSTUN; +} + +bool CNetServer::SetConnectionDataViaSTUN() +{ + m_UseSTUN = true; + + std::lock_guard lock(m_Worker->m_WorkerMutex); + if (!m_Worker->m_Host) + return false; + + Networking::IPAddress publicIP; + if (!StunClient::FindPublicIP(*m_Worker->m_Host, publicIP)) + { + LOGERROR("Failed to find public IP via STUN."); + return false; + } + + m_PublicIp = publicIP.GetIP(); + m_PublicPort = publicIP.GetPort(); + return true; } bool CNetServer::CheckPasswordAndIncrement(const CStr& password, const std::string& username) Index: source/network/StunClient.h =================================================================== --- source/network/StunClient.h +++ source/network/StunClient.h @@ -19,26 +19,25 @@ #ifndef STUNCLIENT_H #define STUNCLIENT_H +#include "ps/Networking.h" + #include typedef struct _ENetHost ENetHost; -class ScriptInterface; -class CStr8; namespace StunClient { - -struct StunEndpoint { - u32 ip; - u16 port; -}; - -void SendStunRequest(ENetHost& transactionHost, u32 targetIp, u16 targetPort); - -bool FindStunEndpointHost(CStr8& ip, u16& port); - -bool FindStunEndpointJoin(ENetHost& transactionHost, StunClient::StunEndpoint& stunEndpoint); - +/** + * Find the public IP assigned to @a enetClient by pinging a STUN server. + * Note that this operation is blocking. + * On success, @a publicIP will contain the public IP & port. + * @a return true on success, false on failure. + */ +bool FindPublicIP(ENetHost& enetClient, Networking::IPAddress& publicIP); + +/** + * Send an UDP message to the peer, opening up the @a enetClient port to receive traffic from the server. + */ void SendHolePunchingMessages(ENetHost& enetClient, const std::string& serverAddress, u16 serverPort); } Index: source/network/StunClient.cpp =================================================================== --- source/network/StunClient.cpp +++ source/network/StunClient.cpp @@ -20,78 +20,59 @@ #include "StunClient.h" -#include "lib/sysdep/os.h" - -#include -#include - -#include -#include - -#include -#if OS_WIN -# include -# include -#else -# include -# include -#endif - -#include - #include "lib/external_libraries/enet.h" -#if OS_WIN -#include "lib/sysdep/os/win/wposix/wtime.h" -#endif - #include "ps/CLogger.h" #include "ps/ConfigDB.h" #include "ps/CStr.h" -unsigned int m_StunServerIP; -int m_StunServerPort; +#include +#include +#include +namespace StunClient +{ /** * These constants are defined in Section 6 of RFC 5389. */ -const u32 m_MagicCookie = 0x2112A442; -const u32 m_MethodTypeBinding = 0x0001; -const u32 m_BindingSuccessResponse = 0x0101; +constexpr u32 m_MagicCookie = 0x2112A442; +constexpr u32 m_MethodTypeBinding = 0x0001; +constexpr u32 m_BindingSuccessResponse = 0x0101; /** * Bit determining whether comprehension of an attribute is optional. * Described in Section 15 of RFC 5389. */ -const u16 m_ComprehensionOptional = 0x1 << 15; +constexpr u16 m_ComprehensionOptional = 0x1 << 15; /** * Bit determining whether the bit was assigned by IETF Review. * Described in section 18.1. of RFC 5389. */ -const u16 m_IETFReview = 0x1 << 14; +constexpr u16 m_IETFReview = 0x1 << 14; /** * These constants are defined in Section 15.1 of RFC 5389. */ -const u8 m_IPAddressFamilyIPv4 = 0x01; +constexpr u8 m_IPAddressFamilyIPv4 = 0x01; /** * These constants are defined in Section 18.2 of RFC 5389. */ -const u16 m_AttrTypeMappedAddress = 0x001; -const u16 m_AttrTypeXORMappedAddress = 0x0020; +constexpr u16 m_AttrTypeMappedAddress = 0x001; +constexpr u16 m_AttrTypeXORMappedAddress = 0x0020; /** * Described in section 3 of RFC 5389. */ u8 m_TransactionID[12]; +using StunRequest = Networking::Connection; + /** * Discovered STUN endpoint */ -u32 m_IP; -u16 m_Port; +Networking::IPAddress m_IP; void AddUInt16(std::vector& buffer, const u16 value) { @@ -126,51 +107,7 @@ return true; } -/** - * Creates a STUN request and sends it to a STUN server. - * The request is sent through transactionHost, from which the answer - * will be retrieved by ReceiveStunResponse and interpreted by ParseStunResponse. - */ -bool CreateStunRequest(ENetHost& transactionHost) -{ - CStr server_name; - CFG_GET_VAL("lobby.stun.server", server_name); - CFG_GET_VAL("lobby.stun.port", m_StunServerPort); - - debug_printf("GetPublicAddress: Using STUN server %s:%d\n", server_name.c_str(), m_StunServerPort); - - addrinfo hints; - addrinfo* res; - - memset(&hints, 0, sizeof(hints)); - hints.ai_family = AF_UNSPEC; // AF_INET or AF_INET6 to force version - hints.ai_socktype = SOCK_STREAM; - - // Resolve the stun server name so we can send it a STUN request - int status = getaddrinfo(server_name.c_str(), nullptr, &hints, &res); - if (status != 0) - { -#ifdef UNICODE - LOGERROR("GetPublicAddress: Error in getaddrinfo: %s", utf8_from_wstring(gai_strerror(status))); -#else - LOGERROR("GetPublicAddress: Error in getaddrinfo: %s", gai_strerror(status)); -#endif - return false; - } - - ENSURE(res); - - // Documentation says it points to "one or more addrinfo structures" - sockaddr_in* current_interface = reinterpret_cast(res->ai_addr); - m_StunServerIP = ntohl(current_interface->sin_addr.s_addr); - - StunClient::SendStunRequest(transactionHost, m_StunServerIP, m_StunServerPort); - - freeaddrinfo(res); - return true; -} - -void StunClient::SendStunRequest(ENetHost& transactionHost, u32 targetIp, u16 targetPort) +bool SendStunRequest(StunRequest& request, const Networking::IPAddress& ip) { std::vector buffer; AddUInt16(buffer, m_MethodTypeBinding); @@ -184,75 +121,31 @@ m_TransactionID[i] = random_byte; } - sockaddr_in to; - int to_len = sizeof(to); - memset(&to, 0, to_len); - - to.sin_family = AF_INET; - to.sin_port = htons(targetPort); - to.sin_addr.s_addr = htonl(targetIp); - - sendto( - transactionHost.socket, - reinterpret_cast(buffer.data()), - static_cast(buffer.size()), - 0, - reinterpret_cast(&to), - to_len); + return !!request.SendTo(ip, buffer); } /** - * Gets the response from the STUN server and checks it for its validity. + * Creates a STUN request and sends it to a STUN server. + * The request is sent from the socket of @a enetClient. */ -bool ReceiveStunResponse(ENetHost& transactionHost, std::vector& buffer) +StunRequest CreateStunRequest(ENetHost& enetClient) { - // TransportAddress sender; - const int LEN = 2048; - char input_buffer[LEN]; + CStr server_name; + int server_port; + CFG_GET_VAL("lobby.stun.server", server_name); + CFG_GET_VAL("lobby.stun.port", server_port); - memset(input_buffer, 0, LEN); + debug_printf("GetPublicAddress: Using STUN server %s:%d\n", server_name.c_str(), server_port); - sockaddr_in addr; - socklen_t from_len = sizeof(addr); + Networking::IPAddress ip; + StunRequest request(enetClient.socket); + // TODO: once Enet supports IPV6, switch to IPV6. + if (!request.ResolveIP(ip, server_name.c_str(), static_cast(server_port), true)) + return request; - int len = recvfrom(transactionHost.socket, input_buffer, LEN, 0, reinterpret_cast(&addr), &from_len); + StunClient::SendStunRequest(request, ip); - int delay = 200; - CFG_GET_VAL("lobby.stun.delay", delay); - - // Wait to receive the message because enet sockets are non-blocking - const int max_tries = 5; - for (int count = 0; len < 0 && (count < max_tries || max_tries == -1); ++count) - { - usleep(delay * 1000); - len = recvfrom(transactionHost.socket, input_buffer, LEN, 0, reinterpret_cast(&addr), &from_len); - } - - if (len < 0) - { - LOGERROR("GetPublicAddress: recvfrom error (%d): %s", errno, strerror(errno)); - return false; - } - - u32 sender_ip = ntohl(static_cast(addr.sin_addr.s_addr)); - u16 sender_port = ntohs(addr.sin_port); - - if (sender_ip != m_StunServerIP) - LOGERROR("GetPublicAddress: Received stun response from different address: %d:%d (%d.%d.%d.%d:%d) %s", - addr.sin_addr.s_addr, - addr.sin_port, - (sender_ip >> 24) & 0xff, - (sender_ip >> 16) & 0xff, - (sender_ip >> 8) & 0xff, - (sender_ip >> 0) & 0xff, - sender_port, - input_buffer); - - // Convert to network string. - buffer.resize(len); - memcpy(buffer.data(), reinterpret_cast(input_buffer), len); - - return true; + return request; } bool ParseStunResponse(const std::vector& buffer) @@ -337,8 +230,7 @@ ip ^= m_MagicCookie; } - m_Port = port; - m_IP = ip; + m_IP = Networking::IPAddress(ip, port); break; } @@ -361,70 +253,46 @@ return true; } -bool STUNRequestAndResponse(ENetHost& transactionHost) +bool FindPublicIP(ENetHost& enetClient, Networking::IPAddress& publicIP) { - if (!CreateStunRequest(transactionHost)) + StunRequest request = CreateStunRequest(enetClient); + if (!request) return false; - std::vector buffer; - return ReceiveStunResponse(transactionHost, buffer) && - ParseStunResponse(buffer); -} + std::vector response; + int delay = 200; + CFG_GET_VAL("lobby.stun.delay", delay); -bool StunClient::FindStunEndpointHost(CStr8& ip, u16& port) -{ - ENetAddress hostAddr{ENET_HOST_ANY, static_cast(port)}; - ENetHost* transactionHost = enet_host_create(&hostAddr, 1, 1, 0, 0); - if (!transactionHost) - return false; - - bool success = STUNRequestAndResponse(*transactionHost); - enet_host_destroy(transactionHost); - if (!success) - return false; - - // Convert m_IP to string - char ipStr[256] = "(error)"; - ENetAddress addr; - addr.host = ntohl(m_IP); - int result = enet_address_get_host_ip(&addr, ipStr, ARRAY_SIZE(ipStr)); - - ip = ipStr; - port = m_Port; - return result == 0; -} - -bool StunClient::FindStunEndpointJoin(ENetHost& transactionHost, StunClient::StunEndpoint& stunEndpoint) -{ - if (!STUNRequestAndResponse(transactionHost)) + if (!request.Receive(response, delay, 5) || !ParseStunResponse(response)) return false; - // Convert m_IP to string - char ipStr[256] = "(error)"; - ENetAddress addr; - addr.host = ntohl(m_IP); - enet_address_get_host_ip(&addr, ipStr, ARRAY_SIZE(ipStr)); + // Release the socket to prevent closing it. + request.Release(); - stunEndpoint.ip = m_IP; - stunEndpoint.port = m_Port; + publicIP = m_IP; return true; } -void StunClient::SendHolePunchingMessages(ENetHost& enetClient, const std::string& serverAddress, u16 serverPort) +void SendHolePunchingMessages(ENetHost& enetClient, const std::string& serverAddress, u16 serverPort) { - // Convert ip string to int64 - ENetAddress addr; - addr.port = serverPort; - enet_address_set_host(&addr, serverAddress.c_str()); + Networking::Connection conn(enetClient.socket); + Networking::IPAddress target; + // TODO: once Enet supports IPV6, switch to IPV6. + conn.ResolveIP(target, serverAddress.c_str(), serverPort, true); int delay = 200; CFG_GET_VAL("lobby.stun.delay", delay); - // Send an UDP message from enet host to ip:port + std::vector buffer; + + // Send an UDP message from localhost to ip:port. for (int i = 0; i < 3; ++i) { - StunClient::SendStunRequest(enetClient, htonl(addr.host), serverPort); - usleep(delay * 1000); + conn.Send(buffer); + std::this_thread::sleep_for(std::chrono::milliseconds(delay)); } + + conn.Release(); +} } Index: source/network/scripting/JSInterface_Network.cpp =================================================================== --- source/network/scripting/JSInterface_Network.cpp +++ source/network/scripting/JSInterface_Network.cpp @@ -97,29 +97,6 @@ // Always use lobby authentication for lobby matches to prevent impersonation and smurfing, in particular through mods that implemented an UI for arbitrary or other players nicknames. bool hasLobby = !!g_XmppClient; g_NetServer = new CNetServer(hasLobby); - // In lobby, we send our public ip and port on request to the players, who want to connect. - // In either case we need to know our public IP. If using STUN, we'll use that, - // otherwise, the lobby's reponse to the game registration stanza will tell us our public IP. - if (hasLobby) - { - CStr ip; - if (!useSTUN) - // Don't store IP - the lobby bot will send it later. - // (if a client tries to connect before it's setup, they'll be disconnected) - g_NetServer->SetConnectionData("", serverPort, false); - else - { - u16 port = serverPort; - // This is using port variable to store return value, do not pass serverPort itself. - if (!StunClient::FindStunEndpointHost(ip, port)) - { - ScriptException::Raise(rq, "Failed to host via STUN."); - SAFE_DELETE(g_NetServer); - return; - } - g_NetServer->SetConnectionData(ip, port, true); - } - } if (!g_NetServer->SetupConnection(serverPort)) { @@ -128,6 +105,16 @@ return; } + // When hosting in the lobby, we'll need to know our public IP to perform hole-punching. + // If using STUN, we can use that to query it, so do this immediately. + // Otherwise, we'll rely on the lobby game registration stanza response. + if (hasLobby && useSTUN && !g_NetServer->SetConnectionDataViaSTUN()) + { + ScriptException::Raise(rq, "Failed to start server"); + SAFE_DELETE(g_NetServer); + return; + } + // Generate a secret to identify the host client. std::string secret = ps_generate_guid(); Index: source/network/tests/test_StunClient.h =================================================================== --- /dev/null +++ source/network/tests/test_StunClient.h @@ -0,0 +1,52 @@ +/* Copyright (C) 2021 Wildfire Games. + * This file is part of 0 A.D. + * + * 0 A.D. is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 2 of the License, or + * (at your option) any later version. + * + * 0 A.D. is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with 0 A.D. If not, see . + */ + +#include "lib/self_test.h" + +#include "ps/ConfigDB.h" +#include "ps/CStr.h" +#include "ps/Networking.h" +#include "network/StunClient.h" + +#include "lib/external_libraries/enet.h" + +class TestStunClient : public CxxTest::TestSuite +{ +public: + void setUp() + { + enet_initialize(); + } + + void tearDown() + { + enet_deinitialize(); + } + + void test_STUN() + { + CConfigDB::Initialise(); + g_ConfigDB.SetValueString(CFG_COMMAND, "lobby.stun.server", "lobby.wildfiregames.com"); + g_ConfigDB.SetValueString(CFG_COMMAND, "lobby.stun.port", "3478"); + Networking::IPAddress publicIP; + ENetAddress hostAddr{ ENET_HOST_ANY, 0 }; + ENetHost* client = enet_host_create(&hostAddr, 1, 1, 0, 0); + TS_ASSERT(StunClient::FindPublicIP(*client, publicIP)); + enet_host_destroy(client); + CConfigDB::Shutdown(); + } +}; Index: source/ps/Networking.h =================================================================== --- /dev/null +++ source/ps/Networking.h @@ -0,0 +1,181 @@ +/* Copyright (C) 2021 Wildfire Games. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#ifndef INCLUDED_NETWORKING +#define INCLUDED_NETWORKING + +#include + +#if OS_WIN +// Winsock2 introduces a bunch of defines that conflict with our code. +// To avoid that problem, we won't include it. But we still want +// sockaddr_storage here. +// So in relevant .cpp files, include _before_ this header, and things will work out. +#ifndef _WINSOCK2API_ +// See https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/4b77102f-769f-414d-b137-47cabfe8be8f +// SOCKADDR_STORAGE is 128 bytes, and must start with ss_family, a short. +typedef struct sockaddr_storage { + unsigned short ss_family; + char __ss_pad1[6]; + __int64 __ss_align; + char __ss_pad2[112]; +} SOCKADDR_STORAGE; +static_assert(sizeof(SOCKADDR_STORAGE) == 128); + +#define SOCK_STREAM 1 +#define SOCK_DGRAM 2 + +#endif +// These are always required. +struct sockaddr; +#define sockaddr_storage SOCKADDR_STORAGE +#else +# include +#endif + +#include +#include +#include + +/** + * Wrappers around networking-related functions. + * They are protocol-independent, implemented on the IPV6 stack. + * Ideally, these ought be replaced with standard functions if/when C++ gets those. + */ +namespace Networking +{ +using socket_t = int; +using socket_type = int; + +class IPAddress +{ +public: + IPAddress() = default; + IPAddress(sockaddr* addr, size_t size); + // IPV4 entry + IPAddress(std::array ipv4, u16 port); + IPAddress(u32 ipv4, u16 port); + // IPV6 entry + IPAddress(std::array ipv6, u16 port); + + bool Parse(const std::string& host, const std::string& service); + + const sockaddr* ToSocketAddress() const; + size_t AddressSize() const; + + std::string GetIP() const; + std::string GetIPV4() const; + u16 GetPort() const; +protected: + sockaddr_storage m_Address; + size_t m_Size; +}; + +/** + * Wrap the C API error codes in a way that avoids mistakes. + */ +class SError +{ + enum class Type + { + ERRNO, + GAI, + NOSOCKET + }; +public: + static SError Errno(int code) { return { Type::ERRNO, code }; } + static SError GaiErrno(int code) { return { Type::GAI, code }; } + static SError NoSocket() { return { Type::NOSOCKET, -1 }; } + static SError OK() { return {}; } + + SError() = default; + + const char* GetString() const; + + /** + * For convenience in C++ code, have a native conversion to bool for ifs. + */ + explicit operator bool() const + { + return m_Code == 0; + } + +protected: + SError(Type t, int code) : m_Type(t), m_Code(code) {} + Type m_Type; + int m_Code = 0; +}; + +template +class Connection +{ +public: + Connection() = default; + Connection(const Connection&) = delete; + Connection(Connection&&) = default; + ~Connection(); + + /** + * Create a connection using an existing socket. + * Note that the socket must be manually released or it will be closed on destruction. + */ + Connection(socket_t socket) : m_Socket(socket) {}; + + /** + * Release ownership of the socket. It will have to be closed manually; + */ + socket_t Release() { socket_t socket = m_Socket; m_Socket = -1; return socket; } + + /** + * Open a socket to the target @a address. The connection is ready to send/receive on success. + * If @a from is specified, try binding there & update it (so port 0 will be updated to a real port). + */ + SError OpenSocket(const IPAddress& address, IPAddress* from = nullptr); + + /** + * Close the socket. Leaves Connection in a valid state. Called automatically on destruction. + */ + void CloseSocket(); + + /** + * Parse an address name into an IP address usable for opening the socket. + * This needs the socket type. + */ + static SError ResolveIP(IPAddress& out, const char* serverName, u16 port = 0, bool forceIPV4 = false); + + SError Send(const std::vector& buffer); + SError SendTo(const IPAddress& target, const std::vector& buffer); + SError Receive(std::vector& buffer, int delay = 200, int max_tries = 25); + + explicit operator bool() { return m_Socket != -1; } + + /** + * Update @a ip with the parameters of the outgoing connection (the 'local IP'). + */ + SError GetOutboundIP(IPAddress& ip) const; + +protected: + socket_t m_Socket = -1; +}; +} + +#endif // INCLUDED_NETWORKING Index: source/ps/Networking.cpp =================================================================== --- /dev/null +++ source/ps/Networking.cpp @@ -0,0 +1,342 @@ +/* Copyright (C) 2021 Wildfire Games. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +#include "precompiled.h" + +// NB : those would usually be included after, but for practical reasons must come before Networking.h +#if OS_WIN +# include +# include +#undef gai_strerror +// Force ANSI variant to simplify code. +#define gai_strerror gai_strerrorA +#endif + +#include "Networking.h" + +#if !OS_WIN +# include +# include +# include +# include +#endif + +#include + +/** + * NB: there are a bunch of suspiciously UB-like reinterpret_cast here because of the C socket API. + * There's no real way to fix it, and I don't expect any reasonable compiler to break. + */ + +namespace Networking +{ +IPAddress::IPAddress(sockaddr* addr, size_t size) +{ + memset(&m_Address, 0, sizeof(m_Address)); + memcpy(&m_Address, addr, size); + m_Size = size; +} + +IPAddress::IPAddress(const u32 ipv4, u16 port) +{ + *this = IPAddress(std::array{ + static_cast(ipv4 >> 24 & 0xff), + static_cast(ipv4 >> 16 & 0xff), + static_cast(ipv4 >> 8 & 0xff), + static_cast(ipv4 & 0xff) + }, port); +} + +IPAddress::IPAddress(std::array ipv4, u16 port) +{ + memset(&m_Address, 0, sizeof(m_Address)); + sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(port); + addr.sin6_addr.s6_addr[10] = 0xff; + addr.sin6_addr.s6_addr[11] = 0xff; + for (size_t i = 0; i < ipv4.size(); ++i) + addr.sin6_addr.s6_addr[12+i] = ipv4[i]; + memcpy(&m_Address, &addr, sizeof(addr)); + m_Size = sizeof(addr); +} + +IPAddress::IPAddress(std::array ipv6, u16 port) +{ + memset(&m_Address, 0, sizeof(m_Address)); + sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = htons(port); + for (size_t i = 0; i < ipv6.size(); ++i) + addr.sin6_addr.s6_addr[i] = ipv6[i]; + memcpy(&m_Address, &addr, sizeof(addr)); + m_Size = sizeof(addr); +} + +bool IPAddress::Parse(const std::string& host, const std::string& service) +{ + return !!Connection::ResolveIP(*this, host.c_str(), strtol(service.c_str(), nullptr, 10)); +} + +const sockaddr* IPAddress::ToSocketAddress() const +{ + return reinterpret_cast(&m_Address); +} + +size_t IPAddress::AddressSize() const +{ + return m_Size; +} + +std::string IPAddress::GetIP() const +{ + char buf[200]; + int status = getnameinfo(reinterpret_cast(&m_Address), sizeof(m_Address), buf, ARRAY_SIZE(buf), nullptr, 0, NI_NOFQDN | NI_NUMERICHOST); + if (status != 0) + { +#ifndef NDEBUG + // Error logging turned off in release mode because it's un-necessarily noisy. + LOGERROR("GetIP error (%d): %s", status, gai_strerror(status)); +#endif + return ""; + } + return buf; +} + +std::string IPAddress::GetIPV4() const +{ + std::string ret; + if (m_Size != sizeof(sockaddr_in6)) + return ret; + + sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + memcpy(&addr, &m_Address, m_Size); + for (size_t i = 0; i < 10; ++i) + if (addr.sin6_addr.s6_addr[i] != 0x00) + return ret; + if (addr.sin6_addr.s6_addr[10] != 0xff || addr.sin6_addr.s6_addr[11] != 0xff) + return ret; + for (size_t i = 12; i < 16; ++i) + { + ret += std::to_string(addr.sin6_addr.s6_addr[i]); + if (i != 15) + ret += "."; + } + return ret; +} + +u16 IPAddress::GetPort() const +{ + char buf[100]; + int status = getnameinfo(reinterpret_cast(&m_Address), sizeof(m_Address), nullptr, 0, buf, ARRAY_SIZE(buf), NI_NUMERICSERV); + if (status != 0) + { +#ifndef NDEBUG + // Error logging turned off in release mode because it's un-necessarily noisy. + LOGERROR("GetIP error (%d): %s", status, gai_strerror(status)); +#endif + return 0; + } + return strtol(buf, nullptr, 10); +} + +const char* SError::GetString() const +{ + if (m_Code == 0) + return "No error"; + + if (m_Type == Type::ERRNO) + return strerror(m_Code); + else if (m_Type == Type::GAI) + return gai_strerror(m_Code); + else + return "Invalid socket"; +} + +template +Connection::~Connection() +{ + CloseSocket(); +} + +template +SError Connection::ResolveIP(IPAddress& out, const char* serverName, u16 port, bool forceIPV4) +{ + addrinfo hints; + addrinfo* res; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = forceIPV4 ? AF_INET : AF_INET6; + hints.ai_socktype = SOCKET_TYPE; + // AI_V4MAPPED means "return an ipv4-mapped Ipv6 address if no ipv6 address is found", + // thus ensuring IPV4 compability (and it does nothing in AF_INET mode). + // AI_ADDRCONFIG essentially means "only return what I am able to understand". + hints.ai_flags = AI_V4MAPPED | AI_ADDRCONFIG; + + std::string service = port == 0 ? "" : std::to_string(port); + + // Resolve the stun server name so we can send it a STUN request + int status = getaddrinfo(serverName, port == 0 ? nullptr : service.c_str(), &hints, &res); + if (status != 0) + return SError::Errno(status); + + ENSURE(res); + + // Multiple results may be returned - assume the first is OK for what we want, given our flags above. + out = IPAddress(res->ai_addr, res->ai_addrlen); + + freeaddrinfo(res); + + return SError::OK(); +} + +template +SError Connection::OpenSocket(const IPAddress& address, IPAddress* from) +{ + m_Socket = socket(PF_INET6, SOCKET_TYPE, 0); + if (m_Socket == -1) + return SError::NoSocket(); + + if (from) + { + if (bind(m_Socket, from->ToSocketAddress(), from->AddressSize()) == -1) + { + CloseSocket(); + return SError::Errno(errno); + } + } + + { +#if OS_WIN + u_long nonblocking = 1; + ioctlsocket(m_Socket, FIONBIO, &nonblocking); +#elif OS_UNIX + int nonblocking = 1; + fcntl(m_Socket, F_SETFL, O_NONBLOCK, &nonblocking); +#else +# warning "No known way to set sockets as non-blocking, sockets will be blocking." +#endif + } + + if (connect(m_Socket, address.ToSocketAddress(), address.AddressSize()) == -1) + { + CloseSocket(); + return SError::Errno(errno); + } + + if (from) + GetOutboundIP(*from); + + return SError::OK(); +} + +template +void Connection::CloseSocket() +{ + if (m_Socket == -1) + return; + +#if OS_WIN + closesocket(m_Socket); +#elif OS_UNIX + close(m_Socket); +#else +# error "This system has no know way to close a socket." +#endif + + m_Socket = -1; +} + +template +SError Connection::Send(const std::vector& buffer) +{ + if (m_Socket == -1) + return SError::NoSocket(); + + if (send(m_Socket, reinterpret_cast(buffer.data()), static_cast(buffer.size()), 0) == -1) + return SError::Errno(errno); + + return SError::OK(); +} + +template +SError Connection::SendTo(const IPAddress& ip, const std::vector& buffer) +{ + if (m_Socket == -1) + return SError::NoSocket(); + + if (sendto(m_Socket, reinterpret_cast(buffer.data()), static_cast(buffer.size()), 0, + ip.ToSocketAddress(), ip.AddressSize()) == -1) + return SError::Errno(errno); + + return SError::OK(); +} + +template +SError Connection::Receive(std::vector& buffer, int delay, int max_tries) +{ + if (m_Socket == -1) + return SError::NoSocket(); + + if (!buffer.size()) + buffer.resize(2048); + size_t size = buffer.size(); + + memset(buffer.data(), 0, size); + + int recvlen = recv(m_Socket, reinterpret_cast(buffer.data()), size, 0); + // Retry a few times, sockets are non-blocking. + for (int count = 0; recvlen < 0 && (count < max_tries || max_tries == -1); ++count) + { + usleep(delay * 1000); + recvlen = recv(m_Socket, reinterpret_cast(buffer.data()), size, 0); + } + + if (recvlen < 0) + return SError::Errno(errno); + + return SError::OK(); +} + +template +SError Connection::GetOutboundIP(IPAddress& ip) const +{ + if (m_Socket == -1) + return SError::NoSocket(); + + sockaddr_in6 addr; + socklen_t addrlen = sizeof(addr); + int status = getsockname(m_Socket, reinterpret_cast(&addr), &addrlen); + if (status == -1) + return SError::Errno(errno); + + ip = IPAddress(reinterpret_cast(&addr), addrlen); + + return SError::OK(); +} +} // namespace Networking + +template class Networking::Connection; +template class Networking::Connection;