diff --git a/Common/src/asserts.cpp b/Common/src/asserts.cpp index b64456a..bb2df55 100644 --- a/Common/src/asserts.cpp +++ b/Common/src/asserts.cpp @@ -26,7 +26,7 @@ namespace Common // Try to break if a debugger exists and then exit the application if ( IsDebuggerPresent() ) { - DebugBreak(); + DEBUG_BREAK(); } exit( -1 ); } diff --git a/Common/src/asserts.h b/Common/src/asserts.h index 414c331..ea9e0c7 100644 --- a/Common/src/asserts.h +++ b/Common/src/asserts.h @@ -1,12 +1,24 @@ #pragma once #ifdef DEBUG + #if defined( _MSC_VER ) + #define DEBUG_BREAK() __debugbreak() + #elif defined( __GNUC__ ) || defined( __clang__ ) + #include + #define DEBUG_BREAK() raise( SIGTRAP ) + #else + // raise(SIGTRAP) is like the most standard way to do a debug break in unix-like systems + #include + #define DEBUG_BREAK() raise( SIGTRAP ) + #endif + #define ASSERT( expression, text_format, ... ) \ if ( !( expression ) ) \ { \ Common::ForceCrash( #expression, text_format, ##__VA_ARGS__ ); \ } #else + #define DEBUG_BREAK() #define ASSERT( expression, text_format, ... ) #endif namespace Common diff --git a/DemoGame/src/client/systems/client_local_player_predictor_system.cpp b/DemoGame/src/client/systems/client_local_player_predictor_system.cpp index 190cde9..13a99f1 100644 --- a/DemoGame/src/client/systems/client_local_player_predictor_system.cpp +++ b/DemoGame/src/client/systems/client_local_player_predictor_system.cpp @@ -131,7 +131,7 @@ void ClientLocalPlayerPredictorSystem::ExecuteLocalPrediction( Engine::ECS::Worl void ClientLocalPlayerPredictorSystem::Execute( Engine::ECS::World& world, float32 elapsed_time ) { const NetworkPeerGlobalComponent& networkPeerComponent = world.GetGlobalComponent< NetworkPeerGlobalComponent >(); - if ( networkPeerComponent.peer->GetConnectionState() != NetLib::PCS_Connected ) + if ( networkPeerComponent.peer->GetConnectionState() != NetLib::PeerConnectionState::Connected ) { return; } diff --git a/DemoGame/src/client/systems/client_local_player_server_reconciliator_system.cpp b/DemoGame/src/client/systems/client_local_player_server_reconciliator_system.cpp index 20ab675..38d6c27 100644 --- a/DemoGame/src/client/systems/client_local_player_server_reconciliator_system.cpp +++ b/DemoGame/src/client/systems/client_local_player_server_reconciliator_system.cpp @@ -139,7 +139,7 @@ static void EvaluateReconciliation( Engine::ECS::GameEntity& entity, const Netwo void ClientLocalPlayerServerReconciliatorSystem::Execute( Engine::ECS::World& world, float32 elapsed_time ) { const NetworkPeerGlobalComponent& networkPeerComponent = world.GetGlobalComponent< NetworkPeerGlobalComponent >(); - if ( networkPeerComponent.peer->GetConnectionState() != NetLib::PCS_Connected ) + if ( networkPeerComponent.peer->GetConnectionState() != NetLib::PeerConnectionState::Connected ) { return; } diff --git a/DemoGame/src/server/systems/server_hit_registration_system.cpp b/DemoGame/src/server/systems/server_hit_registration_system.cpp index dba1f4e..ae8382d 100644 --- a/DemoGame/src/server/systems/server_hit_registration_system.cpp +++ b/DemoGame/src/server/systems/server_hit_registration_system.cpp @@ -1,6 +1,7 @@ #include "server_hit_registration_system.h" #include "logger.h" +#include "asserts.h" #include "AlgorithmUtils.h" #include "ecs/world.h" @@ -202,12 +203,9 @@ static void RollbackEntities( Engine::ECS::World& world, float32 serverTime ) int32 previousIndex = -1; int32 nextIndex = -1; FindPreviousAndNextTimeIndexes( transformHistoryComponent, serverTime, previousIndex, nextIndex ); - if ( nextIndex < 0 ) - { - bool a = true; - } + // TODO Investigate hit reg issue that is hitting this assert. - assert( nextIndex >= 0 ); + ASSERT( nextIndex >= 0, "ServerHitRegistrationSystem.%s Couldn't find next index" ); Engine::TransformComponent& transform = it->GetComponent< Engine::TransformComponent >(); diff --git a/DemoGame/src/shared/InputState.cpp b/DemoGame/src/shared/InputState.cpp index 1f332de..8377107 100644 --- a/DemoGame/src/shared/InputState.cpp +++ b/DemoGame/src/shared/InputState.cpp @@ -1,5 +1,7 @@ #include "InputState.h" +#include "logger.h" + #include "core/buffer.h" InputState::InputState() @@ -32,14 +34,27 @@ void InputState::Serialize( NetLib::Buffer& buffer ) const buffer.WriteByte( isShooting ? 1 : 0 ); } -void InputState::Deserialize( NetLib::Buffer& buffer ) +bool InputState::Deserialize( NetLib::Buffer& buffer ) { + if ( buffer.GetRemainingSize() < GetSize() ) + { + LOG_ERROR( "Not enough data in buffer to read InputState." ); + return false; + } + tick = buffer.ReadInteger(); serverTime = buffer.ReadFloat(); movement.X( buffer.ReadFloat() ); movement.Y( buffer.ReadFloat() ); + if ( movement.X() >= MAXIMUM_MOVEMENT_VALUE || movement.Y() >= MAXIMUM_MOVEMENT_VALUE ) + { + LOG_ERROR( "[InputState::%s] Movement value too high, possible cheating detected. X: %.3f, Y: %.3f", + THIS_FUNCTION_NAME, movement.X(), movement.Y() ); + return false; + } + virtualMousePosition.X( buffer.ReadFloat() ); virtualMousePosition.Y( buffer.ReadFloat() ); @@ -48,4 +63,6 @@ void InputState::Deserialize( NetLib::Buffer& buffer ) const uint8 isShootingByte = buffer.ReadByte(); isShooting = ( isShootingByte == 1 ) ? true : false; + + return true; } diff --git a/DemoGame/src/shared/InputState.h b/DemoGame/src/shared/InputState.h index e1071eb..af04498 100644 --- a/DemoGame/src/shared/InputState.h +++ b/DemoGame/src/shared/InputState.h @@ -12,7 +12,9 @@ class InputState : public NetLib::IInputState int32 GetSize() const override; void Serialize( NetLib::Buffer& buffer ) const override; - void Deserialize( NetLib::Buffer& buffer ) override; + bool Deserialize( NetLib::Buffer& buffer ) override; + + static constexpr float32 MAXIMUM_MOVEMENT_VALUE = 10.f; // Header fields uint32 tick; diff --git a/DemoGame/src/shared/systems/pre_tick_network_system.cpp b/DemoGame/src/shared/systems/pre_tick_network_system.cpp index cbf4b94..138d6a5 100644 --- a/DemoGame/src/shared/systems/pre_tick_network_system.cpp +++ b/DemoGame/src/shared/systems/pre_tick_network_system.cpp @@ -25,7 +25,7 @@ void PreTickNetworkSystem::Execute( Engine::ECS::World& world, float32 elapsed_t { NetworkPeerGlobalComponent& networkPeerComponent = world.GetGlobalComponent< NetworkPeerGlobalComponent >(); - if ( networkPeerComponent.peer->GetConnectionState() == NetLib::PCS_Disconnected ) + if ( networkPeerComponent.peer->GetConnectionState() == NetLib::PeerConnectionState::Disconnected ) { networkPeerComponent.peer->Start( _ip, _port ); } diff --git a/NetworkLibrary/src/Core/Address.cpp b/NetworkLibrary/src/Core/Address.cpp index d46c750..37203d9 100644 --- a/NetworkLibrary/src/Core/Address.cpp +++ b/NetworkLibrary/src/Core/Address.cpp @@ -27,6 +27,11 @@ namespace NetLib return !( *this == other ); } + bool Address::IsValid() const + { + return *this != Address::GetInvalid(); + } + void Address::GetFull( std::string& buffer ) const { buffer.append( _ip.c_str() ); diff --git a/NetworkLibrary/src/Core/Address.h b/NetworkLibrary/src/Core/Address.h index 850c9ef..bfcd7bc 100644 --- a/NetworkLibrary/src/Core/Address.h +++ b/NetworkLibrary/src/Core/Address.h @@ -22,12 +22,13 @@ namespace NetLib static Address GetInvalid() { return Address( "0.0.0.0", 0 ); } Address( const std::string& ip, uint32 port ); - Address( const Address& other ) = default; bool operator==( const Address& other ) const; bool operator!=( const Address& other ) const; + bool IsValid() const; + uint32 GetPort() const { return _port; } const std::string& GetIP() const { return _ip; } void GetFull( std::string& buffer ) const; @@ -48,4 +49,13 @@ namespace NetLib friend class Socket; }; + + struct AddressHasher + { + size_t operator()( const Address& address ) const noexcept + { + return std::hash< std::string >()( address.GetIP() ) ^ + ( std::hash< uint32 >()( address.GetPort() ) << 1 ); + } + }; } // namespace NetLib diff --git a/NetworkLibrary/src/Core/Buffer.cpp b/NetworkLibrary/src/Core/Buffer.cpp index 8c666e0..5c83247 100644 --- a/NetworkLibrary/src/Core/Buffer.cpp +++ b/NetworkLibrary/src/Core/Buffer.cpp @@ -1,6 +1,8 @@ #include "buffer.h" -#include +#include "Logger.h" +#include "asserts.h" + #include #include @@ -20,14 +22,16 @@ namespace NetLib void Buffer::CopyUsedData( uint8* dst, uint32 dst_size ) const { - assert( dst_size >= _index ); + ASSERT( dst_size >= _index, "Buffer.%s The destination buffer is smaller than the used data to copy.", + THIS_FUNCTION_NAME ); std::memcpy( dst, _data, _index ); } void Buffer::WriteLong( uint64 value ) { - assert( _index + 8 <= _size ); + ASSERT( _index + 8 <= _size, "Buffer.%s Write operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); *( ( uint64* ) ( _data + _index ) ) = value; _index += 8; @@ -35,7 +39,8 @@ namespace NetLib void Buffer::WriteInteger( uint32 value ) { - assert( _index + 4 <= _size ); + ASSERT( _index + 4 <= _size, "Buffer.%s Write operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); *( ( uint32* ) ( _data + _index ) ) = value; _index += 4; @@ -43,7 +48,8 @@ namespace NetLib void Buffer::WriteShort( uint16 value ) { - assert( _index + 2 <= _size ); + ASSERT( _index + 2 <= _size, "Buffer.%s Write operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); *( ( uint16* ) ( _data + _index ) ) = value; _index += 2; @@ -51,7 +57,8 @@ namespace NetLib void Buffer::WriteByte( uint8 value ) { - assert( _index + 1 <= _size ); + ASSERT( _index + 1 <= _size, "Buffer.%s Write operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); *( ( uint8* ) ( _data + _index ) ) = value; ++_index; @@ -59,15 +66,26 @@ namespace NetLib void Buffer::WriteFloat( float32 value ) { - assert( _index + 4 <= _size ); + ASSERT( _index + 4 <= _size, "Buffer.%s Write operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); // This memcpy needs to be done using memcpy to keep the bits configuration too std::memcpy( ( _data + _index ), &value, sizeof( uint32 ) ); _index += 4; } + void Buffer::WriteData( const uint8* data, uint32 size ) + { + ASSERT( data != nullptr, "Buffer.%s Can't write nullptr data.", THIS_FUNCTION_NAME ); + ASSERT( _index + size <= _size, "Buffer.%s Write operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); + std::memcpy( _data + _index, data, size ); + _index += size; + } + uint64 Buffer::ReadLong() { - assert( _index + 8 <= _size ); + ASSERT( _index + 8 <= _size, "Buffer.%s Read operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); uint64 value; value = *( ( uint64* ) ( _data + _index ) ); @@ -75,9 +93,24 @@ namespace NetLib return value; } + bool Buffer::ReadLong( uint64& value ) + { + if ( _index + 8 <= _size ) + { + value = *( ( uint64* ) ( _data + _index ) ); + _index += 8; + return true; + } + else + { + return false; + } + } + uint32 Buffer::ReadInteger() { - assert( _index + 4 <= _size ); + ASSERT( _index + 4 <= _size, "Buffer.%s Read operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); uint32 value; value = *( ( uint32* ) ( _data + _index ) ); @@ -85,9 +118,25 @@ namespace NetLib return value; } + bool Buffer::ReadInteger( uint32& value ) + { + if ( _index + 4 <= _size ) + { + value = *( ( uint32* ) ( _data + _index ) ); + + _index += 4; + return true; + } + else + { + return false; + } + } + uint16 Buffer::ReadShort() { - assert( _index + 2 <= _size ); + ASSERT( _index + 2 <= _size, "Buffer.%s Read operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); uint16 value; value = *( ( uint16* ) ( _data + _index ) ); @@ -97,9 +146,24 @@ namespace NetLib return value; } + bool Buffer::ReadShort( uint16& value ) + { + if ( _index + 2 <= _size ) + { + value = *( ( uint16* ) ( _data + _index ) ); + _index += 2; + return true; + } + else + { + return false; + } + } + uint8 Buffer::ReadByte() { - assert( _index + 1 <= _size ); + ASSERT( _index + 1 <= _size, "Buffer.%s Read operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); uint8 value; value = *( ( uint8* ) ( _data + _index ) ); @@ -109,9 +173,24 @@ namespace NetLib return value; } + bool Buffer::ReadByte( uint8& value ) + { + if ( _index + 1 <= _size ) + { + value = *( ( uint8* ) ( _data + _index ) ); + ++_index; + return true; + } + else + { + return false; + } + } + float32 Buffer::ReadFloat() { - assert( _index + 4 <= _size ); + ASSERT( _index + 4 <= _size, "Buffer.%s Read operation exceeds buffer bounds. Size: %u, Remaining: %u", + THIS_FUNCTION_NAME, _size, _size - _index ); float32 value; // This memcpy needs to be done using memcpy to recover the bits configuration too std::memcpy( &value, ( _data + _index ), sizeof( float32 ) ); @@ -120,6 +199,21 @@ namespace NetLib return value; } + bool Buffer::ReadData( uint8* data, uint32 size ) + { + ASSERT( data != nullptr, "Buffer.%s Can't read into nullptr data.", THIS_FUNCTION_NAME ); + if ( _index + size <= _size ) + { + std::memcpy( data, _data + _index, size ); + _index += size; + return true; + } + else + { + return false; + } + } + void Buffer::ResetAccessIndex() { _index = 0; diff --git a/NetworkLibrary/src/Core/Buffer.h b/NetworkLibrary/src/Core/Buffer.h index d2d8b27..6aa2be1 100644 --- a/NetworkLibrary/src/Core/Buffer.h +++ b/NetworkLibrary/src/Core/Buffer.h @@ -12,6 +12,7 @@ namespace NetLib ~Buffer() {} uint32 GetSize() const { return _size; } + uint32 GetRemainingSize() const { return _size - _index; } uint8* GetData() const { return _data; } uint32 GetAccessIndex() const { return _index; } void Clear(); @@ -23,12 +24,18 @@ namespace NetLib void WriteShort( uint16 value ); void WriteByte( uint8 value ); void WriteFloat( float32 value ); + void WriteData( const uint8* data, uint32 size ); uint64 ReadLong(); + bool ReadLong( uint64& value ); uint32 ReadInteger(); + bool ReadInteger( uint32& value ); uint16 ReadShort(); + bool ReadShort( uint16& value ); uint8 ReadByte(); + bool ReadByte( uint8& value ); float32 ReadFloat(); + bool ReadData( uint8* data, uint32 size ); void ResetAccessIndex(); diff --git a/NetworkLibrary/src/Core/Client.cpp b/NetworkLibrary/src/Core/Client.cpp index eaf5015..dbe04e8 100644 --- a/NetworkLibrary/src/Core/Client.cpp +++ b/NetworkLibrary/src/Core/Client.cpp @@ -3,6 +3,7 @@ #include #include "logger.h" +#include "asserts.h" #include "core/remote_peer.h" #include "core/time_clock.h" @@ -19,17 +20,12 @@ namespace NetLib : Peer( PeerType::CLIENT, 1, 1024, 1024 ) , _serverAddress( "127.0.0.1", 54000 ) , inGameMessageID( 0 ) - , _currentState( ClientState::CS_Disconnected ) , _replicationMessagesProcessor() , _clientIndex( 0 ) , _timeSyncer() { } - Client::~Client() - { - } - bool Client::StartClient( const std::string& server_ip, uint32 server_port ) { return Start( server_ip, 0 ); // Port is zero so the system picks up a random port number @@ -63,7 +59,7 @@ namespace NetLib uint32 Client::GetLocalClientId() const { - if ( GetConnectionState() != PeerConnectionState::PCS_Connected ) + if ( GetConnectionState() != PeerConnectionState::Connected ) { LOG_WARNING( "Can not get the local client ID if the peer is not connected" ); return 0; @@ -76,10 +72,7 @@ namespace NetLib { BindSocket( Address( "127.0.0.1", 0 ) ); // Port is zero so the system picks up a random port number - _currentState = ClientState::CS_SendingConnectionRequest; - - uint64 clientSalt = GenerateClientSaltNumber(); - AddRemotePeer( _serverAddress, 0, clientSalt, 0 ); + _connectionManager.StartConnectingToAddress( _serverAddress ); SubscribeToOnRemotePeerDisconnect( [ this ]( uint32 ) @@ -92,49 +85,14 @@ namespace NetLib return true; } - uint64 Client::GenerateClientSaltNumber() - { - // TODO Change this for a better generator. rand is not generating a full 64bit integer since its maximum is - // roughly 32767. I have tried to use mt19937_64 but I think I get a conflict with winsocks and - // std::uniform_int_distribution - srand( static_cast< uint32 >( time( NULL ) ) ); - return rand(); - } - void Client::ProcessMessageFromPeer( const Message& message, RemotePeer& remotePeer ) { MessageType messageType = message.GetHeader().type; switch ( messageType ) { - case MessageType::ConnectionChallenge: - if ( _currentState == ClientState::CS_SendingConnectionRequest || - _currentState == ClientState::CS_SendingConnectionChallengeResponse ) - { - const ConnectionChallengeMessage& connectionChallengeMessage = - static_cast< const ConnectionChallengeMessage& >( message ); - ProcessConnectionChallenge( connectionChallengeMessage, remotePeer ); - } - break; - case MessageType::ConnectionAccepted: - if ( _currentState == ClientState::CS_SendingConnectionChallengeResponse ) - { - const ConnectionAcceptedMessage& connectionAcceptedMessage = - static_cast< const ConnectionAcceptedMessage& >( message ); - ProcessConnectionRequestAccepted( connectionAcceptedMessage, remotePeer ); - } - break; - case MessageType::ConnectionDenied: - if ( _currentState == ClientState::CS_SendingConnectionChallengeResponse || - _currentState == ClientState::CS_SendingConnectionRequest ) - { - const ConnectionDeniedMessage& connectionDeniedMessage = - static_cast< const ConnectionDeniedMessage& >( message ); - ProcessConnectionRequestDenied( connectionDeniedMessage ); - } - break; case MessageType::Disconnection: - if ( _currentState == ClientState::CS_Connected ) + if ( GetConnectionState() == PeerConnectionState::Connected ) { const DisconnectionMessage& disconnectionMessage = static_cast< const DisconnectionMessage& >( message ); @@ -142,7 +100,7 @@ namespace NetLib } break; case MessageType::TimeResponse: - if ( _currentState == ClientState::CS_Connected ) + if ( GetConnectionState() == PeerConnectionState::Connected ) { const TimeResponseMessage& timeResponseMessage = static_cast< const TimeResponseMessage& >( message ); @@ -150,7 +108,7 @@ namespace NetLib } break; case MessageType::Replication: - if ( _currentState == ClientState::CS_Connected ) + if ( GetConnectionState() == PeerConnectionState::Connected ) { const ReplicationMessage& replicationMessage = static_cast< const ReplicationMessage& >( message ); ProcessReplicationAction( replicationMessage ); @@ -166,29 +124,9 @@ namespace NetLib } } - void Client::ProcessMessageFromUnknownPeer( const Message& message, const Address& address ) - { - LOG_WARNING( "Client does not process messages from unknown peers. Ignoring it..." ); - } - void Client::TickConcrete( float32 elapsedTime ) { - if ( _currentState == ClientState::CS_SendingConnectionRequest || - _currentState == ClientState::CS_SendingConnectionChallengeResponse ) - { - RemotePeer* remotePeer = _remotePeersHandler.GetRemotePeerFromAddress( _serverAddress ); - if ( remotePeer == nullptr ) - { - LOG_ERROR( "Can't create new Connection Request Message because there is no remote peer corresponding " - "to IP: %s", - _serverAddress.GetIP() ); - return; - } - - CreateConnectionRequestMessage( *remotePeer ); - } - - if ( _currentState == ClientState::CS_Connected ) + if ( GetConnectionState() == PeerConnectionState::Connected ) { RemotePeer* serverRemotePeer = _remotePeersHandler.GetRemotePeerFromAddress( _serverAddress ); if ( serverRemotePeer == nullptr ) @@ -204,51 +142,15 @@ namespace NetLib bool Client::StopConcrete() { - _currentState = ClientState::CS_Disconnected; return true; } - void Client::ProcessConnectionChallenge( const ConnectionChallengeMessage& message, RemotePeer& remotePeer ) + void Client::OnPendingConnectionAccepted( const Connection::SuccessConnectionData& data ) { - LOG_INFO( "Challenge packet received from server" ); + ASSERT( data.startedLocally, + "Client-side can't receive a connection accepted apart from the one that was started locally." ); - uint64 clientSalt = message.clientSalt; - uint64 serverSalt = message.serverSalt; - if ( remotePeer.GetClientSalt() != clientSalt ) - { - LOG_WARNING( "The generated salt number does not match the server's challenge client salt number. Aborting " - "operation" ); - return; - } - - remotePeer.SetServerSalt( serverSalt ); - - _currentState = ClientState::CS_SendingConnectionChallengeResponse; - - CreateConnectionChallengeResponse( remotePeer ); - - LOG_INFO( "Sending challenge response packet to server..." ); - } - - void Client::ProcessConnectionRequestAccepted( const ConnectionAcceptedMessage& message, RemotePeer& remotePeer ) - { - if ( remotePeer.GeturrentState() == RemotePeerState::Connected ) - { - LOG_INFO( "The server's remote peer is already connected. Ignoring message" ); - return; - } - - uint64 remoteDataPrefix = message.prefix; - if ( remoteDataPrefix != remotePeer.GetDataPrefix() ) - { - LOG_WARNING( "Packet prefix does not match. Skipping packet..." ); - return; - } - - ConnectRemotePeer( remotePeer ); - - _clientIndex = message.clientIndexAssigned; - _currentState = ClientState::CS_Connected; + _clientIndex = data.clientSideId; // TODO Do not hardcode it like this. It might looks weird _replicationMessagesProcessor.SetLocalClientId( _clientIndex ); @@ -257,12 +159,10 @@ namespace NetLib ExecuteOnLocalPeerConnect(); } - void Client::ProcessConnectionRequestDenied( const ConnectionDeniedMessage& message ) + void Client::OnPendingConnectionDenied( const Connection::FailedConnectionData& data ) { - LOG_INFO( "Processing connection denied" ); - ConnectionFailedReasonType reason = static_cast< ConnectionFailedReasonType >( message.reason ); - - RequestStop( false, reason ); + LOG_INFO( "Client.%s Connection denied. Reason: %u", THIS_FUNCTION_NAME, static_cast< uint8 >( data.reason ) ); + RequestStop( false, data.reason ); } void Client::ProcessDisconnection( const DisconnectionMessage& message, RemotePeer& remotePeer ) @@ -277,7 +177,8 @@ namespace NetLib LOG_INFO( "Disconnection message received from server with reason code equal to %hhu. Disconnecting...", message.reason ); - StartDisconnectingRemotePeer( remotePeer.GetClientIndex(), false, ConnectionFailedReasonType::CFR_UNKNOWN ); + StartDisconnectingRemotePeer( remotePeer.GetClientIndex(), false, + Connection::ConnectionFailedReasonType::UNKNOWN ); } void Client::ProcessTimeResponse( const TimeResponseMessage& message ) @@ -291,51 +192,6 @@ namespace NetLib _replicationMessagesProcessor.Client_ProcessReceivedReplicationMessage( message ); } - void Client::CreateConnectionRequestMessage( RemotePeer& remotePeer ) - { - // Get a connection request message - std::unique_ptr< Message > message = _messageFactory.LendMessage( MessageType::ConnectionRequest ); - - if ( message == nullptr ) - { - LOG_ERROR( - "Can't create new Connection Request Message because the MessageFactory has returned a null message" ); - return; - } - - std::unique_ptr< ConnectionRequestMessage > connectionRequestMessage( - static_cast< ConnectionRequestMessage* >( message.release() ) ); - - // Set connection request fields - connectionRequestMessage->clientSalt = remotePeer.GetClientSalt(); - - // Store message in server's pending connection in order to send it - remotePeer.AddMessage( std::move( connectionRequestMessage ) ); - - LOG_INFO( "Connection request created." ); - } - - void Client::CreateConnectionChallengeResponse( RemotePeer& remotePeer ) - { - // Get a connection challenge message - std::unique_ptr< Message > message = _messageFactory.LendMessage( MessageType::ConnectionChallengeResponse ); - if ( message == nullptr ) - { - LOG_ERROR( "Can't create new Connection Challenge Response Message because the MessageFactory has returned " - "a null message" ); - return; - } - - std::unique_ptr< ConnectionChallengeResponseMessage > connectionChallengeResponseMessage( - static_cast< ConnectionChallengeResponseMessage* >( message.release() ) ); - - // Set connection challenge fields - connectionChallengeResponseMessage->prefix = remotePeer.GetDataPrefix(); - - // Store message in server's pending connection in order to send it - remotePeer.AddMessage( std::move( connectionChallengeResponseMessage ) ); - } - void Client::OnServerDisconnect() { LOG_INFO( "ON SERVER DISCONNECT" ); diff --git a/NetworkLibrary/src/Core/Client.h b/NetworkLibrary/src/Core/Client.h index e3131ca..6484480 100644 --- a/NetworkLibrary/src/Core/Client.h +++ b/NetworkLibrary/src/Core/Client.h @@ -22,16 +22,6 @@ namespace NetLib class ReplicationMessage; class IInputState; - // TODO There's a redundancy between ClientState and PeerConnectionState. We should make the ClientState to only be - // for the different connection states and let the Peer enum to decide the final connect disconnect - enum ClientState - { - CS_Disconnected = 0, - CS_Connected = 1, - CS_SendingConnectionRequest = 2, - CS_SendingConnectionChallengeResponse = 3, - }; - class Client : public Peer { public: @@ -40,8 +30,6 @@ namespace NetLib Client& operator=( const Client& ) = delete; - ~Client() override; - bool StartClient( const std::string& server_ip, uint32 server_port ); void SendInputs( const IInputState& inputState ); @@ -56,29 +44,22 @@ namespace NetLib protected: bool StartConcrete( const std::string& ip, uint32 port ) override; void ProcessMessageFromPeer( const Message& message, RemotePeer& remotePeer ) override; - void ProcessMessageFromUnknownPeer( const Message& message, const Address& address ) override; void TickConcrete( float32 elapsedTime ) override; bool StopConcrete() override; - void InternalOnRemotePeerConnect( RemotePeer& remote_peer ) override {}; + void OnPendingConnectionAccepted( const Connection::SuccessConnectionData& data ) override; + void OnPendingConnectionDenied( const Connection::FailedConnectionData& data ) override; + void InternalOnRemotePeerDisconnect( const RemotePeer& remote_peer ) override {}; private: - uint64 GenerateClientSaltNumber(); - void ProcessConnectionChallenge( const ConnectionChallengeMessage& message, RemotePeer& remotePeer ); - void ProcessConnectionRequestAccepted( const ConnectionAcceptedMessage& message, RemotePeer& remotePeer ); - void ProcessConnectionRequestDenied( const ConnectionDeniedMessage& message ); void ProcessDisconnection( const DisconnectionMessage& message, RemotePeer& remotePeer ); void ProcessTimeResponse( const TimeResponseMessage& message ); void ProcessReplicationAction( const ReplicationMessage& message ); - void CreateConnectionRequestMessage( RemotePeer& remotePeer ); - void CreateConnectionChallengeResponse( RemotePeer& remotePeer ); - void OnServerDisconnect(); Address _serverAddress; - ClientState _currentState; // TODO We can probably make this a var within the Peer.h as it's shared by client and server uint32 _clientIndex; diff --git a/NetworkLibrary/src/Core/Peer.cpp b/NetworkLibrary/src/Core/Peer.cpp index b6c931a..55b3367 100644 --- a/NetworkLibrary/src/Core/Peer.cpp +++ b/NetworkLibrary/src/Core/Peer.cpp @@ -8,7 +8,11 @@ #include "communication/message_factory.h" #include "communication/network_packet_utils.h" +#include "connection/client_connection_pipeline.h" +#include "connection/server_connection_pipeline.h" + #include "logger.h" +#include "asserts.h" #include "core/buffer.h" #include "core/remote_peer.h" @@ -18,25 +22,46 @@ namespace NetLib { bool Peer::Start( const std::string& ip, uint32 port ) { - if ( _connectionState != PeerConnectionState::PCS_Disconnected ) + if ( _connectionState != PeerConnectionState::Disconnected ) { LOG_WARNING( "You are trying to call Peer::Start on a Peer that has already started" ); return true; } - SetConnectionState( PeerConnectionState::PCS_Connecting ); + SetConnectionState( PeerConnectionState::Connecting ); if ( _socket.Start() != SocketResult::SOKT_SUCCESS ) { LOG_ERROR( "Error while starting peer, aborting operation..." ); - SetConnectionState( PeerConnectionState::PCS_Disconnected ); + SetConnectionState( PeerConnectionState::Disconnected ); + return false; + } + + // TODO, This is hardcoded + Connection::ConnectionConfiguration connectionConfiguration; + connectionConfiguration.canStartConnections = ( _type == PeerType::CLIENT ); + connectionConfiguration.connectionTimeoutSeconds = 5.f; + connectionConfiguration.sendDenialOnTimeout = ( _type == PeerType::SERVER ); + if ( _type == PeerType::CLIENT ) + { + connectionConfiguration.connectionPipeline = new Connection::ClientConnectionPipeline(); + } + else if ( _type == PeerType::SERVER ) + { + connectionConfiguration.connectionPipeline = new Connection::ServerConnectionPipeline(); + } + + if ( !_connectionManager.StartUp( connectionConfiguration, &_messageFactory, &_remotePeersHandler ) ) + { + LOG_ERROR( "Error while starting peer connection manager, aborting operation..." ); + SetConnectionState( PeerConnectionState::Disconnected ); return false; } if ( !StartConcrete( ip, port ) ) { LOG_ERROR( "Error while starting peer, aborting operation..." ); - SetConnectionState( PeerConnectionState::PCS_Disconnected ); + SetConnectionState( PeerConnectionState::Disconnected ); return false; } @@ -47,12 +72,13 @@ namespace NetLib bool Peer::PreTick() { - if ( _connectionState == PeerConnectionState::PCS_Disconnected ) + if ( _connectionState == PeerConnectionState::Disconnected ) { - LOG_WARNING( "You are trying to call Peer::PreTick on a Peer that is disconnected" ); + LOG_WARNING( "You are trying to call PreTick on a Peer that is disconnected" ); return false; } + ReadReceivedData(); ProcessReceivedData(); return true; @@ -60,17 +86,21 @@ namespace NetLib bool Peer::Tick( float32 elapsedTime ) { - if ( _connectionState == PeerConnectionState::PCS_Disconnected ) + if ( _connectionState == PeerConnectionState::Disconnected ) { LOG_WARNING( "You are trying to call Peer::Tick on a Peer that is disconnected" ); return false; } + TickPendingConnections( elapsedTime ); TickRemotePeers( elapsedTime ); TickConcrete( elapsedTime ); FinishRemotePeersDisconnection(); + SendDataToPendingConnections(); SendDataToRemotePeers(); + ConvertSuccessfulConnectionsInRemotePeers(); + ProcessDeniedConnections(); if ( _isStopRequested ) { @@ -85,8 +115,7 @@ namespace NetLib bool Peer::Stop() { - RequestStop( true, ConnectionFailedReasonType::CFR_PEER_SHUT_DOWN ); - StopInternal(); + RequestStop( true, Connection::ConnectionFailedReasonType::PEER_SHUT_DOWN ); return true; } @@ -103,13 +132,14 @@ namespace NetLib return result; } - uint32 Peer::GetMetric( uint32 remote_peer_id, const std::string& metric_name, const std::string& value_type ) const + uint32 Peer::GetMetric( uint32 remote_peer_id, Metrics::MetricType metric_type, + Metrics::ValueType value_type ) const { uint32 result = 0; const RemotePeer* remotePeer = _remotePeersHandler.GetRemotePeerFromId( remote_peer_id ); if ( remotePeer != nullptr ) { - result = remotePeer->GetMetric( metric_name, value_type ); + result = remotePeer->GetMetric( metric_type, value_type ); } else { @@ -152,7 +182,7 @@ namespace NetLib Peer::Peer( PeerType type, uint32 maxConnections, uint32 receiveBufferSize, uint32 sendBufferSize ) : _type( type ) - , _connectionState( PeerConnectionState::PCS_Disconnected ) + , _connectionState( PeerConnectionState::Disconnected ) , _socket() , _address( Address::GetInvalid() ) , _receiveBufferSize( receiveBufferSize ) @@ -162,9 +192,10 @@ namespace NetLib , _onLocalPeerDisconnect() , _isStopRequested( false ) , _stopRequestShouldNotifyRemotePeers( false ) - , _stopRequestReason( ConnectionFailedReasonType::CFR_UNKNOWN ) + , _stopRequestReason( Connection::ConnectionFailedReasonType::UNKNOWN ) , _currentTick( 0 ) , _messageFactory( 3 ) + , _connectionManager() { _receiveBuffer = new uint8[ _receiveBufferSize ]; _sendBuffer = new uint8[ _sendBufferSize ]; @@ -181,16 +212,7 @@ namespace NetLib bool Peer::AddRemotePeer( const Address& addressInfo, uint16 id, uint64 clientSalt, uint64 serverSalt ) { - bool addedSuccesfully = _remotePeersHandler.AddRemotePeer( addressInfo, id, clientSalt, serverSalt ); - - return addedSuccesfully; - } - - void Peer::ConnectRemotePeer( RemotePeer& remotePeer ) - { - remotePeer.SetConnected(); - InternalOnRemotePeerConnect( remotePeer ); - ExecuteOnRemotePeerConnect( remotePeer.GetClientIndex() ); + return _remotePeersHandler.AddRemotePeer( addressInfo, id, clientSalt, serverSalt ); } bool Peer::BindSocket( const Address& address ) const @@ -204,7 +226,7 @@ namespace NetLib return true; } - void Peer::DisconnectAllRemotePeers( bool shouldNotify, ConnectionFailedReasonType reason ) + void Peer::DisconnectAllRemotePeers( bool shouldNotify, Connection::ConnectionFailedReasonType reason ) { if ( shouldNotify ) { @@ -222,7 +244,7 @@ namespace NetLib } void Peer::DisconnectRemotePeer( const RemotePeer& remotePeer, bool shouldNotify, - ConnectionFailedReasonType reason ) + Connection::ConnectionFailedReasonType reason ) { if ( shouldNotify ) { @@ -236,7 +258,7 @@ namespace NetLib ExecuteOnRemotePeerDisconnect( id ); } - void Peer::CreateDisconnectionPacket( const RemotePeer& remotePeer, ConnectionFailedReasonType reason ) + void Peer::CreateDisconnectionPacket( const RemotePeer& remotePeer, Connection::ConnectionFailedReasonType reason ) { NetworkPacket packet; packet.SetHeaderChannelType( TransmissionChannelType::UnreliableUnordered ); @@ -248,19 +270,25 @@ namespace NetLib disconenctionMessage->SetOrdered( false ); disconenctionMessage->SetReliability( false ); disconenctionMessage->prefix = remotePeer.GetDataPrefix(); - disconenctionMessage->reason = reason; + disconenctionMessage->reason = static_cast< uint8 >( reason ); packet.AddMessage( std::move( disconenctionMessage ) ); SendPacketToAddress( packet, remotePeer.GetAddress() ); + + while ( packet.GetNumberOfMessages() > 0 ) + { + std::unique_ptr< Message > message = packet.TryGetNextMessage(); + _messageFactory.ReleaseMessage( std::move( message ) ); + } } void Peer::ExecuteOnLocalPeerConnect() { - SetConnectionState( PeerConnectionState::PCS_Connected ); + SetConnectionState( PeerConnectionState::Connected ); _onLocalPeerConnect.Execute(); } - void Peer::ExecuteOnLocalPeerDisconnect( ConnectionFailedReasonType reason ) + void Peer::ExecuteOnLocalPeerDisconnect( Connection::ConnectionFailedReasonType reason ) { _onLocalPeerDisconnect.Execute( reason ); } @@ -271,12 +299,12 @@ namespace NetLib } bool Peer::UnsubscribeToOnPeerDisconnected( - const Common::Delegate< ConnectionFailedReasonType >::SubscriptionHandler& handler ) + const Common::Delegate< Connection::ConnectionFailedReasonType >::SubscriptionHandler& handler ) { return _onLocalPeerDisconnect.DeleteSubscriber( handler ); } - void Peer::ProcessReceivedData() + void Peer::ReadReceivedData() { Address remoteAddress = Address::GetInvalid(); uint32 numberOfBytesRead = 0; @@ -291,7 +319,7 @@ namespace NetLib { // Data read succesfully. Keep going! Buffer buffer = Buffer( _receiveBuffer, numberOfBytesRead ); - ProcessDatagram( buffer, remoteAddress ); + ReadDatagram( buffer, remoteAddress ); } else if ( result == SocketResult::SOKT_ERR || result == SocketResult::SOKT_WOULDBLOCK ) { @@ -305,20 +333,37 @@ namespace NetLib if ( remotePeer != nullptr ) { StartDisconnectingRemotePeer( remotePeer->GetClientIndex(), false, - ConnectionFailedReasonType::CFR_UNKNOWN ); + Connection::ConnectionFailedReasonType::UNKNOWN ); } } } while ( arePendingDatagramsToRead ); - - ProcessNewRemotePeerMessages(); } - void Peer::ProcessDatagram( Buffer& buffer, const Address& address ) + void Peer::ReadDatagram( Buffer& buffer, const Address& address ) { // TODO Add validation for tampered or corrupted packets so it doesn't crash when a tampered message arrives. // Read incoming packet - NetworkPacket packet = NetworkPacket(); - packet.Read( _messageFactory, buffer ); + // Read Network packet + NetworkPacket packet; + const bool readSuccessfully = NetworkPacketUtils::ReadNetworkPacket( buffer, _messageFactory, packet ); + if ( !readSuccessfully ) + { + std::string ip_and_port; + address.GetFull( ip_and_port ); + LOG_WARNING( "Received corrupted or invalid packet from %s. Discarding packet.", ip_and_port.c_str() ); + return; + } + + // Store messages within transmission channels for being processed + StoreReceivedMessages( packet, address ); + + // Clean up packet + NetworkPacketUtils::CleanPacket( _messageFactory, packet ); + } + + void Peer::StoreReceivedMessages( NetworkPacket& packet, const Address& address ) + { + bool processedSuccessfully = false; RemotePeer* remotePeer = _remotePeersHandler.GetRemotePeerFromAddress( address ); bool isPacketFromRemotePeer = ( remotePeer != nullptr ); @@ -328,17 +373,11 @@ namespace NetLib } else { - const std::vector< std::unique_ptr< Message > >& packetMessages = packet.GetAllMessages(); - for ( auto cit = packetMessages.cbegin(); cit != packetMessages.cend(); ++cit ) - { - ProcessMessageFromUnknownPeer( **cit, address ); - } - - NetworkPacketUtils::CleanPacket( _messageFactory, packet ); + _connectionManager.ProcessPacket( address, packet ); } } - void Peer::ProcessNewRemotePeerMessages() + void Peer::ProcessReceivedData() { auto validRemotePeersIt = _remotePeersHandler.GetValidRemotePeersIterator(); auto pastTheEndIt = _remotePeersHandler.GetValidRemotePeersPastTheEndIterator(); @@ -358,6 +397,48 @@ namespace NetLib } } + void Peer::TickPendingConnections( float32 elapsed_time ) + { + _connectionManager.Tick( elapsed_time ); + } + + void Peer::ConvertSuccessfulConnectionsInRemotePeers() + { + std::vector< Connection::SuccessConnectionData > successfulConnections; + _connectionManager.GetSuccessConnectionsData( successfulConnections ); + + for ( auto& cit = successfulConnections.cbegin(); cit != successfulConnections.cend(); ++cit ) + { + // TODO Change this and get rid of both salts in remote peer. Just keep the data prefix + if ( AddRemotePeer( cit->address, cit->id, cit->dataPrefix, 0 ) ) + { + RemotePeer* remotePeer = _remotePeersHandler.GetRemotePeerFromId( cit->id ); + ASSERT( remotePeer != nullptr, "Remote peer cannot be nullptr after its creation" ); + OnPendingConnectionAccepted( *cit ); + ExecuteOnRemotePeerConnect( cit->id ); + } + else + { + LOG_ERROR( "%s Error while adding new remote peer after successful connection", THIS_FUNCTION_NAME ); + } + } + + _connectionManager.RemoveSuccessConnections(); + } + + void Peer::ProcessDeniedConnections() + { + std::vector< Connection::FailedConnectionData > deniedConnections; + _connectionManager.GetFailedConnectionsData( deniedConnections ); + + for ( auto& cit = deniedConnections.cbegin(); cit != deniedConnections.cend(); ++cit ) + { + OnPendingConnectionDenied( *cit ); + } + + _connectionManager.RemoveFailedConnections(); + } + void Peer::TickRemotePeers( float32 elapsedTime ) { auto validRemotePeersIt = _remotePeersHandler.GetValidRemotePeersIterator(); @@ -372,7 +453,7 @@ namespace NetLib if ( remotePeer.IsInactive() ) { StartDisconnectingRemotePeer( remotePeer.GetClientIndex(), true, - ConnectionFailedReasonType::CFR_TIMEOUT ); + Connection::ConnectionFailedReasonType::TIMEOUT ); } } } @@ -388,12 +469,13 @@ namespace NetLib } } - void Peer::SendDataToAddress( const Buffer& buffer, const Address& address ) const + void Peer::SendDataToPendingConnections() { - _socket.SendTo( buffer.GetData(), buffer.GetSize(), address ); + _connectionManager.SendData( _socket ); } - void Peer::StartDisconnectingRemotePeer( uint32 id, bool shouldNotify, ConnectionFailedReasonType reason ) + void Peer::StartDisconnectingRemotePeer( uint32 id, bool shouldNotify, + Connection::ConnectionFailedReasonType reason ) { RemotePeer* remotePeer = _remotePeersHandler.GetRemotePeerFromId( id ); @@ -447,7 +529,7 @@ namespace NetLib } } - void Peer::RequestStop( bool shouldNotifyRemotePeers, ConnectionFailedReasonType reason ) + void Peer::RequestStop( bool shouldNotifyRemotePeers, Connection::ConnectionFailedReasonType reason ) { _isStopRequested = true; _stopRequestShouldNotifyRemotePeers = shouldNotifyRemotePeers; @@ -471,7 +553,7 @@ namespace NetLib void Peer::StopInternal() { - if ( _connectionState == PeerConnectionState::PCS_Disconnected ) + if ( _connectionState == PeerConnectionState::Disconnected ) { LOG_WARNING( "You are trying to call Peer::Stop on a Peer that is disconnected" ); return; @@ -480,11 +562,12 @@ namespace NetLib StopConcrete(); DisconnectAllRemotePeers( _stopRequestShouldNotifyRemotePeers, _stopRequestReason ); _socket.Close(); + _connectionManager.ShutDown(); _isStopRequested = false; PeerConnectionState previousConnectionState = _connectionState; - SetConnectionState( PeerConnectionState::PCS_Disconnected ); + SetConnectionState( PeerConnectionState::Disconnected ); LOG_INFO( "Peer stopped succesfully" ); ExecuteOnLocalPeerDisconnect( _stopRequestReason ); diff --git a/NetworkLibrary/src/Core/Peer.h b/NetworkLibrary/src/Core/Peer.h index 98069ac..12b2b19 100644 --- a/NetworkLibrary/src/Core/Peer.h +++ b/NetworkLibrary/src/Core/Peer.h @@ -13,6 +13,9 @@ #include "communication/message_factory.h" +#include "connection/connection_manager.h" +#include "connection/connection_failed_reason_type.h" + #include "transmission_channels/transmission_channel.h" class Buffer; @@ -24,20 +27,17 @@ namespace NetLib class RemotePeer; class Buffer; - enum ConnectionFailedReasonType : uint8 + namespace Connection { - CFR_UNKNOWN = 0, // Unexpect - CFR_TIMEOUT = 1, // The peer is inactive - CFR_SERVER_FULL = 2, // The server can't handle more connections, it has reached its maximum - CFR_PEER_SHUT_DOWN = 3, // The peer has shut down its Network system - CFR_CONNECTION_TIMEOUT = 4 // The in process connection has taken too long - }; + struct SuccessConnectionData; + struct FailedConnectionData; + } struct RemotePeerDisconnectionData { uint32 id; bool shouldNotify; - ConnectionFailedReasonType reason; + Connection::ConnectionFailedReasonType reason; }; enum class PeerType : uint8 @@ -47,11 +47,11 @@ namespace NetLib SERVER = 2 }; - enum PeerConnectionState : uint8 + enum class PeerConnectionState : uint8 { - PCS_Disconnected = 0, - PCS_Connecting = 1, - PCS_Connected = 2 + Disconnected = 0, + Connecting = 1, + Connected = 2 }; // TODO Set ordered and reliable flags in all the connection messages such as challenge response, connection @@ -88,12 +88,12 @@ namespace NetLib /// value of 0 is returned. /// /// The remote peer id - /// The name of the metric. See metrics/metric_names.h for more info - /// The type of value you want to get. See metrics/metric_names.h for more + /// The type of the metric. See metrics/metric_types.h for more info + /// The type of value you want to get. See metrics/metric_types.h for more /// info /// The metric value on success or 0 on failure - uint32 GetMetric( uint32 remote_peer_id, const std::string& metric_name, - const std::string& value_type ) const; + uint32 GetMetric( uint32 remote_peer_id, Metrics::MetricType metric_type, + Metrics::ValueType value_type ) const; float64 GetLocalTime() const; float64 GetServerTime() const; @@ -104,10 +104,10 @@ namespace NetLib bool UnsubscribeToOnPeerConnected( const Common::Delegate<>::SubscriptionHandler& handler ); template < typename Functor > - Common::Delegate< ConnectionFailedReasonType >::SubscriptionHandler SubscribeToOnLocalPeerDisconnect( - Functor&& functor ); + Common::Delegate< Connection::ConnectionFailedReasonType >::SubscriptionHandler + SubscribeToOnLocalPeerDisconnect( Functor&& functor ); bool UnsubscribeToOnPeerDisconnected( - const Common::Delegate< ConnectionFailedReasonType >::SubscriptionHandler& handler ); + const Common::Delegate< Connection::ConnectionFailedReasonType >::SubscriptionHandler& handler ); template < typename Functor > Common::Delegate< uint32 >::SubscriptionHandler SubscribeToOnRemotePeerDisconnect( Functor&& functor ); @@ -127,18 +127,17 @@ namespace NetLib virtual bool StartConcrete( const std::string& ip, uint32 port ) = 0; virtual void ProcessMessageFromPeer( const Message& message, RemotePeer& remotePeer ) = 0; - virtual void ProcessMessageFromUnknownPeer( const Message& message, const Address& address ) = 0; virtual void TickConcrete( float32 elapsedTime ) = 0; virtual bool StopConcrete() = 0; void SendPacketToAddress( const NetworkPacket& packet, const Address& address ) const; bool AddRemotePeer( const Address& addressInfo, uint16 id, uint64 clientSalt, uint64 serverSalt ); - void ConnectRemotePeer( RemotePeer& remotePeer ); bool BindSocket( const Address& address ) const; - void StartDisconnectingRemotePeer( uint32 id, bool shouldNotify, ConnectionFailedReasonType reason ); + void StartDisconnectingRemotePeer( uint32 id, bool shouldNotify, + Connection::ConnectionFailedReasonType reason ); - void RequestStop( bool shouldNotifyRemotePeers, ConnectionFailedReasonType reason ); + void RequestStop( bool shouldNotifyRemotePeers, Connection::ConnectionFailedReasonType reason ); // Delegates related @@ -146,35 +145,58 @@ namespace NetLib /// Called during OnRemotePeerConnect. This function is used for events happening inside the network library /// code. This function will be called before teh OnRemotePeerConnect Delegate /// - virtual void InternalOnRemotePeerConnect( RemotePeer& remote_peer ) = 0; + virtual void OnPendingConnectionAccepted( const Connection::SuccessConnectionData& data ) = 0; + virtual void OnPendingConnectionDenied( const Connection::FailedConnectionData& data ) = 0; virtual void InternalOnRemotePeerDisconnect( const RemotePeer& remote_peer ) = 0; void ExecuteOnLocalPeerConnect(); - void ExecuteOnLocalPeerDisconnect( ConnectionFailedReasonType reason ); + void ExecuteOnLocalPeerDisconnect( Connection::ConnectionFailedReasonType reason ); RemotePeersHandler _remotePeersHandler; MessageFactory _messageFactory; + Connection::ConnectionManager _connectionManager; + private: + /// + /// Reads all the incoming received data from the socket + /// + void ReadReceivedData(); + + /// + /// Reads an incoming datagram received from the specified address + /// + void ReadDatagram( Buffer& buffer, const Address& address ); + + /// + /// Stores all the messages from a network packet into the corresponding transmission channels + /// + void StoreReceivedMessages( NetworkPacket& packet, const Address& address ); + + /// + /// Process all the pending received data from all the remote peers + /// void ProcessReceivedData(); - void ProcessDatagram( Buffer& buffer, const Address& address ); - void ProcessNewRemotePeerMessages(); void SetConnectionState( PeerConnectionState state ); + void TickPendingConnections( float32 elapsed_time ); + void ConvertSuccessfulConnectionsInRemotePeers(); + void ProcessDeniedConnections(); + // Remote peer related void TickRemotePeers( float32 elapsedTime ); - void DisconnectAllRemotePeers( bool shouldNotify, ConnectionFailedReasonType reason ); + void DisconnectAllRemotePeers( bool shouldNotify, Connection::ConnectionFailedReasonType reason ); void DisconnectRemotePeer( const RemotePeer& remotePeer, bool shouldNotify, - ConnectionFailedReasonType reason ); + Connection::ConnectionFailedReasonType reason ); - void CreateDisconnectionPacket( const RemotePeer& remotePeer, ConnectionFailedReasonType reason ); + void CreateDisconnectionPacket( const RemotePeer& remotePeer, + Connection::ConnectionFailedReasonType reason ); /// /// Sends pending data to all the connected remote peers /// void SendDataToRemotePeers(); - - void SendDataToAddress( const Buffer& buffer, const Address& address ) const; + void SendDataToPendingConnections(); bool DoesRemotePeerIdExistInPendingDisconnections( uint32 id ) const; void FinishRemotePeersDisconnection(); @@ -200,12 +222,12 @@ namespace NetLib // Stop request bool _isStopRequested; bool _stopRequestShouldNotifyRemotePeers; - ConnectionFailedReasonType _stopRequestReason; + Connection::ConnectionFailedReasonType _stopRequestReason; std::list< RemotePeerDisconnectionData > _remotePeerPendingDisconnections; Common::Delegate<> _onLocalPeerConnect; - Common::Delegate< ConnectionFailedReasonType > _onLocalPeerDisconnect; + Common::Delegate< Connection::ConnectionFailedReasonType > _onLocalPeerDisconnect; Common::Delegate< uint32 > _onRemotePeerConnect; Common::Delegate< uint32 > _onRemotePeerDisconnect; }; @@ -217,8 +239,8 @@ namespace NetLib } template < typename Functor > - inline Common::Delegate< ConnectionFailedReasonType >::SubscriptionHandler Peer::SubscribeToOnLocalPeerDisconnect( - Functor&& functor ) + inline Common::Delegate< Connection::ConnectionFailedReasonType >::SubscriptionHandler Peer:: + SubscribeToOnLocalPeerDisconnect( Functor&& functor ) { return _onLocalPeerDisconnect.AddSubscriber( std::forward< Functor >( functor ) ); } diff --git a/NetworkLibrary/src/Core/Server.cpp b/NetworkLibrary/src/Core/Server.cpp index c0c5ca7..6c4fad4 100644 --- a/NetworkLibrary/src/Core/Server.cpp +++ b/NetworkLibrary/src/Core/Server.cpp @@ -4,6 +4,7 @@ #include #include "logger.h" +#include "asserts.h" #include "core/time_clock.h" @@ -35,7 +36,7 @@ namespace NetLib bool Server::CreateNetworkEntity( uint32 entityType, uint32 controlledByPeerId, float32 posX, float32 posY ) { - if ( GetConnectionState() != PeerConnectionState::PCS_Connected ) + if ( GetConnectionState() != PeerConnectionState::Connected ) { LOG_WARNING( "Can't create Network entity of type %d because the server is not connected.", static_cast< int >( entityType ) ); @@ -48,7 +49,7 @@ namespace NetLib void Server::DestroyNetworkEntity( uint32 entityId ) { - if ( GetConnectionState() != PeerConnectionState::PCS_Connected ) + if ( GetConnectionState() != PeerConnectionState::Connected ) { LOG_WARNING( "Can't destroy Network entity with ID: %d because the server is not connected.", static_cast< int >( entityId ) ); @@ -106,14 +107,6 @@ namespace NetLib TickReplication(); } - uint64 Server::GenerateServerSalt() const - { - // TODO Change this in order to get another random generator that generates 64bit numbers - srand( static_cast< uint32 >( time( NULL ) ) + 3589 ); - uint64 serverSalt = rand(); - return serverSalt; - } - void Server::ProcessMessageFromPeer( const Message& message, RemotePeer& remotePeer ) { MessageType messageType = message.GetHeader().type; @@ -121,20 +114,6 @@ namespace NetLib switch ( messageType ) { - case MessageType::ConnectionRequest: - { - const ConnectionRequestMessage& connectionRequestMessage = - static_cast< const ConnectionRequestMessage& >( message ); - ProcessConnectionRequest( connectionRequestMessage, remotePeer.GetAddress() ); - break; - } - case MessageType::ConnectionChallengeResponse: - { - const ConnectionChallengeResponseMessage& connectionChallengeResponseMessage = - static_cast< const ConnectionChallengeResponseMessage& >( message ); - ProcessConnectionChallengeResponse( connectionChallengeResponseMessage, remotePeer ); - break; - } case MessageType::TimeRequest: { const TimeRequestMessage& timeRequestMessage = static_cast< const TimeRequestMessage& >( message ); @@ -164,67 +143,6 @@ namespace NetLib } } - void Server::ProcessMessageFromUnknownPeer( const Message& message, const Address& address ) - { - if ( message.GetHeader().type == MessageType::ConnectionRequest ) - { - const ConnectionRequestMessage& connectionRequestMessage = - static_cast< const ConnectionRequestMessage& >( message ); - ProcessConnectionRequest( connectionRequestMessage, address ); - } - else - { - LOG_WARNING( "Server only process Connection request messages from unknown peers. Any other type of " - "message will be discarded." ); - } - } - - void Server::ProcessConnectionRequest( const ConnectionRequestMessage& message, const Address& address ) - { - std::string ip_and_port; - address.GetFull( ip_and_port ); - LOG_INFO( "Processing connection request from [%s] with salt number %d", ip_and_port.c_str(), - message.clientSalt ); - - RemotePeersHandlerResult isAbleToConnectResult = _remotePeersHandler.IsRemotePeerAbleToConnect( address ); - - if ( isAbleToConnectResult == - RemotePeersHandlerResult::RPH_SUCCESS ) // If there is green light keep with the connection pipeline. - { - uint64 clientSalt = message.clientSalt; - uint64 serverSalt = GenerateServerSalt(); - AddRemotePeer( address, _nextAssignedRemotePeerID, clientSalt, serverSalt ); - ++_nextAssignedRemotePeerID; - - RemotePeer* remotePeer = _remotePeersHandler.GetRemotePeerFromAddress( address ); - CreateConnectionChallengeMessage( *remotePeer ); - } - else if ( isAbleToConnectResult == - RemotePeersHandlerResult::RPH_ALREADYEXIST ) // If the client is already connected just send a - // connection approved message - { - RemotePeer* remotePeer = _remotePeersHandler.GetRemotePeerFromAddress( address ); - - RemotePeerState remotePeerState = remotePeer->GeturrentState(); - if ( remotePeerState == RemotePeerState::Connected ) - { - CreateConnectionApprovedMessage( *remotePeer ); - LOG_INFO( "The client is already connected, sending connection approved..." ); - } - else if ( remotePeerState == RemotePeerState::Connecting ) - { - CreateConnectionChallengeMessage( *remotePeer ); - LOG_INFO( "The client is already trying to connect, sending connection challenge..." ); - } - } - else if ( isAbleToConnectResult == - RemotePeersHandlerResult::RPH_FULL ) // If all the client slots are full deny the connection - { - SendConnectionDeniedPacket( address, ConnectionFailedReasonType::CFR_SERVER_FULL ); - LOG_WARNING( "All available connection slots are full. Denying incoming connection..." ); - } - } - void Server::CreateDisconnectionMessage( RemotePeer& remotePeer ) { std::unique_ptr< Message > message = _messageFactory.LendMessage( MessageType::Disconnection ); @@ -260,74 +178,6 @@ namespace NetLib remotePeer.AddMessage( std::move( timeResponseMessage ) ); } - void Server::CreateConnectionChallengeMessage( RemotePeer& remotePeer ) - { - std::unique_ptr< Message > message = _messageFactory.LendMessage( MessageType::ConnectionChallenge ); - if ( message == nullptr ) - { - LOG_ERROR( "Can't create new Connection Challenge Message because the MessageFactory has returned a null " - "message" ); - return; - } - - std::unique_ptr< ConnectionChallengeMessage > connectionChallengePacket( - static_cast< ConnectionChallengeMessage* >( message.release() ) ); - connectionChallengePacket->clientSalt = remotePeer.GetClientSalt(); - connectionChallengePacket->serverSalt = remotePeer.GetServerSalt(); - remotePeer.AddMessage( std::move( connectionChallengePacket ) ); - - LOG_INFO( "Connection challenge message created." ); - } - - void Server::SendConnectionDeniedPacket( const Address& address, ConnectionFailedReasonType reason ) - { - std::unique_ptr< Message > message = _messageFactory.LendMessage( MessageType::ConnectionDenied ); - std::unique_ptr< ConnectionDeniedMessage > connectionDeniedMessage( - static_cast< ConnectionDeniedMessage* >( message.release() ) ); - connectionDeniedMessage->reason = reason; - - NetworkPacket packet = NetworkPacket(); - packet.AddMessage( std::move( connectionDeniedMessage ) ); - - LOG_INFO( "Sending connection denied..." ); - SendPacketToAddress( packet, address ); - - NetworkPacketUtils::CleanPacket( _messageFactory, packet ); - } - - void Server::ProcessConnectionChallengeResponse( const ConnectionChallengeResponseMessage& message, - RemotePeer& remotePeer ) - { - std::string ip_and_port; - remotePeer.GetAddress().GetFull( ip_and_port ); - LOG_INFO( "Processing connection challenge response from [%s]", ip_and_port.c_str() ); - - if ( remotePeer.GeturrentState() == RemotePeerState::Connected ) - { - LOG_INFO( "The remote peer is already connected. Sending connection approved..." ); - CreateConnectionApprovedMessage( remotePeer ); - return; - } - - uint64 dataPrefix = message.prefix; - - if ( remotePeer.GetDataPrefix() == dataPrefix ) - { - ConnectRemotePeer( remotePeer ); - - // Send connection approved packet - CreateConnectionApprovedMessage( remotePeer ); - LOG_INFO( "Connection approved" ); - } - else - { - LOG_INFO( "Connection denied due to not wrong data prefix" ); - SendConnectionDeniedPacket( remotePeer.GetAddress(), ConnectionFailedReasonType::CFR_UNKNOWN ); - - StartDisconnectingRemotePeer( remotePeer.GetClientIndex(), false, ConnectionFailedReasonType::CFR_UNKNOWN ); - } - } - // TODO REFACTOR THIS METHOD void Server::ProcessTimeRequest( const TimeRequestMessage& message, RemotePeer& remotePeer ) { @@ -345,8 +195,15 @@ namespace NetLib assert( inputState != nullptr ); Buffer buffer( message.data, message.dataSize ); - inputState->Deserialize( buffer ); - _remotePeerInputsHandler.AddInputState( inputState, remotePeer.GetClientIndex() ); + if ( !inputState->Deserialize( buffer ) ) + { + LOG_ERROR( "Server::%s, Failed to deserialize input state from remote peer %u. Ignoring input...", + THIS_FUNCTION_NAME, remotePeer.GetClientIndex() ); + } + else + { + _remotePeerInputsHandler.AddInputState( inputState, remotePeer.GetClientIndex() ); + } } else { @@ -368,29 +225,8 @@ namespace NetLib "remove peer...", message.reason ); - StartDisconnectingRemotePeer( remotePeer.GetClientIndex(), false, ConnectionFailedReasonType::CFR_UNKNOWN ); - } - - void Server::CreateConnectionApprovedMessage( RemotePeer& remotePeer ) - { - std::unique_ptr< Message > message = _messageFactory.LendMessage( MessageType::ConnectionAccepted ); - if ( message == nullptr ) - { - LOG_ERROR( - "Can't create new Connection Accepted Message because the MessageFactory has returned a null message" ); - return; - } - - std::unique_ptr< ConnectionAcceptedMessage > connectionAcceptedPacket( - static_cast< ConnectionAcceptedMessage* >( message.release() ) ); - connectionAcceptedPacket->prefix = remotePeer.GetDataPrefix(); - connectionAcceptedPacket->clientIndexAssigned = remotePeer.GetClientIndex(); - remotePeer.AddMessage( std::move( connectionAcceptedPacket ) ); - } - - void Server::SendPacketToRemotePeer( const RemotePeer& remotePeer, const NetworkPacket& packet ) const - { - SendPacketToAddress( packet, remotePeer.GetAddress() ); + StartDisconnectingRemotePeer( remotePeer.GetClientIndex(), false, + Connection::ConnectionFailedReasonType::UNKNOWN ); } void Server::TickReplication() @@ -424,9 +260,13 @@ namespace NetLib return true; } - void Server::InternalOnRemotePeerConnect( RemotePeer& remote_peer ) + void Server::OnPendingConnectionAccepted( const Connection::SuccessConnectionData& data ) { - _remotePeerInputsHandler.CreateInputsBuffer( remote_peer.GetClientIndex() ); + ASSERT( !data.startedLocally, + "Server-side can't receive a connection accepted that was started locally. Server " + "does not start conenctions locally." ); + + _remotePeerInputsHandler.CreateInputsBuffer( data.id ); } void Server::InternalOnRemotePeerDisconnect( const RemotePeer& remote_peer ) diff --git a/NetworkLibrary/src/Core/Server.h b/NetworkLibrary/src/Core/Server.h index 914d16a..67d6b83 100644 --- a/NetworkLibrary/src/Core/Server.h +++ b/NetworkLibrary/src/Core/Server.h @@ -49,19 +49,15 @@ namespace NetLib protected: bool StartConcrete( const std::string& ip, uint32 port ) override; void ProcessMessageFromPeer( const Message& message, RemotePeer& remotePeer ) override; - void ProcessMessageFromUnknownPeer( const Message& message, const Address& address ) override; void TickConcrete( float32 elapsedTime ) override; bool StopConcrete() override; - void InternalOnRemotePeerConnect( RemotePeer& remote_peer ) override; + void OnPendingConnectionAccepted( const Connection::SuccessConnectionData& data ) override; + void OnPendingConnectionDenied( const Connection::FailedConnectionData& data ) override {} + void InternalOnRemotePeerDisconnect( const RemotePeer& remote_peer ) override; private: - uint64 GenerateServerSalt() const; - - void ProcessConnectionRequest( const ConnectionRequestMessage& message, const Address& address ); - void ProcessConnectionChallengeResponse( const ConnectionChallengeResponseMessage& message, - RemotePeer& remotePeer ); void ProcessTimeRequest( const TimeRequestMessage& message, RemotePeer& remotePeer ); void ProcessInputs( const InputStateMessage& message, RemotePeer& remotePeer ); void ProcessDisconnection( const DisconnectionMessage& message, RemotePeer& remotePeer ); @@ -77,12 +73,8 @@ namespace NetLib /// // int32 IsRemotePeerAbleToConnect(const Address& address) const; - void CreateConnectionChallengeMessage( RemotePeer& remotePeer ); - void CreateConnectionApprovedMessage( RemotePeer& remotePeer ); void CreateDisconnectionMessage( RemotePeer& remotePeer ); void CreateTimeResponseMessage( RemotePeer& remotePeer, const TimeRequestMessage& timeRequest ); - void SendConnectionDeniedPacket( const Address& address, ConnectionFailedReasonType reason ); - void SendPacketToRemotePeer( const RemotePeer& remotePeer, const NetworkPacket& packet ) const; void TickReplication(); diff --git a/NetworkLibrary/src/Core/connecting_remote_peer.h b/NetworkLibrary/src/Core/connecting_remote_peer.h new file mode 100644 index 0000000..c13e8b0 --- /dev/null +++ b/NetworkLibrary/src/Core/connecting_remote_peer.h @@ -0,0 +1,38 @@ +#pragma once +#include "numeric_types.h" + +#include "core/address.h" +#include "transmission_channels/unreliable_unordered_transmission_channel.h" +#include "metrics/metrics_handler.h" + +namespace NetLib +{ + enum class ConnectingRemotePeerState : uint8 + { + ConnectionRequest = 0, + ConnectionChallenge = 1 + }; + + class ConnectingRemotePeer + { + public: + ConnectingRemotePeer(); + + void StartUp(); + void ShutDown(); + + void Connect( const Address& address, uint64 clientSalt, uint64 serverSalt ); + void Disconnect(); + + private: + Address _address; + ConnectingRemotePeerState _currentState; + + float32 _maxInactivityTime; + float32 _inactivityTimeLeft; + uint64 _clientSalt; + uint64 _serverSalt; + + UnreliableUnorderedTransmissionChannel _transmissionChannel; + }; +} // namespace NetLib diff --git a/NetworkLibrary/src/Core/remote_peer.cpp b/NetworkLibrary/src/Core/remote_peer.cpp index b6ef887..4a87be2 100644 --- a/NetworkLibrary/src/Core/remote_peer.cpp +++ b/NetworkLibrary/src/Core/remote_peer.cpp @@ -10,7 +10,7 @@ #include "transmission_channels/unreliable_unordered_transmission_channel.h" #include "transmission_channels/reliable_ordered_channel.h" -#include "metrics/metric_names.h" +#include "metrics/metric_types.h" #include "core/Socket.h" @@ -68,7 +68,6 @@ namespace NetLib , _nextPacketSequenceNumber( 0 ) , _currentState( RemotePeerState::Disconnected ) , _transmissionChannels() - , _metricsEnabled( false ) { // InitTransmissionChannels(); } @@ -82,7 +81,6 @@ namespace NetLib , _nextPacketSequenceNumber( 0 ) , _currentState( RemotePeerState::Disconnected ) , _transmissionChannels() - , _metricsEnabled( false ) { InitTransmissionChannels( message_factory ); } @@ -92,7 +90,6 @@ namespace NetLib : _address( Address::GetInvalid() ) , _nextPacketSequenceNumber( 0 ) , _currentState( RemotePeerState::Disconnected ) - , _metricsEnabled( false ) { InitTransmissionChannels( message_factory ); Connect( address, id, maxInactivityTime, clientSalt, serverSalt ); @@ -112,27 +109,26 @@ namespace NetLib _transmissionChannels.clear(); } - void RemotePeer::ActivateNetworkStatistics() - { - _metricsEnabled = true; - _metricsHandler.Configure( 1.f ); - } - - void RemotePeer::DeactivateNetworkStatistics() - { - _metricsEnabled = false; - } - void RemotePeer::Connect( const Address& address, uint16 id, float32 maxInactivityTime, uint64 clientSalt, uint64 serverSalt ) { + bool result = true; + _address = address; _id = id; _maxInactivityTime = maxInactivityTime; _inactivityTimeLeft = _maxInactivityTime; _clientSalt = clientSalt; _serverSalt = serverSalt; - _currentState = RemotePeerState::Connecting; + _currentState = RemotePeerState::Connected; + + // TODO Add here the list of metrics or metrics data we can to enable for this remote peer + if ( !_metricsHandler.StartUp( 1.f, Metrics::MetricsEnableConfig::ENABLE_ALL ) ) + { + LOG_ERROR( "[RemotePeer.%s] Failed to start up metrics handler for remote peer %u.", THIS_FUNCTION_NAME, + _id ); + result = false; + } } void RemotePeer::Tick( float32 elapsedTime, MessageFactory& message_factory ) @@ -144,18 +140,14 @@ namespace NetLib _inactivityTimeLeft = 0.f; } - Metrics::MetricsHandler* metricsHandler = _metricsEnabled ? &_metricsHandler : nullptr; // Update transmission channels for ( uint32 i = 0; i < GetNumberOfTransmissionChannels(); ++i ) { - _transmissionChannels[ i ]->Update( elapsedTime, metricsHandler ); + _transmissionChannels[ i ]->Update( elapsedTime, _metricsHandler ); } - if ( _metricsEnabled ) - { - _metricsHandler.Update( elapsedTime ); - _pingPongMessagesSender.Update( elapsedTime, *this, message_factory ); - } + _metricsHandler.Update( elapsedTime ); + _pingPongMessagesSender.Update( elapsedTime, *this, message_factory ); } bool RemotePeer::AddMessage( std::unique_ptr< Message > message ) @@ -201,8 +193,7 @@ namespace NetLib for ( ; it < _transmissionChannels.end(); ++it ) { TransmissionChannel* channel = *it; - Metrics::MetricsHandler* metricsHandler = _metricsEnabled ? &_metricsHandler : nullptr; - channel->CreateAndSendPacket( socket, _address, metricsHandler ); + channel->CreateAndSendPacket( socket, _address, _metricsHandler ); } } @@ -232,9 +223,9 @@ namespace NetLib AddReceivedMessage( std::move( message ) ); } - if ( _metricsEnabled ) + if ( _metricsHandler.HasMetric( Metrics::MetricType::DOWNLOAD_BANDWIDTH ) ) { - _metricsHandler.AddValue( Metrics::DOWNLOAD_BANDWIDTH_METRIC, packet_size ); + _metricsHandler.AddValue( Metrics::MetricType::DOWNLOAD_BANDWIDTH, packet_size ); } } @@ -244,8 +235,7 @@ namespace NetLib TransmissionChannel* transmissionChannel = GetTransmissionChannelFromType( channelType ); if ( transmissionChannel != nullptr ) { - Metrics::MetricsHandler* metricsHandler = _metricsEnabled ? &_metricsHandler : nullptr; - transmissionChannel->ProcessACKs( acks, lastAckedMessageSequenceNumber, metricsHandler ); + transmissionChannel->ProcessACKs( acks, lastAckedMessageSequenceNumber, _metricsHandler ); } } @@ -258,8 +248,7 @@ namespace NetLib TransmissionChannel* transmissionChannel = GetTransmissionChannelFromType( channelType ); if ( transmissionChannel != nullptr ) { - Metrics::MetricsHandler* metricsHandler = _metricsEnabled ? &_metricsHandler : nullptr; - transmissionChannel->AddReceivedMessage( std::move( message ), metricsHandler ); + transmissionChannel->AddReceivedMessage( std::move( message ), _metricsHandler ); _inactivityTimeLeft = _maxInactivityTime; } else @@ -321,16 +310,18 @@ namespace NetLib return static_cast< uint32 >( _transmissionChannels.size() ); } - uint32 RemotePeer::GetMetric( const std::string& metric_name, const std::string& value_type ) const + uint32 RemotePeer::GetMetric( Metrics::MetricType metric_type, Metrics::ValueType value_type ) const { uint32 result = 0; - if ( _metricsEnabled ) + if ( _metricsHandler.HasMetric( metric_type ) ) { - result = _metricsHandler.GetValue( metric_name, value_type ); + result = _metricsHandler.GetValue( metric_type, value_type ); } else { - LOG_WARNING( "You are trying to get a metric value from a RemotePeer that doesn't have metrics enabled" ); + LOG_WARNING( "[RemotePeer.%s] You are trying to get a metric value from a RemotePeer that doesn't have " + "metric of type %u enabled", + THIS_FUNCTION_NAME, static_cast< uint8 >( metric_type ) ); } return result; @@ -338,6 +329,8 @@ namespace NetLib void RemotePeer::Disconnect() { + bool result = true; + // Reset transmission channels for ( uint32 i = 0; i < GetNumberOfTransmissionChannels(); ++i ) { @@ -349,6 +342,11 @@ namespace NetLib _currentState = RemotePeerState::Disconnected; - DeactivateNetworkStatistics(); + if ( !_metricsHandler.ShutDown() ) + { + LOG_ERROR( "[RemotePeer.%s] Failed to shut down metrics handler for remote peer %u.", THIS_FUNCTION_NAME, + _id ); + result = false; + } } } // namespace NetLib diff --git a/NetworkLibrary/src/Core/remote_peer.h b/NetworkLibrary/src/Core/remote_peer.h index cd8a044..a4fa74e 100644 --- a/NetworkLibrary/src/Core/remote_peer.h +++ b/NetworkLibrary/src/Core/remote_peer.h @@ -11,6 +11,7 @@ #include "core/ping_pong_messages_sender.h" #include "metrics/metrics_handler.h" +#include "metrics/metric_types.h" #include "transmission_channels/transmission_channel.h" @@ -21,11 +22,10 @@ namespace NetLib class Socket; class MessageFactory; - enum RemotePeerState : uint8 + enum class RemotePeerState : uint8 { Disconnected = 0, - Connected = 1, - Connecting = 2 + Connected = 1 }; class RemotePeer @@ -40,11 +40,11 @@ namespace NetLib uint64 _clientSalt; uint64 _serverSalt; + // TODO Unused variable. Remove it uint16 _nextPacketSequenceNumber; std::vector< TransmissionChannel* > _transmissionChannels; - bool _metricsEnabled; Metrics::MetricsHandler _metricsHandler; PingPongMessagesSender _pingPongMessagesSender; @@ -65,9 +65,6 @@ namespace NetLib RemotePeer& operator=( const RemotePeer& ) = delete; ~RemotePeer(); - void ActivateNetworkStatistics(); - void DeactivateNetworkStatistics(); - /// /// Initializes all the internal systems. You must call this method before performing any other operation. /// It is also automatically called in parameterized ctor @@ -114,7 +111,7 @@ namespace NetLib std::vector< TransmissionChannelType > GetAvailableTransmissionChannelTypes() const; uint32 GetNumberOfTransmissionChannels() const; - uint32 GetMetric( const std::string& metric_name, const std::string& value_type ) const; + uint32 GetMetric( Metrics::MetricType metric_type, Metrics::ValueType value_type ) const; /// /// Disconnect and reset the remote client diff --git a/NetworkLibrary/src/Core/remote_peers_handler.cpp b/NetworkLibrary/src/Core/remote_peers_handler.cpp index c65d11c..a49f346 100644 --- a/NetworkLibrary/src/Core/remote_peers_handler.cpp +++ b/NetworkLibrary/src/Core/remote_peers_handler.cpp @@ -66,7 +66,6 @@ namespace NetLib _remotePeerSlots[ slotIndex ] = true; _remotePeers[ slotIndex ].Connect( addressInfo, id, REMOTE_PEER_INACTIVITY_TIME, clientSalt, serverSalt ); - _remotePeers[ slotIndex ].ActivateNetworkStatistics(); auto it = _validRemotePeers.insert( &( _remotePeers[ slotIndex ] ) ); assert( it.second ); // If the element was already there it means that we are trying to add it again. ERROR!! @@ -107,6 +106,22 @@ namespace NetLib return freeIndex; } + uint32 RemotePeersHandler::GetNumberOfAvailableRemotePeerSlots() const + { + ASSERT( _isInitialized, "Remote peers handler is not initialized." ); + + uint32 availableIndex = 0; + for ( uint32 i = 0; i < _maxConnections; ++i ) + { + if ( !_remotePeerSlots[ i ] ) + { + ++availableIndex; + } + } + + return availableIndex; + } + RemotePeer* RemotePeersHandler::GetRemotePeerFromAddress( const Address& address ) { ASSERT( _isInitialized, "Remote peers handler is not initialized." ); diff --git a/NetworkLibrary/src/Core/remote_peers_handler.h b/NetworkLibrary/src/Core/remote_peers_handler.h index 4357750..020963c 100644 --- a/NetworkLibrary/src/Core/remote_peers_handler.h +++ b/NetworkLibrary/src/Core/remote_peers_handler.h @@ -35,6 +35,7 @@ namespace NetLib bool AddRemotePeer( const Address& addressInfo, uint16 id, uint64 clientSalt, uint64 serverSalt ); int32 FindFreeRemotePeerSlot() const; + uint32 GetNumberOfAvailableRemotePeerSlots() const; RemotePeer* GetRemotePeerFromAddress( const Address& address ); RemotePeer* GetRemotePeerFromId( uint32 id ); const RemotePeer* GetRemotePeerFromId( uint32 id ) const; @@ -47,6 +48,7 @@ namespace NetLib void RemoveAllRemotePeers(); bool RemoveRemotePeer( uint32 remotePeerId ); + uint32 GetMaxConnections() const { return _maxConnections; } private: int32 GetIndexFromId( uint32 id ) const; diff --git a/NetworkLibrary/src/communication/message.cpp b/NetworkLibrary/src/communication/message.cpp index c9be12d..4d47dbc 100644 --- a/NetworkLibrary/src/communication/message.cpp +++ b/NetworkLibrary/src/communication/message.cpp @@ -12,12 +12,20 @@ namespace NetLib buffer.WriteLong( clientSalt ); } - void ConnectionRequestMessage::Read( Buffer& buffer ) + bool ConnectionRequestMessage::Read( Buffer& buffer ) { _header.type = MessageType::ConnectionRequest; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } + + if ( !buffer.ReadLong( clientSalt ) ) + { + return false; + } - clientSalt = buffer.ReadLong(); + return true; } uint32 ConnectionRequestMessage::Size() const @@ -28,22 +36,28 @@ namespace NetLib void ConnectionChallengeMessage::Write( Buffer& buffer ) const { _header.Write( buffer ); - buffer.WriteLong( clientSalt ); buffer.WriteLong( serverSalt ); } - void ConnectionChallengeMessage::Read( Buffer& buffer ) + bool ConnectionChallengeMessage::Read( Buffer& buffer ) { _header.type = MessageType::ConnectionChallenge; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } + + if ( !buffer.ReadLong( serverSalt ) ) + { + return false; + } - clientSalt = buffer.ReadLong(); - serverSalt = buffer.ReadLong(); + return true; } uint32 ConnectionChallengeMessage::Size() const { - return MessageHeader::Size() + ( sizeof( uint64 ) * 2 ); + return MessageHeader::Size() + sizeof( uint64 ); } void ConnectionChallengeResponseMessage::Write( Buffer& buffer ) const @@ -52,12 +66,20 @@ namespace NetLib buffer.WriteLong( prefix ); } - void ConnectionChallengeResponseMessage::Read( Buffer& buffer ) + bool ConnectionChallengeResponseMessage::Read( Buffer& buffer ) { _header.type = MessageType::ConnectionChallengeResponse; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } + + if ( !buffer.ReadLong( prefix ) ) + { + return false; + } - prefix = buffer.ReadLong(); + return true; } uint32 ConnectionChallengeResponseMessage::Size() const @@ -72,13 +94,25 @@ namespace NetLib buffer.WriteShort( clientIndexAssigned ); } - void ConnectionAcceptedMessage::Read( Buffer& buffer ) + bool ConnectionAcceptedMessage::Read( Buffer& buffer ) { _header.type = MessageType::ConnectionAccepted; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } + + if ( !buffer.ReadLong( prefix ) ) + { + return false; + } + + if ( !buffer.ReadShort( clientIndexAssigned ) ) + { + return false; + } - prefix = buffer.ReadLong(); - clientIndexAssigned = buffer.ReadShort(); + return true; } uint32 ConnectionAcceptedMessage::Size() const @@ -92,11 +126,20 @@ namespace NetLib buffer.WriteByte( reason ); } - void ConnectionDeniedMessage::Read( Buffer& buffer ) + bool ConnectionDeniedMessage::Read( Buffer& buffer ) { _header.type = MessageType::ConnectionDenied; - _header.ReadWithoutHeader( buffer ); - reason = buffer.ReadByte(); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } + + if ( !buffer.ReadByte( reason ) ) + { + return false; + } + + return true; } uint32 ConnectionDeniedMessage::Size() const @@ -111,13 +154,25 @@ namespace NetLib buffer.WriteByte( reason ); } - void DisconnectionMessage::Read( Buffer& buffer ) + bool DisconnectionMessage::Read( Buffer& buffer ) { _header.type = MessageType::Disconnection; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } + + if ( !buffer.ReadLong( prefix ) ) + { + return false; + } + + if ( !buffer.ReadByte( reason ) ) + { + return false; + } - prefix = buffer.ReadLong(); - reason = buffer.ReadByte(); + return true; } uint32 DisconnectionMessage::Size() const @@ -131,12 +186,20 @@ namespace NetLib buffer.WriteInteger( remoteTime ); } - void TimeRequestMessage::Read( Buffer& buffer ) + bool TimeRequestMessage::Read( Buffer& buffer ) { _header.type = MessageType::TimeRequest; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } - remoteTime = buffer.ReadInteger(); + if ( !buffer.ReadInteger( remoteTime ) ) + { + return false; + } + + return true; } uint32 TimeRequestMessage::Size() const @@ -151,13 +214,25 @@ namespace NetLib buffer.WriteInteger( serverTime ); } - void TimeResponseMessage::Read( Buffer& buffer ) + bool TimeResponseMessage::Read( Buffer& buffer ) { _header.type = MessageType::TimeResponse; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } - remoteTime = buffer.ReadInteger(); - serverTime = buffer.ReadInteger(); + if ( !buffer.ReadInteger( remoteTime ) ) + { + return false; + } + + if ( !buffer.ReadInteger( serverTime ) ) + { + return false; + } + + return true; } uint32 TimeResponseMessage::Size() const @@ -173,33 +248,55 @@ namespace NetLib buffer.WriteInteger( controlledByPeerId ); buffer.WriteInteger( replicatedClassId ); buffer.WriteShort( dataSize ); - // TODO Create method called WriteData(data, size) in order to avoid this for loop - for ( uint16 i = 0; i < dataSize; ++i ) + if ( dataSize > 0 ) { - buffer.WriteByte( data[ i ] ); + buffer.WriteData( data, dataSize ); } } - void ReplicationMessage::Read( Buffer& buffer ) + bool ReplicationMessage::Read( Buffer& buffer ) { _header.type = MessageType::Replication; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } - replicationAction = buffer.ReadByte(); - networkEntityId = buffer.ReadInteger(); - controlledByPeerId = buffer.ReadInteger(); - replicatedClassId = buffer.ReadInteger(); - dataSize = buffer.ReadShort(); - if ( dataSize > 0 ) + if ( !buffer.ReadByte( replicationAction ) ) { - data = new uint8[ dataSize ]; + return false; + } + + if ( !buffer.ReadInteger( networkEntityId ) ) + { + return false; + } + + if ( !buffer.ReadInteger( controlledByPeerId ) ) + { + return false; } - // TODO Create method called ReadData(uint8& data, size) in order to avoid this for loop - for ( uint16 i = 0; i < dataSize; ++i ) + if ( !buffer.ReadInteger( replicatedClassId ) ) { - data[ i ] = buffer.ReadByte(); + return false; } + + if ( !buffer.ReadShort( dataSize ) ) + { + return false; + } + + if ( dataSize > 0 ) + { + data = new uint8[ dataSize ]; + if ( !buffer.ReadData( data, dataSize ) ) + { + return false; + } + } + + return true; } uint32 ReplicationMessage::Size() const @@ -231,29 +328,33 @@ namespace NetLib _header.Write( buffer ); buffer.WriteShort( dataSize ); - // TODO Create method called WriteData(data, size) in order to avoid this for loop - for ( uint32 i = 0; i < dataSize; ++i ) - { - buffer.WriteByte( data[ i ] ); - } + buffer.WriteData( data, dataSize ); } - void InputStateMessage::Read( Buffer& buffer ) + bool InputStateMessage::Read( Buffer& buffer ) { _header.type = MessageType::Inputs; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } + + if ( !buffer.ReadShort( dataSize ) ) + { + return false; + } - dataSize = buffer.ReadShort(); if ( dataSize > 0 ) { data = new uint8[ dataSize ]; } - // TODO Create method called ReadData(uint8& data, size) in order to avoid this for loop - for ( uint16 i = 0; i < dataSize; ++i ) + if ( !buffer.ReadData( data, dataSize ) ) { - data[ i ] = buffer.ReadByte(); + return false; } + + return true; } uint32 InputStateMessage::Size() const @@ -275,10 +376,15 @@ namespace NetLib _header.Write( buffer ); } - void PingPongMessage::Read( Buffer& buffer ) + bool PingPongMessage::Read( Buffer& buffer ) { _header.type = MessageType::PingPong; - _header.ReadWithoutHeader( buffer ); + if ( !_header.ReadWithoutHeader( buffer ) ) + { + return false; + } + + return true; } uint32 PingPongMessage::Size() const diff --git a/NetworkLibrary/src/communication/message.h b/NetworkLibrary/src/communication/message.h index 09b8300..3c0e6b9 100644 --- a/NetworkLibrary/src/communication/message.h +++ b/NetworkLibrary/src/communication/message.h @@ -18,7 +18,7 @@ namespace NetLib virtual void Write( Buffer& buffer ) const = 0; // Read it without the message header type - virtual void Read( Buffer& buffer ) = 0; + virtual bool Read( Buffer& buffer ) = 0; virtual uint32 Size() const = 0; // TODO Temp, until I find a better way to clean Replication's data field @@ -43,7 +43,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; ~ConnectionRequestMessage() override {}; @@ -55,19 +55,17 @@ namespace NetLib { public: ConnectionChallengeMessage() - : clientSalt( 0 ) - , serverSalt( 0 ) + : serverSalt( 0 ) , Message( MessageType::ConnectionChallenge ) { } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; ~ConnectionChallengeMessage() override {}; - uint64 clientSalt; uint64 serverSalt; }; @@ -81,7 +79,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; ~ConnectionChallengeResponseMessage() override {}; @@ -100,7 +98,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; ~ConnectionAcceptedMessage() override {}; @@ -119,7 +117,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; ~ConnectionDeniedMessage() override {}; @@ -138,7 +136,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; ~DisconnectionMessage() override {}; @@ -157,7 +155,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; ~TimeRequestMessage() override {}; @@ -176,7 +174,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; ~TimeResponseMessage() override {}; @@ -200,7 +198,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; // TODO Make this also dynamic based on replication action. Now it is set to its worst case @@ -227,7 +225,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; void Reset() override; @@ -245,7 +243,7 @@ namespace NetLib } void Write( Buffer& buffer ) const override; - void Read( Buffer& buffer ) override; + bool Read( Buffer& buffer ) override; uint32 Size() const override; void Reset() override; diff --git a/NetworkLibrary/src/communication/message_header.cpp b/NetworkLibrary/src/communication/message_header.cpp index de94d6a..aeec88b 100644 --- a/NetworkLibrary/src/communication/message_header.cpp +++ b/NetworkLibrary/src/communication/message_header.cpp @@ -31,11 +31,22 @@ namespace NetLib ReadWithoutHeader( buffer ); } - void MessageHeader::ReadWithoutHeader( Buffer& buffer ) + bool MessageHeader::ReadWithoutHeader( Buffer& buffer ) { - messageSequenceNumber = buffer.ReadShort(); - uint8 flags = buffer.ReadByte(); + if ( !buffer.ReadShort( messageSequenceNumber ) ) + { + return false; + } + + uint8 flags; + if ( !buffer.ReadByte( flags ) ) + { + return false; + } + isReliable = BitwiseUtils::GetBitAtIndex( flags, 0 ); isOrdered = BitwiseUtils::GetBitAtIndex( flags, 1 ); + + return true; } } // namespace NetLib diff --git a/NetworkLibrary/src/communication/message_header.h b/NetworkLibrary/src/communication/message_header.h index 59f1e50..999e123 100644 --- a/NetworkLibrary/src/communication/message_header.h +++ b/NetworkLibrary/src/communication/message_header.h @@ -40,7 +40,7 @@ namespace NetLib void Write( Buffer& buffer ) const; void Read( Buffer& buffer ); - void ReadWithoutHeader( Buffer& buffer ); + bool ReadWithoutHeader( Buffer& buffer ); static uint32 Size() { return sizeof( MessageType ) + sizeof( uint16 ) + sizeof( uint8 ); } ~MessageHeader() {} diff --git a/NetworkLibrary/src/communication/message_utils.cpp b/NetworkLibrary/src/communication/message_utils.cpp index 87864cd..50e23c0 100644 --- a/NetworkLibrary/src/communication/message_utils.cpp +++ b/NetworkLibrary/src/communication/message_utils.cpp @@ -10,7 +10,13 @@ namespace NetLib { std::unique_ptr< Message > MessageUtils::ReadMessage( MessageFactory& message_factory, Buffer& buffer ) { - MessageType type = static_cast< MessageType >( buffer.ReadByte() ); + uint8 messageType; + if ( !buffer.ReadByte( messageType ) ) + { + return nullptr; + } + + const MessageType type = static_cast< MessageType >( messageType ); std::unique_ptr< Message > message = nullptr; switch ( type ) @@ -54,7 +60,10 @@ namespace NetLib if ( message != nullptr ) { - message->Read( buffer ); + if ( !message->Read( buffer ) ) + { + return nullptr; + } } return std::move( message ); diff --git a/NetworkLibrary/src/communication/network_packet.cpp b/NetworkLibrary/src/communication/network_packet.cpp index 32720c4..8f37b9a 100644 --- a/NetworkLibrary/src/communication/network_packet.cpp +++ b/NetworkLibrary/src/communication/network_packet.cpp @@ -1,12 +1,12 @@ #include "network_packet.h" #include "asserts.h" +#include "Logger.h" #include "core/buffer.h" #include "communication/message.h" #include "communication/message_utils.h" -#include "communication/message_factory.h" namespace NetLib { @@ -24,13 +24,6 @@ namespace NetLib buffer.WriteByte( channelType ); } - void NetworkPacketHeader::Read( Buffer& buffer ) - { - lastAckedSequenceNumber = buffer.ReadShort(); - ackBits = buffer.ReadInteger(); - channelType = buffer.ReadByte(); - } - NetworkPacket::NetworkPacket() : _header( 0, 0, 0 ) , _defaultMTUSizeInBytes( 1500 ) @@ -51,25 +44,20 @@ namespace NetLib } } - void NetworkPacket::Read( MessageFactory& message_factory, Buffer& buffer ) + bool NetworkPacket::AddMessage( std::unique_ptr< Message > message ) { - _header.Read( buffer ); - - const uint8 numberOfMessages = buffer.ReadByte(); + _messages.push_back( std::move( message ) ); + return true; + } - for ( uint32 i = 0; i < numberOfMessages; ++i ) + bool NetworkPacket::AddMessages( std::vector< std::unique_ptr< Message > >& messages ) + { + for ( auto it = messages.begin(); it != messages.end(); ++it ) { - std::unique_ptr< Message > message = MessageUtils::ReadMessage( message_factory, buffer ); - if ( message != nullptr ) - { - AddMessage( std::move( message ) ); - } + AddMessage( std::move( *it ) ); } - } - bool NetworkPacket::AddMessage( std::unique_ptr< Message > message ) - { - _messages.push_back( std::move( message ) ); + messages.clear(); return true; } @@ -92,7 +80,7 @@ namespace NetLib uint32 NetworkPacket::Size() const { - uint32 packetSize = NetworkPacketHeader::Size(); + uint32 packetSize = NetworkPacketHeader::SIZE; packetSize += 1; // We store in 1 byte the number of messages that this packet contains auto iterator = _messages.cbegin(); diff --git a/NetworkLibrary/src/communication/network_packet.h b/NetworkLibrary/src/communication/network_packet.h index c4de551..359807c 100644 --- a/NetworkLibrary/src/communication/network_packet.h +++ b/NetworkLibrary/src/communication/network_packet.h @@ -8,7 +8,6 @@ namespace NetLib { class Buffer; class Message; - class MessageFactory; struct NetworkPacketHeader { @@ -27,9 +26,6 @@ namespace NetLib } void Write( Buffer& buffer ) const; - void Read( Buffer& buffer ); - - static uint32 Size() { return sizeof( uint16 ) + sizeof( uint32 ) + sizeof( uint8 ); }; void SetACKs( uint32 acks ) { ackBits = acks; }; void SetHeaderLastAcked( uint16 lastAckedMessage ) { lastAckedSequenceNumber = lastAckedMessage; }; @@ -38,6 +34,8 @@ namespace NetLib uint16 lastAckedSequenceNumber; uint32 ackBits; uint8 channelType; + + static constexpr uint32 SIZE = sizeof( uint16 ) + sizeof( uint32 ) + sizeof( uint8 ); }; class NetworkPacket @@ -54,11 +52,11 @@ namespace NetLib NetworkPacket& operator=( NetworkPacket&& other ) noexcept = delete; void Write( Buffer& buffer ) const; - void Read( MessageFactory& message_factory, Buffer& buffer ); const NetworkPacketHeader& GetHeader() const { return _header; }; bool AddMessage( std::unique_ptr< Message > message ); + bool AddMessages( std::vector< std::unique_ptr< Message > >& messages ); std::unique_ptr< Message > TryGetNextMessage(); const std::vector< std::unique_ptr< Message > >& GetAllMessages() const; uint32 GetNumberOfMessages() const { return static_cast< uint32 >( _messages.size() ); } diff --git a/NetworkLibrary/src/communication/network_packet_utils.cpp b/NetworkLibrary/src/communication/network_packet_utils.cpp index f155e1c..a5779c3 100644 --- a/NetworkLibrary/src/communication/network_packet_utils.cpp +++ b/NetworkLibrary/src/communication/network_packet_utils.cpp @@ -1,12 +1,115 @@ #include "network_packet_utils.h" #include "asserts.h" +#include "Logger.h" +#include "core/buffer.h" #include "communication/network_packet.h" #include "communication/message_factory.h" +#include "communication/message_utils.h" + +#include "transmission_channels/transmission_channel.h" + +#include namespace NetLib { + static bool ReadNetworkPacketHeader( Buffer& buffer, NetworkPacketHeader& out_header ) + { + if ( buffer.GetRemainingSize() < NetworkPacketHeader::SIZE ) + { + LOG_ERROR( "Not enough data in buffer to read Network Packet header." ); + return false; + } + + if ( !buffer.ReadShort( out_header.lastAckedSequenceNumber ) ) + { + return false; + } + + if ( !buffer.ReadInteger( out_header.ackBits ) ) + { + return false; + } + + if ( !buffer.ReadByte( out_header.channelType ) ) + { + return false; + } + + if ( out_header.channelType >= TransmissionChannelType::Count ) + { + LOG_ERROR( "Invalid channel type %u in Network Packet header.", out_header.channelType ); + return false; + } + + return true; + } + + static bool ReadNetworkPacketMessages( Buffer& buffer, MessageFactory& message_factory, + std::vector< std::unique_ptr< Message > >& out_messages ) + { + // Read number of messages + uint8 numberOfMessages; + if ( !buffer.ReadByte( numberOfMessages ) ) + { + LOG_ERROR( "Error reading number of messages in Network Packet." ); + return false; + } + + // Read messages + uint32 messageCount = 0; + for ( ; messageCount < numberOfMessages; ++messageCount ) + { + std::unique_ptr< Message > message = MessageUtils::ReadMessage( message_factory, buffer ); + if ( message != nullptr ) + { + out_messages.push_back( std::move( message ) ); + } + else + { + LOG_ERROR( "Error reading Network Packet message %u/%u.", messageCount + 1, numberOfMessages ); + break; + } + } + + // If not all messages were read succesfully, release the ones that were read + if ( messageCount < numberOfMessages ) + { + for ( auto it = out_messages.begin(); it < out_messages.end(); ++it ) + { + message_factory.ReleaseMessage( std::move( *it ) ); + } + return false; + } + + return true; + } + + bool NetworkPacketUtils::ReadNetworkPacket( Buffer& buffer, MessageFactory& message_factory, + NetworkPacket& out_packet ) + { + // Read header + NetworkPacketHeader header; + if ( !ReadNetworkPacketHeader( buffer, header ) ) + { + LOG_ERROR( "Error reading Network Packet header." ); + return false; + } + + // Read messages + std::vector< std::unique_ptr< Message > > messages; + if ( !ReadNetworkPacketMessages( buffer, message_factory, messages ) ) + { + return false; + } + + // Configure output packet + out_packet.SetHeader( header ); + out_packet.AddMessages( std::move( messages ) ); + return true; + } + void NetworkPacketUtils::CleanPacket( MessageFactory& message_factory, NetworkPacket& packet ) { const uint32 numberOfMessages = packet.GetNumberOfMessages(); diff --git a/NetworkLibrary/src/communication/network_packet_utils.h b/NetworkLibrary/src/communication/network_packet_utils.h index 8288443..0610760 100644 --- a/NetworkLibrary/src/communication/network_packet_utils.h +++ b/NetworkLibrary/src/communication/network_packet_utils.h @@ -4,9 +4,11 @@ namespace NetLib { class NetworkPacket; class MessageFactory; + class Buffer; namespace NetworkPacketUtils { + bool ReadNetworkPacket( Buffer& buffer, MessageFactory& message_factory, NetworkPacket& out_packet ); void CleanPacket( MessageFactory& message_factory, NetworkPacket& packet ); }; } \ No newline at end of file diff --git a/NetworkLibrary/src/connection/client_connection_pipeline.cpp b/NetworkLibrary/src/connection/client_connection_pipeline.cpp new file mode 100644 index 0000000..655c614 --- /dev/null +++ b/NetworkLibrary/src/connection/client_connection_pipeline.cpp @@ -0,0 +1,203 @@ +#include "client_connection_pipeline.h" + +#include "connection/pending_connection.h" +#include "communication/message.h" +#include "communication/message_factory.h" + +#include "logger.h" + +namespace NetLib +{ + namespace Connection + { + static std::unique_ptr< Message > CreateConnectionChallengeResponseMessage( MessageFactory& message_factory, + uint64 data_prefix ) + { + LOG_INFO( "%s Creating connection challenge response message for pending connection", THIS_FUNCTION_NAME ); + + // Get a connection challenge message + std::unique_ptr< Message > message = + message_factory.LendMessage( MessageType::ConnectionChallengeResponse ); + if ( message == nullptr ) + { + LOG_ERROR( + "%s Can't create new Connection Challenge Response Message because the MessageFactory has returned " + "a null message", + THIS_FUNCTION_NAME ); + return nullptr; + } + + std::unique_ptr< ConnectionChallengeResponseMessage > connectionChallengeResponseMessage( + static_cast< ConnectionChallengeResponseMessage* >( message.release() ) ); + + // Set connection challenge fields + connectionChallengeResponseMessage->prefix = data_prefix; + + return connectionChallengeResponseMessage; + } + + static void ProcessConnectionChallenge( PendingConnection& pending_connection, + const ConnectionChallengeMessage& message, + MessageFactory& message_factory ) + { + LOG_INFO( "%s Processing challenge message for pending connection", THIS_FUNCTION_NAME ); + + // Check if state is valid for a connection challenge + if ( pending_connection.GetCurrentState() != PendingConnectionState::ConnectionChallenge && + pending_connection.GetCurrentState() != PendingConnectionState::Initializing ) + { + LOG_WARNING( "%s Pending connection is not in a valid state to process a connection challenge message. " + "Current state: %u", + THIS_FUNCTION_NAME, static_cast< uint32 >( pending_connection.GetCurrentState() ) ); + return; + } + + // Update pending connection + pending_connection.SetServerSalt( message.serverSalt ); + pending_connection.SetCurrentState( PendingConnectionState::ConnectionChallenge ); + + // Create connection challenge response + const uint64 dataPrefix = pending_connection.GetClientSalt() ^ message.serverSalt; + std::unique_ptr< Message > connectionChallengeMessage = + CreateConnectionChallengeResponseMessage( message_factory, dataPrefix ); + pending_connection.AddMessage( std::move( connectionChallengeMessage ) ); + } + + static void ProcessConnectionAccepted( PendingConnection& pending_connection, + const ConnectionAcceptedMessage& message ) + { + LOG_INFO( "%s Processing connection accepted message for pending connection", THIS_FUNCTION_NAME ); + + // Check if state is valid for a connection challenge + if ( pending_connection.GetCurrentState() != PendingConnectionState::ConnectionChallenge && + pending_connection.GetCurrentState() != PendingConnectionState::Completed ) + { + LOG_WARNING( "%s Pending connection is not in a valid state to process a connection accepted message. " + "Current state: %u", + THIS_FUNCTION_NAME, static_cast< uint32 >( pending_connection.GetCurrentState() ) ); + return; + } + + pending_connection.SetCurrentState( PendingConnectionState::Completed ); + pending_connection.SetId( 0 ); + pending_connection.SetClientSideId( message.clientIndexAssigned ); + } + + static void ProcessConnectionDenied( PendingConnection& pending_connection, + const ConnectionDeniedMessage& message ) + { + LOG_INFO( "%s Processing connection denied message for pending connection", THIS_FUNCTION_NAME ); + + pending_connection.SetCurrentState( PendingConnectionState::Failed ); + pending_connection.SetConnectionDeniedReason( static_cast< ConnectionFailedReasonType >( message.reason ) ); + } + + static void ProcessMessage( PendingConnection& pending_connection, const Message* message, + MessageFactory& message_factory ) + { + const MessageType type = message->GetHeader().type; + + switch ( type ) + { + case MessageType::ConnectionChallenge: + { + ProcessConnectionChallenge( pending_connection, + static_cast< const ConnectionChallengeMessage& >( *message ), + message_factory ); + } + break; + case MessageType::ConnectionAccepted: + { + ProcessConnectionAccepted( pending_connection, + static_cast< const ConnectionAcceptedMessage& >( *message ) ); + } + break; + case MessageType::ConnectionDenied: + { + ProcessConnectionDenied( pending_connection, + static_cast< const ConnectionDeniedMessage& >( *message ) ); + } + break; + default: + LOG_ERROR( "ClientConnectionPipeline.%s Incoming message not supported. Type: %u", + THIS_FUNCTION_NAME, static_cast< uint8 >( message->GetHeader().type ) ); + break; + } + } + + static std::unique_ptr< Message > CreateConnectionRequestMessage( MessageFactory& message_factory, + uint64 client_salt ) + { + // Get a connection challenge message + std::unique_ptr< Message > message = message_factory.LendMessage( MessageType::ConnectionRequest ); + if ( message == nullptr ) + { + LOG_ERROR( "%s Can't create new Connection Request Message because the MessageFactory has returned " + "a null message", + THIS_FUNCTION_NAME ); + return nullptr; + } + + std::unique_ptr< ConnectionRequestMessage > connectionRequestMessage( + static_cast< ConnectionRequestMessage* >( message.release() ) ); + + // Set connection request fields + connectionRequestMessage->clientSalt = client_salt; + + return connectionRequestMessage; + } + + static uint64 GenerateClientSaltNumber() + { + // TODO Change this for a better generator. rand is not generating a full 64bit integer since its maximum is + // roughly 32767. I have tried to use mt19937_64 but I think I get a conflict with winsocks and + // std::uniform_int_distribution + srand( static_cast< uint32 >( time( NULL ) ) ); + return rand(); + } + + static void AddConnectionRequestMessage( PendingConnection& pending_connection, + MessageFactory& message_factory ) + { + LOG_INFO( "%s Adding connection request message to pending connection", THIS_FUNCTION_NAME ); + + if ( !pending_connection.HasClientSaltAssigned() ) + { + const uint64 clientSalt = GenerateClientSaltNumber(); + pending_connection.SetClientSalt( clientSalt ); + } + + // Create connection request + std::unique_ptr< Message > connectionRequestMessage = + CreateConnectionRequestMessage( message_factory, pending_connection.GetClientSalt() ); + + // Add message to pending connection + pending_connection.AddMessage( std::move( connectionRequestMessage ) ); + } + + void ClientConnectionPipeline::ProcessConnection( PendingConnection& pending_connection, + MessageFactory& message_factory, float32 elapsed_time ) + { + // Process pending connections + bool areMessagesToProcess = true; + while ( areMessagesToProcess ) + { + const Message* message = pending_connection.GetPendingReadyToProcessMessage(); + if ( message != nullptr ) + { + ProcessMessage( pending_connection, message, message_factory ); + } + else + { + areMessagesToProcess = false; + } + } + + // Start with new connections + if ( pending_connection.GetCurrentState() == PendingConnectionState::Initializing ) + { + AddConnectionRequestMessage( pending_connection, message_factory ); + } + } + } // namespace Connection +} // namespace NetLib diff --git a/NetworkLibrary/src/connection/client_connection_pipeline.h b/NetworkLibrary/src/connection/client_connection_pipeline.h new file mode 100644 index 0000000..1923c5d --- /dev/null +++ b/NetworkLibrary/src/connection/client_connection_pipeline.h @@ -0,0 +1,15 @@ +#pragma once +#include "connection/i_connection_pipeline.h" + +namespace NetLib +{ + namespace Connection + { + class ClientConnectionPipeline : public IConnectionPipeline + { + public: + virtual void ProcessConnection( PendingConnection& pending_connection, MessageFactory& message_factory, + float32 elapsed_time ) override; + }; + } +} diff --git a/NetworkLibrary/src/connection/connection_failed_reason_type.h b/NetworkLibrary/src/connection/connection_failed_reason_type.h new file mode 100644 index 0000000..6209608 --- /dev/null +++ b/NetworkLibrary/src/connection/connection_failed_reason_type.h @@ -0,0 +1,18 @@ +#pragma once +#include "numeric_types.h" + +namespace NetLib +{ + namespace Connection + { + enum class ConnectionFailedReasonType : uint8 + { + UNKNOWN = 0, // Unexpect + TIMEOUT = 1, // The peer is inactive + SERVER_FULL = 2, // The server can't handle more connections, it has reached its maximum + PEER_SHUT_DOWN = 3, // The peer has shut down its Network system + CONNECTION_TIMEOUT = 4, // The in process connection has taken too long + WRONG_CHALLENGE_RESPONSE = 5 // The challenge response from the client didn't match the expected one + }; + } +} diff --git a/NetworkLibrary/src/connection/connection_manager.cpp b/NetworkLibrary/src/connection/connection_manager.cpp new file mode 100644 index 0000000..3202cae --- /dev/null +++ b/NetworkLibrary/src/connection/connection_manager.cpp @@ -0,0 +1,354 @@ +#include "connection_manager.h" + +#include "connection/i_connection_pipeline.h" +#include "communication/network_packet.h" +#include "communication/message.h" +#include "core/remote_peers_handler.h" + +#include "logger.h" +#include "asserts.h" + +namespace NetLib +{ + namespace Connection + { + ConnectionManager::ConnectionManager() + : _isStartedUp( false ) + , _messageFactory( nullptr ) + , _remotePeersHandler( nullptr ) + , _pendingConnections() + , _connectionPipeline( nullptr ) + , _connectionTimeoutSeconds( 0.f ) + , _canStartConnections( false ) + , _sendDenialOnTimeout( false ) + { + } + + bool ConnectionManager::StartUp( ConnectionConfiguration& configuration, MessageFactory* message_factory, + const RemotePeersHandler* remote_peers_handler ) + { + if ( _isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is already started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + ASSERT( message_factory != nullptr, "Message factory can't be null" ); + ASSERT( remote_peers_handler != nullptr, "Remote peers handler can't be null" ); + + _messageFactory = message_factory; + _remotePeersHandler = remote_peers_handler; + + _pendingConnections.reserve( _remotePeersHandler->GetMaxConnections() ); + + if ( configuration.connectionPipeline != nullptr ) + { + _connectionPipeline = configuration.connectionPipeline; + _canStartConnections = configuration.canStartConnections; + _connectionTimeoutSeconds = configuration.connectionTimeoutSeconds; + _sendDenialOnTimeout = configuration.sendDenialOnTimeout; + + _isStartedUp = true; + } + else + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager cannot start up. Configuration has invalid fields", + THIS_FUNCTION_NAME ); + } + + return _isStartedUp; + } + + bool ConnectionManager::ShutDown() + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + // Shut down all pending connections + for ( auto& it = _pendingConnections.begin(); it != _pendingConnections.end(); ++it ) + { + if ( !it->second.ShutDown() ) + { + LOG_ERROR( "[ConnectionManager.%s] Failed to shut down pending connection.", THIS_FUNCTION_NAME ); + } + } + _pendingConnections.clear(); + + _messageFactory = nullptr; + _remotePeersHandler = nullptr; + + // Shut down connection pipeline + if ( _connectionPipeline != nullptr ) + { + delete _connectionPipeline; + _connectionPipeline = nullptr; + } + + _isStartedUp = false; + + return !_isStartedUp; + } + + void ConnectionManager::Tick( float32 elapsed_time ) + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return; + } + + for ( auto& it = _pendingConnections.begin(); it != _pendingConnections.end(); ++it ) + { + PendingConnection& pc = it->second; + _connectionPipeline->ProcessConnection( pc, *_messageFactory, elapsed_time ); + + UpdateTimeout( pc, elapsed_time ); + } + } + + bool ConnectionManager::DoesPendingConnectionExist( const Address& address ) const + { + return ( _pendingConnections.find( address ) != _pendingConnections.end() ); + } + + bool ConnectionManager::ProcessPacket( const Address& address, NetworkPacket& packet ) + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + ASSERT( address.IsValid(), "ConnectionManager.%s Address is not valid.", THIS_FUNCTION_NAME ); + + bool success = true; + // Check if pending connection exists + if ( !DoesPendingConnectionExist( address ) ) + { + // Try creating a pending connection if it doesn't exist - There have to be empty slots left + if ( !CreatePendingConnection( address, false ) ) + { + success = false; + + std::string fullAddress; + address.GetFull( fullAddress ); + LOG_WARNING( "ConnectionManager.%s Cannot create pending connection with address %s.", + THIS_FUNCTION_NAME, fullAddress.c_str() ); + } + } + + if ( success ) + { + // Process packet + PendingConnection& pendingConnection = _pendingConnections[ address ]; + pendingConnection.ProcessPacket( packet ); + } + + return success; + } + + bool ConnectionManager::CreatePendingConnection( const Address& address, bool started_locally ) + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + bool success = false; + // Check if we have slots left for new pending connections + if ( AreSlotsAvailableForNewPendingConnection() ) + { + // Check if we have already created this pending connection + if ( !DoesPendingConnectionExist( address ) ) + { + // Create and start up + _pendingConnections.try_emplace( address, _messageFactory ); + success = _pendingConnections[ address ].StartUp( address, started_locally ); + if ( !success ) + { + _pendingConnections.erase( address ); + + std::string fullAddress; + address.GetFull( fullAddress ); + LOG_ERROR( "[ConnectionManager.%s] Failed to start up pending connection for address %s.", + THIS_FUNCTION_NAME, fullAddress.c_str() ); + } + } + else + { + std::string fullAddress; + address.GetFull( fullAddress ); + LOG_WARNING( "ConnectionManager.%s Pending connection with address %s already exists.", + THIS_FUNCTION_NAME, fullAddress.c_str() ); + } + } + else + { + LOG_WARNING( "ConnectionManager.%s Cannot create pending connection. Maximum " + "number of concurrent pending connections has been reached.", + THIS_FUNCTION_NAME ); + } + + return success; + } + + bool ConnectionManager::AreSlotsAvailableForNewPendingConnection() const + { + const uint32 numberOfAvailableRemotePeerSlots = _remotePeersHandler->GetNumberOfAvailableRemotePeerSlots(); + const uint32 numberOfCurrentPendingConnections = _pendingConnections.size(); + return ( numberOfAvailableRemotePeerSlots > numberOfCurrentPendingConnections ); + } + + bool ConnectionManager::StartConnectingToAddress( const Address& address ) + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + bool success = false; + if ( _canStartConnections ) + { + if ( CreatePendingConnection( address, true ) ) + { + PendingConnection& pendingConnection = _pendingConnections[ address ]; + pendingConnection.SetCurrentState( PendingConnectionState::Initializing ); + success = true; + } + } + else + { + LOG_WARNING( + "ConnectionManager.%s Cannot start connection to address. Starting connections is disabled " + "in the configuration.", + THIS_FUNCTION_NAME ); + } + + return success; + } + + void ConnectionManager::GetSuccessConnectionsData( + std::vector< SuccessConnectionData >& out_success_connections ) + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return; + } + + for ( auto& cit = _pendingConnections.cbegin(); cit != _pendingConnections.cend(); ++cit ) + { + const PendingConnection& pc = cit->second; + if ( pc.GetCurrentState() == PendingConnectionState::Completed ) + { + out_success_connections.emplace_back( pc.GetAddress(), pc.WasStartedLocally(), pc.GetId(), + pc.GetClientSideId(), pc.GetDataPrefix() ); + } + } + } + + void ConnectionManager::RemoveSuccessConnections() + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return; + } + + for ( auto it = _pendingConnections.begin(); it != _pendingConnections.end(); ) + { + if ( it->second.GetCurrentState() == PendingConnectionState::Completed ) + { + it = _pendingConnections.erase( it ); + } + else + { + ++it; + } + } + } + + void ConnectionManager::GetFailedConnectionsData( std::vector< FailedConnectionData >& out_failed_connections ) + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return; + } + + for ( auto& cit = _pendingConnections.cbegin(); cit != _pendingConnections.cend(); ++cit ) + { + const PendingConnection& pc = cit->second; + if ( pc.GetCurrentState() == PendingConnectionState::Failed ) + { + out_failed_connections.emplace_back( pc.GetAddress(), pc.GetConnectionDeniedReason() ); + } + } + } + + void ConnectionManager::RemoveFailedConnections() + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[ConnectionManager.%s] ConnectionManager is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return; + } + + for ( auto it = _pendingConnections.begin(); it != _pendingConnections.end(); ) + { + if ( it->second.GetCurrentState() == PendingConnectionState::Failed ) + { + it = _pendingConnections.erase( it ); + } + else + { + ++it; + } + } + } + + void ConnectionManager::SendData( Socket& socket ) + { + for ( auto it = _pendingConnections.begin(); it != _pendingConnections.end(); ++it ) + { + it->second.SendData( socket ); + } + } + + void ConnectionManager::UpdateTimeout( PendingConnection& pending_connection, float32 elapsed_time ) + { + pending_connection.UpdateConnectionElapsedTime( elapsed_time ); + if ( pending_connection.GetCurrentConnectionElapsedTime() >= _connectionTimeoutSeconds ) + { + std::string addressStr; + pending_connection.GetAddress().GetFull( addressStr ); + LOG_INFO( "ConnectionManager.%s Pending connection with address %s has timed out.", THIS_FUNCTION_NAME, + addressStr.c_str() ); + + pending_connection.SetCurrentState( PendingConnectionState::Failed ); + pending_connection.SetConnectionDeniedReason( ConnectionFailedReasonType::TIMEOUT ); + + if ( _sendDenialOnTimeout ) + { + // TODO Evaluate if it is worth sending a denial message on timeout since the client-side will also + // time out its connection (if config stays the same between server and client.) + } + } + } + } // namespace Connection +} // namespace NetLib diff --git a/NetworkLibrary/src/connection/connection_manager.h b/NetworkLibrary/src/connection/connection_manager.h new file mode 100644 index 0000000..8199112 --- /dev/null +++ b/NetworkLibrary/src/connection/connection_manager.h @@ -0,0 +1,220 @@ +#pragma once +#include "numeric_types.h" + +#include "core/address.h" +#include "connection/pending_connection.h" +#include "connection/connection_failed_reason_type.h" + +#include + +/* +/ The Connection Manager is the main orchestrator of the Connection Component. +/ +/ Responsabilities: +/ - Manage in-progress connections (timeouts, validations, etc) +/ - Handle incoming connection-related packets +/ +/ Dependencies: +/ - Message Factory (For creating connection-related messages) +/ - Remote Peers Handler (For checking if there are empty remote peer slots for new connections) +/ +*/ +namespace NetLib +{ + class Socket; + class NetworkPacket; + class RemotePeersHandler; + class MessageFactory; + + namespace Connection + { + class IConnectionPipeline; + + struct ConnectionConfiguration + { + // If true, this peer can send connection requests to other peers. + bool canStartConnections; + // Maximum time to wait for a connection to be established before timing out (in seconds). + float32 connectionTimeoutSeconds; + // If true, send a denial message when a connection times out. + bool sendDenialOnTimeout; + // The connection pipeline to use for processing connection states and messages. + IConnectionPipeline* connectionPipeline; + }; + + struct SuccessConnectionData + { + SuccessConnectionData( const Address& address, bool started_locally, uint16 id, uint16 client_side_id, + uint64 data_prefix ) + : address( address ) + , startedLocally( started_locally ) + , id( id ) + , clientSideId( client_side_id ) + , dataPrefix( data_prefix ) + { + } + + Address address; + bool startedLocally; + // The remote peer ID assigned by the server + uint16 id; + // The client-side ID assigned by the server (Use this variable when client opens a connection to the + // server and the server assigns an ID to this client's local peer) + uint16 clientSideId; + uint64 dataPrefix; + }; + + struct FailedConnectionData + { + FailedConnectionData( const Address& address, ConnectionFailedReasonType reason ) + : address( address ) + , reason( reason ) + { + } + + Address address; + ConnectionFailedReasonType reason; + }; + + class ConnectionManager + { + public: + ConnectionManager(); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Starts up the connection manager. + / + / notes: Call this method before calling any other from the connection manager. + / + / param configuration: The connection configuration to use + / param message_factory: The Message Factory dependency + / param remote_peers_handler: The Remote Peers Handler dependency + / + / returns: true if started up successfully, false otherwise + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + bool StartUp( ConnectionConfiguration& configuration, MessageFactory* message_factory, + const RemotePeersHandler* remote_peers_handler ); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Shuts down the connection manager. + / + / returns: true if shut down successfully, false otherwise + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + bool ShutDown(); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Process an incoming network packet that doesn't belong to any connected remote peer. + / + / param address: The source address of the network packet + / param packet: The network packet to process + / + / returns: true if processed successfully, false otherwise + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + bool ProcessPacket( const Address& address, NetworkPacket& packet ); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Updates the connection manager and all its subsystems + / + / param elapsed_time: The time since the last tick + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + void Tick( float32 elapsed_time ); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Starts connecting to a specified address + / + / notes: If a pending connection already exists with address, a new ConnectionMessage will be sent + / + / param address: The address to start connection with + / + / returns: true if connection started, false otherwise + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + bool StartConnectingToAddress( const Address& address ); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Sends the pending connection-related messages, if any, to the remote connections. + / + / param socket: The socket used to transmit the data through + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + void SendData( Socket& socket ); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Gets all the success connections that hasn't been removed yet. + / + / notes: Do not forget to call RemoveSuccessConnections after processing the data. + / + / param out_success_connections: An output array with the data of the success connections. + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + void GetSuccessConnectionsData( std::vector< SuccessConnectionData >& out_success_connections ); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Remove all the success connections to get some available slots for the new ones. + / + / notes: Call this method after processing the success connections data, or you will lost them. + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + void RemoveSuccessConnections(); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Gets all the failed connections that hasn't been removed yet. + / + / notes: Do not forget to call RemoveFailedConnections after processing the data. + / + / param out_success_connections: An output array with the data of the failed connections. + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + void GetFailedConnectionsData( std::vector< FailedConnectionData >& out_failed_connections ); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Remove all the failed connections to get some available slots for the new ones. + / + / notes: Call this method after processing the failed connections data, or you will lost them. + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + void RemoveFailedConnections(); + + private: + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Checks if a pending connection to the specified address already exists + / + / param address: The address to check + / + / returns: true if exists, false otherwise + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + bool DoesPendingConnectionExist( const Address& address ) const; + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Creates a new pending connection based on an address + / + / param address: The new pending connection's address + / param started_locally: Whether the connection was started in the local or remote peer + / + / returns: true if created, false otherwise + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + bool CreatePendingConnection( const Address& address, bool started_locally ); + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Check if it's possible to host another new connection + / + / returns: true if possible, false otherwise + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + bool AreSlotsAvailableForNewPendingConnection() const; + + /*>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + / brief: Updates the pending connection's elapsed times. + / + / notes: If any exceeds the maximum elapsed time then force a connection failure. + / + / param pending_connection: The new pending connection's to update + / param elapsed_time: The elapsed time since the last call + >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>*/ + void UpdateTimeout( PendingConnection& pending_connection, float32 elapsed_time ); + + MessageFactory* _messageFactory; + const RemotePeersHandler* _remotePeersHandler; + + bool _isStartedUp; + std::unordered_map< Address, PendingConnection, AddressHasher > _pendingConnections; + IConnectionPipeline* _connectionPipeline; + float32 _connectionTimeoutSeconds; + bool _canStartConnections; + bool _sendDenialOnTimeout; + }; + } // namespace Connection +} // namespace NetLib diff --git a/NetworkLibrary/src/connection/i_connection_pipeline.h b/NetworkLibrary/src/connection/i_connection_pipeline.h new file mode 100644 index 0000000..b204848 --- /dev/null +++ b/NetworkLibrary/src/connection/i_connection_pipeline.h @@ -0,0 +1,23 @@ +#pragma once +#include "numeric_types.h" + +#include + +namespace NetLib +{ + class MessageFactory; + + namespace Connection + { + class PendingConnection; + + class IConnectionPipeline + { + public: + IConnectionPipeline() = default; + virtual ~IConnectionPipeline() = default; + virtual void ProcessConnection( PendingConnection& pending_connection, MessageFactory& message_factory, + float32 elapsed_time ) = 0; + }; + } +} diff --git a/NetworkLibrary/src/connection/pending_connection.cpp b/NetworkLibrary/src/connection/pending_connection.cpp new file mode 100644 index 0000000..db76ed8 --- /dev/null +++ b/NetworkLibrary/src/connection/pending_connection.cpp @@ -0,0 +1,172 @@ +#include "pending_connection.h" + +#include "logger.h" +#include "asserts.h" + +#include "communication/network_packet.h" + +namespace NetLib +{ + namespace Connection + { + PendingConnection::PendingConnection() + : _isStartedUp( false ) + , _address( Address::GetInvalid() ) + , _transmissionChannel( nullptr ) + , _startedLocally( false ) + , _metricsHandler() + , _currentState( PendingConnectionState::Initializing ) + , _currentConnectionElapsedTimeSeconds( 0.f ) + , _clientSalt( 0 ) + , _serverSalt( 0 ) + , _dataPrefix( 0 ) + , _hasClientSaltAssigned( false ) + , _hasServerSaltAssigned( false ) + , _id( 0 ) + , _clientSideId( 0 ) + , _connectionDeniedReason( ConnectionFailedReasonType::UNKNOWN ) + + { + } + + PendingConnection::PendingConnection( MessageFactory* message_factory ) + : _isStartedUp( false ) + , _address( Address::GetInvalid() ) + , _transmissionChannel( message_factory ) + , _startedLocally( false ) + , _metricsHandler() + , _currentState( PendingConnectionState::Initializing ) + , _currentConnectionElapsedTimeSeconds( 0.f ) + , _clientSalt( 0 ) + , _serverSalt( 0 ) + , _dataPrefix( 0 ) + , _hasClientSaltAssigned( false ) + , _hasServerSaltAssigned( false ) + , _id( 0 ) + , _clientSideId( 0 ) + , _connectionDeniedReason( ConnectionFailedReasonType::UNKNOWN ) + { + } + + bool PendingConnection::StartUp( const Address& address, bool started_locally ) + { + if ( _isStartedUp ) + { + LOG_ERROR( "[PendingConnection.%s] PendingConnection is already started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + _address = address; + _startedLocally = started_locally; + _currentState = PendingConnectionState::Initializing; + _hasClientSaltAssigned = false; + _hasServerSaltAssigned = false; + _currentConnectionElapsedTimeSeconds = 0.f; + _isStartedUp = true; + return true; + } + + bool PendingConnection::ShutDown() + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[PendingConnection.%s] PendingConnection is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + _address = Address::GetInvalid(); + _isStartedUp = false; + return true; + } + + void PendingConnection::ProcessPacket( NetworkPacket& packet ) + { + // Process packet ACKs + const uint32 acks = packet.GetHeader().ackBits; + const uint16 lastAckedMessageSequenceNumber = packet.GetHeader().lastAckedSequenceNumber; + _transmissionChannel.ProcessACKs( acks, lastAckedMessageSequenceNumber, _metricsHandler ); + + // Process packet messages one by one + while ( packet.GetNumberOfMessages() > 0 ) + { + std::unique_ptr< Message > message = packet.TryGetNextMessage(); + AddReceivedMessage( std::move( message ) ); + } + + if ( _metricsHandler.HasMetric( Metrics::MetricType::DOWNLOAD_BANDWIDTH ) ) + { + const uint32 packet_size = packet.Size(); + _metricsHandler.AddValue( Metrics::MetricType::DOWNLOAD_BANDWIDTH, packet_size ); + } + } + + const Message* PendingConnection::GetPendingReadyToProcessMessage() + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[PendingConnection.%s] PendingConnection is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return nullptr; + } + + const Message* message = nullptr; + + if ( _transmissionChannel.ArePendingReadyToProcessMessages() ) + { + message = _transmissionChannel.GetReadyToProcessMessage(); + } + + return message; + } + + bool PendingConnection::AddMessage( std::unique_ptr< Message > message ) + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[PendingConnection.%s] PendingConnection is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + return _transmissionChannel.AddMessageToSend( std::move( message ) ); + } + + void PendingConnection::SetClientSalt( uint64 client_salt ) + { + _clientSalt = client_salt; + _hasClientSaltAssigned = true; + }; + + void PendingConnection::SetServerSalt( uint64 server_salt ) + { + _serverSalt = server_salt; + _hasServerSaltAssigned = true; + }; + + bool PendingConnection::AddReceivedMessage( std::unique_ptr< Message > message ) + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[PendingConnection.%s] PendingConnection is not started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + const bool result = _transmissionChannel.AddReceivedMessage( std::move( message ), _metricsHandler ); + + return result; + } + + void PendingConnection::SendData( Socket& socket ) + { + _transmissionChannel.CreateAndSendPacket( socket, _address, _metricsHandler ); + } + + void PendingConnection::UpdateConnectionElapsedTime( float32 elapsed_time ) + { + _currentConnectionElapsedTimeSeconds += elapsed_time; + }; + } // namespace Connection +} // namespace NetLib diff --git a/NetworkLibrary/src/connection/pending_connection.h b/NetworkLibrary/src/connection/pending_connection.h new file mode 100644 index 0000000..467b2ea --- /dev/null +++ b/NetworkLibrary/src/connection/pending_connection.h @@ -0,0 +1,91 @@ +#pragma once +#include "numeric_types.h" + +#include "core/address.h" +#include "transmission_channels/unreliable_unordered_transmission_channel.h" +#include "metrics/metrics_handler.h" +#include "connection/connection_failed_reason_type.h" + +namespace NetLib +{ + class MessageFactory; + class Socket; + class NetworkPacket; + + namespace Connection + { + enum class PendingConnectionState : uint8 + { + Initializing = 0, + ConnectionChallenge = 1, + Completed = 2, + Failed = 3 + }; + + class PendingConnection + { + public: + PendingConnection(); + PendingConnection( MessageFactory* message_factory ); + + bool StartUp( const Address& address, bool started_locally ); + bool ShutDown(); + + void ProcessPacket( NetworkPacket& packet ); + + PendingConnectionState GetCurrentState() const { return _currentState; }; + void SetCurrentState( PendingConnectionState new_state ) { _currentState = new_state; }; + + const Message* GetPendingReadyToProcessMessage(); + bool AddMessage( std::unique_ptr< Message > message ); + + void SendData( Socket& socket ); + + void UpdateConnectionElapsedTime( float32 elapsed_time ); + + float32 GetCurrentConnectionElapsedTime() const { return _currentConnectionElapsedTimeSeconds; }; + + uint64 GetClientSalt() const { return _clientSalt; }; + uint64 GetServerSalt() const { return _serverSalt; }; + uint64 GetDataPrefix() const { return _dataPrefix; }; + void GenerateDataPrefix() { _dataPrefix = _clientSalt ^ _serverSalt; }; + void SetClientSalt( uint64 client_salt ); + void SetServerSalt( uint64 server_salt ); + + uint16 GetId() const { return _id; }; + void SetId( uint16 id ) { _id = id; }; + uint16 GetClientSideId() const { return _clientSideId; }; + void SetClientSideId( uint16 client_side_id ) { _clientSideId = client_side_id; }; + const Address& GetAddress() const { return _address; }; + bool WasStartedLocally() const { return _startedLocally; }; + bool HasClientSaltAssigned() const { return _hasClientSaltAssigned; }; + bool HasServerSaltAssigned() const { return _hasServerSaltAssigned; }; + ConnectionFailedReasonType GetConnectionDeniedReason() const { return _connectionDeniedReason; } + void SetConnectionDeniedReason( ConnectionFailedReasonType reason ) + { + _connectionDeniedReason = reason; + }; + + private: + bool AddReceivedMessage( std::unique_ptr< Message > message ); + + bool _isStartedUp; + + Address _address; + PendingConnectionState _currentState; + UnreliableUnorderedTransmissionChannel _transmissionChannel; + Metrics::MetricsHandler _metricsHandler; + + uint64 _clientSalt; + bool _hasClientSaltAssigned; + uint64 _serverSalt; + bool _hasServerSaltAssigned; + uint64 _dataPrefix; + uint16 _id; + uint16 _clientSideId; + bool _startedLocally; + float32 _currentConnectionElapsedTimeSeconds; + ConnectionFailedReasonType _connectionDeniedReason; + }; + } // namespace Connection +} // namespace NetLib diff --git a/NetworkLibrary/src/connection/server_connection_pipeline.cpp b/NetworkLibrary/src/connection/server_connection_pipeline.cpp new file mode 100644 index 0000000..c1e5354 --- /dev/null +++ b/NetworkLibrary/src/connection/server_connection_pipeline.cpp @@ -0,0 +1,235 @@ +#include "server_connection_pipeline.h" + +#include "logger.h" +#include "asserts.h" + +#include "communication/message.h" +#include "communication/message_factory.h" +#include "connection/pending_connection.h" +#include "core/peer.h" + +namespace NetLib +{ + namespace Connection + { + static uint64 GenerateServerSalt() + { + // TODO Change this in order to get another random generator that generates 64bit numbers + srand( static_cast< uint32 >( time( NULL ) ) + 3589 ); + uint64 serverSalt = rand(); + return serverSalt; + } + + static std::unique_ptr< Message > CreateConnectionChallengeMessage( MessageFactory& message_factory, + uint64 client_salt, uint64 server_salt ) + { + LOG_INFO( "%s Creating connection challenge message for pending connection", THIS_FUNCTION_NAME ); + + std::unique_ptr< Message > message = message_factory.LendMessage( MessageType::ConnectionChallenge ); + if ( message == nullptr ) + { + LOG_ERROR( "%s Can't create new Connection Challenge Message because the MessageFactory has returned a " + "null message", + THIS_FUNCTION_NAME ); + return nullptr; + } + + std::unique_ptr< ConnectionChallengeMessage > connectionChallengeMessage( + static_cast< ConnectionChallengeMessage* >( message.release() ) ); + connectionChallengeMessage->serverSalt = server_salt; + + return connectionChallengeMessage; + } + + static std::unique_ptr< Message > CreateConnectionDeniedMessage( MessageFactory& message_factory, + ConnectionFailedReasonType reason ) + { + LOG_INFO( "%s Creating connection denied message for pending connection", THIS_FUNCTION_NAME ); + + std::unique_ptr< Message > message = message_factory.LendMessage( MessageType::ConnectionDenied ); + if ( message == nullptr ) + { + LOG_ERROR( "%s Can't create new connection denied Message because the MessageFactory has returned a " + "null message", + THIS_FUNCTION_NAME ); + return nullptr; + } + + std::unique_ptr< ConnectionDeniedMessage > connectionDeniedMessage( + static_cast< ConnectionDeniedMessage* >( message.release() ) ); + connectionDeniedMessage->reason = static_cast< uint8 >( reason ); + + return connectionDeniedMessage; + } + + static std::unique_ptr< Message > CreateConnectionAcceptedMessage( MessageFactory& message_factory, + uint64 data_prefix, uint16 id ) + { + LOG_INFO( "%s Creating connection accepted message for pending connection", THIS_FUNCTION_NAME ); + + std::unique_ptr< Message > message = message_factory.LendMessage( MessageType::ConnectionAccepted ); + if ( message == nullptr ) + { + LOG_ERROR( + "%s Can't create new Connection Accepted Message because the MessageFactory has returned a null " + "message", + THIS_FUNCTION_NAME ); + return nullptr; + } + + std::unique_ptr< ConnectionAcceptedMessage > connectionAcceptedMessage( + static_cast< ConnectionAcceptedMessage* >( message.release() ) ); + connectionAcceptedMessage->prefix = data_prefix; + connectionAcceptedMessage->clientIndexAssigned = id; + + return connectionAcceptedMessage; + } + + static void ProcessConnectionRequest( PendingConnection& pending_connection, + const ConnectionRequestMessage& message, MessageFactory& message_factory ) + { + LOG_INFO( "%s Processing connection request message for pending connection", THIS_FUNCTION_NAME ); + + std::unique_ptr< Message > outcomeMessage = nullptr; + + // If it is in initializing or connection challenge state, send challenge + if ( pending_connection.GetCurrentState() == PendingConnectionState::Initializing || + pending_connection.GetCurrentState() == PendingConnectionState::ConnectionChallenge ) + { + if ( pending_connection.GetCurrentState() == PendingConnectionState::Initializing ) + { + pending_connection.SetClientSalt( message.clientSalt ); + pending_connection.SetServerSalt( GenerateServerSalt() ); + pending_connection.GenerateDataPrefix(); + pending_connection.SetCurrentState( PendingConnectionState::ConnectionChallenge ); + } + + outcomeMessage = CreateConnectionChallengeMessage( message_factory, pending_connection.GetClientSalt(), + pending_connection.GetServerSalt() ); + } + // If it is in completed state, send accepted + else if ( pending_connection.GetCurrentState() == PendingConnectionState::Completed ) + { + outcomeMessage = CreateConnectionAcceptedMessage( message_factory, pending_connection.GetDataPrefix(), + pending_connection.GetId() ); + } + // If it is in failed state, send denied + else if ( pending_connection.GetCurrentState() == PendingConnectionState::Failed ) + { + outcomeMessage = + CreateConnectionDeniedMessage( message_factory, pending_connection.GetConnectionDeniedReason() ); + } + + ASSERT( outcomeMessage != nullptr, "Message can't be nullptr" ); + pending_connection.AddMessage( std::move( outcomeMessage ) ); + } + + ServerConnectionPipeline::ServerConnectionPipeline() + : IConnectionPipeline() + , _nextConnectionApprovedId( 1 ) + { + } + + void ServerConnectionPipeline::ProcessConnection( PendingConnection& pending_connection, + MessageFactory& message_factory, float32 elapsed_time ) + { + // Process pending connections + bool areMessagesToProcess = true; + while ( areMessagesToProcess ) + { + const Message* message = pending_connection.GetPendingReadyToProcessMessage(); + if ( message != nullptr ) + { + ProcessMessage( pending_connection, message, message_factory ); + } + else + { + areMessagesToProcess = false; + } + } + } + + void ServerConnectionPipeline::ProcessMessage( PendingConnection& pending_connection, const Message* message, + MessageFactory& message_factory ) + { + const MessageType type = message->GetHeader().type; + + switch ( type ) + { + case MessageType::ConnectionRequest: + { + ProcessConnectionRequest( pending_connection, + static_cast< const ConnectionRequestMessage& >( *message ), + message_factory ); + } + break; + case MessageType::ConnectionChallengeResponse: + { + ProcessConnectionChallengeResponse( + pending_connection, static_cast< const ConnectionChallengeResponseMessage& >( *message ), + message_factory ); + } + break; + default: + LOG_ERROR( "ServerConnectionPipeline.%s Incoming message not supported. Type: %u", + THIS_FUNCTION_NAME, static_cast< uint8 >( message->GetHeader().type ) ); + break; + } + } + + void ServerConnectionPipeline::ProcessConnectionChallengeResponse( + PendingConnection& pending_connection, const ConnectionChallengeResponseMessage& message, + MessageFactory& message_factory ) + { + LOG_INFO( "%s Processing challenge response message for pending connection", THIS_FUNCTION_NAME ); + + std::unique_ptr< Message > outcomeMessage = nullptr; + + // Check if challenge was calculated correctly + if ( pending_connection.GetDataPrefix() == message.prefix ) + { + if ( pending_connection.GetCurrentState() == PendingConnectionState::Failed ) + { + outcomeMessage = CreateConnectionDeniedMessage( message_factory, + pending_connection.GetConnectionDeniedReason() ); + } + else + { + if ( pending_connection.GetCurrentState() != PendingConnectionState::Completed ) + { + pending_connection.SetCurrentState( PendingConnectionState::Completed ); + pending_connection.SetId( GenerateNextConnectionApprovedId() ); + } + + outcomeMessage = CreateConnectionAcceptedMessage( + message_factory, pending_connection.GetDataPrefix(), pending_connection.GetId() ); + } + } + else + { + pending_connection.SetCurrentState( PendingConnectionState::Failed ); + pending_connection.SetConnectionDeniedReason( ConnectionFailedReasonType::WRONG_CHALLENGE_RESPONSE ); + outcomeMessage = + CreateConnectionDeniedMessage( message_factory, pending_connection.GetConnectionDeniedReason() ); + } + + ASSERT( outcomeMessage != nullptr, "Message can't be nullptr" ); + pending_connection.AddMessage( std::move( outcomeMessage ) ); + } + + uint16 ServerConnectionPipeline::GenerateNextConnectionApprovedId() + { + const uint16 result = _nextConnectionApprovedId; + if ( _nextConnectionApprovedId == MAX_UINT16 ) + { + _nextConnectionApprovedId = 1; + } + else + { + _nextConnectionApprovedId++; + } + + return result; + } + } // namespace Connection +} // namespace NetLib diff --git a/NetworkLibrary/src/connection/server_connection_pipeline.h b/NetworkLibrary/src/connection/server_connection_pipeline.h new file mode 100644 index 0000000..a6955e4 --- /dev/null +++ b/NetworkLibrary/src/connection/server_connection_pipeline.h @@ -0,0 +1,31 @@ +#pragma once +#include "connection/i_connection_pipeline.h" + +namespace NetLib +{ + class Message; + class ConnectionChallengeResponseMessage; + + namespace Connection + { + class ServerConnectionPipeline : public IConnectionPipeline + { + public: + ServerConnectionPipeline(); + virtual void ProcessConnection( PendingConnection& pending_connection, MessageFactory& message_factory, + float32 elapsed_time ) override; + + private: + void ProcessMessage( PendingConnection& pending_connection, const Message* message, + MessageFactory& message_factory ); + void ProcessConnectionChallengeResponse( PendingConnection& pending_connection, + const ConnectionChallengeResponseMessage& message, + MessageFactory& message_factory ); + uint16 GenerateNextConnectionApprovedId(); + + static constexpr uint16 SERVER_CONNECTION_ID = 0; + + uint16 _nextConnectionApprovedId; + }; + } +} diff --git a/NetworkLibrary/src/inputs/i_input_state.h b/NetworkLibrary/src/inputs/i_input_state.h index b45da46..6772df0 100644 --- a/NetworkLibrary/src/inputs/i_input_state.h +++ b/NetworkLibrary/src/inputs/i_input_state.h @@ -10,6 +10,12 @@ namespace NetLib public: virtual int32 GetSize() const = 0; virtual void Serialize( Buffer& buffer ) const = 0; - virtual void Deserialize( Buffer& buffer ) = 0; + + /// + /// Deserializes the input state fields from the provided buffer. This function also checks that the fields + /// are valid. + /// + /// True if deserialization and validation went successfully, False otherwise. + virtual bool Deserialize( Buffer& buffer ) = 0; }; } diff --git a/NetworkLibrary/src/inputs/remote_peer_inputs_handler.cpp b/NetworkLibrary/src/inputs/remote_peer_inputs_handler.cpp index 2092865..ba6a474 100644 --- a/NetworkLibrary/src/inputs/remote_peer_inputs_handler.cpp +++ b/NetworkLibrary/src/inputs/remote_peer_inputs_handler.cpp @@ -8,6 +8,9 @@ namespace NetLib { + // TODO: Make the IInputState to have fixed header fields (tick and serverTime) instead of letting the concrete + // input class to define them. In that way we can check for tampered messages or order later to reduce jitter + // effects. Then, in this method, we can check if input's tick and server time are valid. void RemotePeerInputsBuffer::AddInputState( IInputState* input ) { assert( input != nullptr ); diff --git a/NetworkLibrary/src/metrics/download_bandwidth_metric.cpp b/NetworkLibrary/src/metrics/download_bandwidth_metric.cpp index c581345..e4fa849 100644 --- a/NetworkLibrary/src/metrics/download_bandwidth_metric.cpp +++ b/NetworkLibrary/src/metrics/download_bandwidth_metric.cpp @@ -2,8 +2,6 @@ #include "logger.h" -#include "metrics/metric_names.h" - namespace NetLib { namespace Metrics @@ -17,26 +15,26 @@ namespace NetLib { } - void DownloadBandwidthMetric::GetName( std::string& out_name_buffer ) const + MetricType DownloadBandwidthMetric::GetType() const { - out_name_buffer.assign( DOWNLOAD_BANDWIDTH_METRIC ); + return MetricType::DOWNLOAD_BANDWIDTH; } - uint32 DownloadBandwidthMetric::GetValue( const std::string& value_type ) const + uint32 DownloadBandwidthMetric::GetValue( ValueType value_type ) const { uint32 result = 0; - if ( value_type == "MAX" ) + if ( value_type == ValueType::MAX ) { result = _maxValue; } - else if ( value_type == "CURRENT" ) + else if ( value_type == ValueType::CURRENT ) { result = _currentValue; } else { - LOG_WARNING( "Unknown value type '%s' for UploadBandwidthMetric", value_type.c_str() ); + LOG_WARNING( "Unknown value type '%u' for UploadBandwidthMetric", static_cast< uint8 >( value_type ) ); } return result; diff --git a/NetworkLibrary/src/metrics/download_bandwidth_metric.h b/NetworkLibrary/src/metrics/download_bandwidth_metric.h index 24820ca..d840445 100644 --- a/NetworkLibrary/src/metrics/download_bandwidth_metric.h +++ b/NetworkLibrary/src/metrics/download_bandwidth_metric.h @@ -10,8 +10,8 @@ namespace NetLib public: DownloadBandwidthMetric(); - void GetName( std::string& out_name_buffer ) const override; - uint32 GetValue( const std::string& value_type ) const override; + MetricType GetType() const override; + uint32 GetValue( ValueType value_type ) const override; void SetUpdateRate( float32 update_rate ) override; void Update( float32 elapsed_time ) override; void AddValueSample( uint32 value, const std::string& sample_type = "NONE" ) override; diff --git a/NetworkLibrary/src/metrics/i_metric.h b/NetworkLibrary/src/metrics/i_metric.h index b744ce5..65aeda3 100644 --- a/NetworkLibrary/src/metrics/i_metric.h +++ b/NetworkLibrary/src/metrics/i_metric.h @@ -1,6 +1,8 @@ #pragma once #include "numeric_types.h" +#include "metrics/metric_types.h" + #include namespace NetLib @@ -10,8 +12,8 @@ namespace NetLib class IMetric { public: - virtual void GetName( std::string& out_name_buffer ) const = 0; - virtual uint32 GetValue( const std::string& value_type ) const = 0; + virtual MetricType GetType() const = 0; + virtual uint32 GetValue( ValueType value_type ) const = 0; virtual void SetUpdateRate( float32 update_rate ) = 0; virtual void Update( float32 elapsed_time ) = 0; virtual void AddValueSample( uint32 value, const std::string& sample_type = "NONE" ) = 0; diff --git a/NetworkLibrary/src/metrics/increment_metric.cpp b/NetworkLibrary/src/metrics/increment_metric.cpp index 8c4d37d..5c3d7fd 100644 --- a/NetworkLibrary/src/metrics/increment_metric.cpp +++ b/NetworkLibrary/src/metrics/increment_metric.cpp @@ -6,27 +6,28 @@ namespace NetLib { namespace Metrics { - IncrementMetric::IncrementMetric( const std::string& name ) - : _name( name ) + IncrementMetric::IncrementMetric( MetricType type ) + : _type( type ) , _currentValue( 0 ) { } - void IncrementMetric::GetName( std::string& out_name_buffer ) const + MetricType IncrementMetric::GetType() const { - out_name_buffer.assign( _name ); + return _type; } - uint32 IncrementMetric::GetValue( const std::string& value_type ) const + uint32 IncrementMetric::GetValue( ValueType value_type ) const { uint32 result = 0; - if ( value_type == "CURRENT" ) + if ( value_type == ValueType::CURRENT ) { result = _currentValue; } else { - LOG_WARNING( "Unknown value type '%s' for %s Metric", value_type.c_str(), _name.c_str() ); + LOG_WARNING( "Unknown value type '%s' for %u Metric", static_cast< uint8 >( value_type ), + static_cast< uint8 >( _type ) ); } return result; diff --git a/NetworkLibrary/src/metrics/increment_metric.h b/NetworkLibrary/src/metrics/increment_metric.h index 3ac99c2..a269a74 100644 --- a/NetworkLibrary/src/metrics/increment_metric.h +++ b/NetworkLibrary/src/metrics/increment_metric.h @@ -1,6 +1,8 @@ #pragma once #include "metrics/i_metric.h" +#include "metrics/metric_types.h" + namespace NetLib { namespace Metrics @@ -8,10 +10,10 @@ namespace NetLib class IncrementMetric : public IMetric { public: - IncrementMetric( const std::string& name ); + IncrementMetric( MetricType type ); - void GetName( std::string& out_name_buffer ) const override; - uint32 GetValue( const std::string& value_type ) const override; + MetricType GetType() const override; + uint32 GetValue( ValueType value_type ) const override; void SetUpdateRate( float32 update_rate ) override; void Update( float32 elapsed_time ) override; void AddValueSample( uint32 value, const std::string& sample_type = "NONE" ) override; @@ -19,7 +21,7 @@ namespace NetLib private: uint32 _currentValue; - const std::string _name; + const MetricType _type; }; } } diff --git a/NetworkLibrary/src/metrics/jitter_metric.cpp b/NetworkLibrary/src/metrics/jitter_metric.cpp index 44152fa..1b927e8 100644 --- a/NetworkLibrary/src/metrics/jitter_metric.cpp +++ b/NetworkLibrary/src/metrics/jitter_metric.cpp @@ -3,8 +3,6 @@ #include "logger.h" #include "AlgorithmUtils.h" -#include "metrics/metric_names.h" - namespace NetLib { namespace Metrics @@ -18,26 +16,26 @@ namespace NetLib _latencySamples.reserve( MAX_BUFFER_SIZE ); } - void JitterMetric::GetName( std::string& out_name_buffer ) const + MetricType JitterMetric::GetType() const { - out_name_buffer.assign( JITTER_METRIC ); + return MetricType::JITTER; } - uint32 JitterMetric::GetValue( const std::string& value_type ) const + uint32 JitterMetric::GetValue( ValueType value_type ) const { uint32 result = 0; - if ( value_type == "MAX" ) + if ( value_type == ValueType::MAX ) { result = _maxValue; } - else if ( value_type == "CURRENT" ) + else if ( value_type == ValueType::CURRENT ) { result = _currentValue; } else { - LOG_WARNING( "Unknown value type '%s' for JitterMetric", value_type.c_str() ); + LOG_WARNING( "Unknown value type '%u' for JitterMetric", static_cast< uint8 >( value_type ) ); } return result; diff --git a/NetworkLibrary/src/metrics/jitter_metric.h b/NetworkLibrary/src/metrics/jitter_metric.h index 7f345cf..f5cac36 100644 --- a/NetworkLibrary/src/metrics/jitter_metric.h +++ b/NetworkLibrary/src/metrics/jitter_metric.h @@ -15,8 +15,8 @@ namespace NetLib JitterMetric(); - void GetName( std::string& out_name_buffer ) const override; - uint32 GetValue( const std::string& value_type ) const override; + MetricType GetType() const override; + uint32 GetValue( ValueType value_type ) const override; void SetUpdateRate( float32 update_rate ) override; void Update( float32 elapsed_time ) override; void AddValueSample( uint32 value, const std::string& sample_type = "NONE" ) override; diff --git a/NetworkLibrary/src/metrics/latency_metric.cpp b/NetworkLibrary/src/metrics/latency_metric.cpp index 1fdef45..2946cb1 100644 --- a/NetworkLibrary/src/metrics/latency_metric.cpp +++ b/NetworkLibrary/src/metrics/latency_metric.cpp @@ -3,8 +3,6 @@ #include "logger.h" #include "AlgorithmUtils.h" -#include "metrics/metric_names.h" - namespace NetLib { namespace Metrics @@ -17,26 +15,26 @@ namespace NetLib _samples.reserve( MAX_BUFFER_SIZE ); } - void LatencyMetric::GetName( std::string& out_name_buffer ) const + MetricType LatencyMetric::GetType() const { - out_name_buffer.assign( LATENCY_METRIC ); + return MetricType::LATENCY; } - uint32 LatencyMetric::GetValue( const std::string& value_type ) const + uint32 LatencyMetric::GetValue( ValueType value_type ) const { uint32 result = 0; - if ( value_type == "MAX" ) + if ( value_type == ValueType::MAX ) { result = _maxValue; } - else if ( value_type == "CURRENT" ) + else if ( value_type == ValueType::CURRENT ) { result = _currentValue; } else { - LOG_WARNING( "Unknown value type '%s' for LatencyMetric", value_type.c_str() ); + LOG_WARNING( "Unknown value type '%u' for LatencyMetric", static_cast< uint8 >( value_type ) ); } return result; diff --git a/NetworkLibrary/src/metrics/latency_metric.h b/NetworkLibrary/src/metrics/latency_metric.h index 95c2a1c..2bc8cda 100644 --- a/NetworkLibrary/src/metrics/latency_metric.h +++ b/NetworkLibrary/src/metrics/latency_metric.h @@ -15,8 +15,8 @@ namespace NetLib LatencyMetric(); - void GetName( std::string& out_name_buffer ) const override; - uint32 GetValue( const std::string& value_type ) const override; + MetricType GetType() const override; + uint32 GetValue( ValueType value_type ) const override; void SetUpdateRate( float32 update_rate ) override; void Update( float32 elapsed_time ) override; void AddValueSample( uint32 value, const std::string& sample_type = "NONE" ) override; diff --git a/NetworkLibrary/src/metrics/metric_names.h b/NetworkLibrary/src/metrics/metric_names.h deleted file mode 100644 index f5a6062..0000000 --- a/NetworkLibrary/src/metrics/metric_names.h +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -namespace NetLib -{ - namespace Metrics - { - // Metric names - static constexpr char* LATENCY_METRIC = "LATENCY"; - static constexpr char* JITTER_METRIC = "JITTER"; - static constexpr char* PACKET_LOSS_METRIC = "PACKET_LOSS"; - static constexpr char* UPLOAD_BANDWIDTH_METRIC = "UPLOAD_BANDWIDTH"; - static constexpr char* DOWNLOAD_BANDWIDTH_METRIC = "DOWNLOAD_BANDWIDTH"; - static constexpr char* OUT_OF_ORDER_METRIC = "OUT_OF_ORDER"; - static constexpr char* DUPLICATE_METRIC = "DUPLICATE"; - static constexpr char* RETRANSMISSION_METRIC = "RETRANSMISSION"; - - // Metric value types - static constexpr char* CURRENT_VALUE_TYPE = "CURRENT"; - static constexpr char* MAX_VALUE_TYPE = "MAX"; - } -} diff --git a/NetworkLibrary/src/metrics/metric_types.h b/NetworkLibrary/src/metrics/metric_types.h new file mode 100644 index 0000000..e99ba42 --- /dev/null +++ b/NetworkLibrary/src/metrics/metric_types.h @@ -0,0 +1,25 @@ +#pragma once + +namespace NetLib +{ + namespace Metrics + { + enum class MetricType : uint8 + { + LATENCY = 0, + JITTER = 1, + PACKET_LOSS = 2, + UPLOAD_BANDWIDTH = 3, + DOWNLOAD_BANDWIDTH = 4, + RETRANSMISSIONS = 5, + OUT_OF_ORDER_MESSAGES = 6, + DUPLICATE_MESSAGES = 7 + }; + + enum class ValueType : uint8 + { + CURRENT = 0, + MAX = 1 + }; + } +} diff --git a/NetworkLibrary/src/metrics/metrics_handler.cpp b/NetworkLibrary/src/metrics/metrics_handler.cpp index db6eba9..6c48274 100644 --- a/NetworkLibrary/src/metrics/metrics_handler.cpp +++ b/NetworkLibrary/src/metrics/metrics_handler.cpp @@ -2,8 +2,8 @@ #include "AlgorithmUtils.h" #include "logger.h" +#include "asserts.h" -#include "metrics/metric_names.h" #include "metrics/latency_metric.h" #include "metrics/packet_loss_metric.h" #include "metrics/jitter_metric.h" @@ -11,37 +11,135 @@ #include "metrics/download_bandwidth_metric.h" #include "metrics/increment_metric.h" -#include - namespace NetLib { namespace Metrics { + const std::vector< MetricType > MetricsHandler::ALL_METRICS = { MetricType::LATENCY, + MetricType::JITTER, + MetricType::PACKET_LOSS, + MetricType::UPLOAD_BANDWIDTH, + MetricType::DOWNLOAD_BANDWIDTH, + MetricType::RETRANSMISSIONS, + MetricType::OUT_OF_ORDER_MESSAGES, + MetricType::DUPLICATE_MESSAGES }; + MetricsHandler::MetricsHandler() - : _entries() + : _isStartedUp( false ) + , _entries() { - AddEntry( std::make_unique< LatencyMetric >() ); - AddEntry( std::make_unique< JitterMetric >() ); - AddEntry( std::make_unique< PacketLossMetric >() ); - AddEntry( std::make_unique< UploadBandwidthMetric >() ); - AddEntry( std::make_unique< DownloadBandwidthMetric >() ); - AddEntry( std::make_unique< IncrementMetric >( DUPLICATE_METRIC ) ); - AddEntry( std::make_unique< IncrementMetric >( OUT_OF_ORDER_METRIC ) ); - AddEntry( std::make_unique< IncrementMetric >( RETRANSMISSION_METRIC ) ); } - void MetricsHandler::Configure( float32 update_rate ) + MetricsHandler::~MetricsHandler() { - Reset(); + ASSERT( !_isStartedUp, "[MetricsHandler.%s] Before destructing instance, call ShutDown.", + THIS_FUNCTION_NAME ); + } + + bool MetricsHandler::AddMetrics( float32 update_rate, const std::vector< MetricType >& metrics ) + { + bool result = true; + + for ( auto cit = metrics.cbegin(); cit != metrics.cend(); ++cit ) + { + if ( HasMetric( *cit ) ) + { + LOG_WARNING( "[MetricsHandler.%s] Metric of type %u already exists. Ignoring it.", + THIS_FUNCTION_NAME, static_cast< uint8 >( *cit ) ); + continue; + } + + switch ( *cit ) + { + case MetricType::LATENCY: + result &= AddEntry( new LatencyMetric() ); + break; + case MetricType::JITTER: + result &= AddEntry( new JitterMetric() ); + break; + case MetricType::PACKET_LOSS: + result &= AddEntry( new PacketLossMetric() ); + break; + case MetricType::UPLOAD_BANDWIDTH: + result &= AddEntry( new UploadBandwidthMetric() ); + break; + case MetricType::DOWNLOAD_BANDWIDTH: + result &= AddEntry( new DownloadBandwidthMetric() ); + break; + case MetricType::RETRANSMISSIONS: + case MetricType::OUT_OF_ORDER_MESSAGES: + case MetricType::DUPLICATE_MESSAGES: + result &= AddEntry( new IncrementMetric( *cit ) ); + break; + default: + LOG_ERROR( "[MetricsHandler.%s] Unknown MetricType %u. Ignoring it.", THIS_FUNCTION_NAME, + static_cast< uint8 >( *cit ) ); + break; + } + } for ( auto it = _entries.begin(); it != _entries.end(); ++it ) { it->second->SetUpdateRate( update_rate ); } + + return result; + } + + bool MetricsHandler::HasMetric( MetricType type ) const + { + return _entries.find( type ) != _entries.end(); + } + + bool MetricsHandler::StartUp( float32 update_rate, MetricsEnableConfig enable_config, + const std::vector< MetricType >& enabled_metrics ) + { + if ( _isStartedUp ) + { + LOG_ERROR( "[MetricsHandler.%s] MetricsHandler is already started up, ignoring call", + THIS_FUNCTION_NAME ); + return false; + } + + if ( enable_config == MetricsEnableConfig::ENABLE_ALL ) + { + AddMetrics( update_rate, ALL_METRICS ); + } + else if ( enable_config == MetricsEnableConfig::CUSTOM ) + { + AddMetrics( update_rate, enabled_metrics ); + } + + _isStartedUp = true; + return true; + } + + bool MetricsHandler::ShutDown() + { + if ( !_isStartedUp ) + { + LOG_ERROR( "[MetricsHandler.%s] MetricsHandler is not started up, ignoring call", THIS_FUNCTION_NAME ); + return false; + } + + for ( auto it = _entries.begin(); it != _entries.end(); ++it ) + { + delete it->second; + } + + _entries.clear(); + _isStartedUp = false; + return true; } void MetricsHandler::Update( float32 elapsed_time ) { + if ( !_isStartedUp ) + { + LOG_ERROR( "[MetricsHandler.%s] MetricsHandler is not started up, ignoring call", THIS_FUNCTION_NAME ); + return; + } + for ( auto it = _entries.begin(); it != _entries.end(); ++it ) { it->second->Update( elapsed_time ); @@ -53,61 +151,75 @@ namespace NetLib "BANDWIDTH: Current: %u, " "Max: %u\nDOWNLOAD BANDWIDTH: Current: %u, Max: %u\nRETRANSMISSIONS: Current: %u\nOUT OF ORDER: " "Current: %u\nDUPLICATE: Current: %u", - GetValue( LATENCY_METRIC, CURRENT_VALUE_TYPE ), GetValue( LATENCY_METRIC, MAX_VALUE_TYPE ), - GetValue( JITTER_METRIC, CURRENT_VALUE_TYPE ), GetValue( JITTER_METRIC, MAX_VALUE_TYPE ), - GetValue( PACKET_LOSS_METRIC, CURRENT_VALUE_TYPE ), GetValue( PACKET_LOSS_METRIC, MAX_VALUE_TYPE ), - GetValue( UPLOAD_BANDWIDTH_METRIC, CURRENT_VALUE_TYPE ), - GetValue( UPLOAD_BANDWIDTH_METRIC, MAX_VALUE_TYPE ), - GetValue( DOWNLOAD_BANDWIDTH_METRIC, CURRENT_VALUE_TYPE ), - GetValue( DOWNLOAD_BANDWIDTH_METRIC, MAX_VALUE_TYPE ), - GetValue( RETRANSMISSION_METRIC, CURRENT_VALUE_TYPE ), - GetValue( OUT_OF_ORDER_METRIC, CURRENT_VALUE_TYPE ), GetValue( DUPLICATE_METRIC, CURRENT_VALUE_TYPE ) ); + GetValue( MetricType::LATENCY, ValueType::CURRENT ), GetValue( MetricType::LATENCY, ValueType::MAX ), + GetValue( MetricType::JITTER, ValueType::CURRENT ), GetValue( MetricType::JITTER, ValueType::MAX ), + GetValue( MetricType::PACKET_LOSS, ValueType::CURRENT ), + GetValue( MetricType::PACKET_LOSS, ValueType::MAX ), + GetValue( MetricType::UPLOAD_BANDWIDTH, ValueType::CURRENT ), + GetValue( MetricType::UPLOAD_BANDWIDTH, ValueType::MAX ), + GetValue( MetricType::DOWNLOAD_BANDWIDTH, ValueType::CURRENT ), + GetValue( MetricType::DOWNLOAD_BANDWIDTH, ValueType::MAX ), + GetValue( MetricType::RETRANSMISSIONS, ValueType::CURRENT ), + GetValue( MetricType::OUT_OF_ORDER_MESSAGES, ValueType::CURRENT ), + GetValue( MetricType::DUPLICATE_MESSAGES, ValueType::CURRENT ) ); } - bool MetricsHandler::AddEntry( std::unique_ptr< IMetric > entry ) + bool MetricsHandler::AddEntry( IMetric* metric ) { - assert( entry != nullptr ); + ASSERT( metric != nullptr, "[MetricsHandler.%s] entry is nullptr.", THIS_FUNCTION_NAME ); bool result = false; - std::string name; - entry->GetName( name ); - if ( _entries.find( name ) == _entries.end() ) + const MetricType metricType = metric->GetType(); + if ( !HasMetric( metricType ) ) { - _entries[ name ] = std::move( entry ); + _entries[ metricType ] = metric; result = true; } else { - LOG_WARNING( "Network statistic entry with name '%s' already exists, ignoring the new one", - name.c_str() ); + LOG_WARNING( "Network statistic metric of type '%u' already exists, ignoring the new one", + static_cast< uint8 >( metricType ) ); } return result; } - uint32 MetricsHandler::GetValue( const std::string& entry_name, const std::string& value_type ) const + uint32 MetricsHandler::GetValue( MetricType metric_type, ValueType value_type ) const { + if ( !_isStartedUp ) + { + LOG_ERROR( "[MetricsHandler.%s] MetricsHandler is not started up, ignoring call", THIS_FUNCTION_NAME ); + return 0; + } + uint32 result = 0; - auto it = _entries.find( entry_name ); + auto it = _entries.find( metric_type ); if ( it != _entries.end() ) { result = it->second->GetValue( value_type ); } else { - LOG_WARNING( "Can't get a value from a metric that doesn't exist. Name: '%s'", entry_name.c_str() ); + LOG_WARNING( "Can't get a value from a metric that doesn't exist. Metric type: '%u'", + static_cast< uint8 >( metric_type ) ); } return result; } - bool MetricsHandler::AddValue( const std::string& entry_name, uint32 value, const std::string& sample_type ) + bool MetricsHandler::AddValue( MetricType metric_type, uint32 value, const std::string& sample_type ) { + if ( !_isStartedUp ) + { + LOG_ERROR( "[MetricsHandler.%s] MetricsHandler is not started up, ignoring call", THIS_FUNCTION_NAME ); + return false; + } + bool result = false; - auto it = _entries.find( entry_name ); + auto it = _entries.find( metric_type ); if ( it != _entries.end() ) { it->second->AddValueSample( value, sample_type ); @@ -115,18 +227,11 @@ namespace NetLib } else { - LOG_WARNING( "Can't add a value to a metric that doesn't exist. Name: '%s'", entry_name.c_str() ); + LOG_WARNING( "Can't add a value to a metric that doesn't exist. Metric type: '%u'", + static_cast< uint8 >( metric_type ) ); } return result; } - - void MetricsHandler::Reset() - { - for ( auto it = _entries.begin(); it != _entries.end(); ++it ) - { - it->second->Reset(); - } - } } // namespace Metrics } // namespace NetLib diff --git a/NetworkLibrary/src/metrics/metrics_handler.h b/NetworkLibrary/src/metrics/metrics_handler.h index 8296546..96863b5 100644 --- a/NetworkLibrary/src/metrics/metrics_handler.h +++ b/NetworkLibrary/src/metrics/metrics_handler.h @@ -2,6 +2,7 @@ #include "numeric_types.h" #include "metrics/i_metric.h" +#include "metrics/metric_types.h" #include #include @@ -12,35 +13,48 @@ namespace NetLib { namespace Metrics { + enum class MetricsEnableConfig : uint8 + { + ENABLE_ALL = 0, + DISABLE_ALL = 1, + CUSTOM = 2 + }; + class MetricsHandler { public: MetricsHandler(); + ~MetricsHandler(); - void Configure( float32 update_rate ); + bool StartUp( float32 update_rate, MetricsEnableConfig enable_config, + const std::vector< MetricType >& enabled_metrics = {} ); + bool ShutDown(); void Update( float32 elapsed_time ); /// - /// Add a network statistic entry. This class is responsible for handling the entries memory. + /// Gets the value of type value_type (MAX, CURRENT...) from the metric of type metric_type. If the + /// 1) Metrics handler is not started up, 2) The metric does not exist or 3) the value type is invalid + /// the function returns 0. /// - bool AddEntry( std::unique_ptr< IMetric > entry ); + uint32 GetValue( MetricType metric_type, ValueType value_type ) const; /// - /// Gets the value of type value_type (MAX, CURRENT...) from the entry with name entry_name. If the - /// entry name or the value type is invalid the function returns 0. + /// Adds a value to the metric of type metric_type. If the 1) Metrics handler is not started up or 2) + /// the metric doesn't exists it does nothing and returns false. /// - uint32 GetValue( const std::string& entry_name, const std::string& value_type ) const; + bool AddValue( MetricType metric_type, uint32 value, const std::string& sample_type = "NONE" ); - /// - /// Adds a value to the entry with name entry_name. If the entry doesn't exists it does nothing and - /// returns false. - /// - bool AddValue( const std::string& entry_name, uint32 value, const std::string& sample_type = "NONE" ); + bool HasMetric( MetricType type ) const; private: - void Reset(); - std::unordered_map< std::string, std::unique_ptr< IMetric > > _entries; + bool AddMetrics( float32 update_rate, const std::vector< MetricType >& metrics ); + bool AddEntry( IMetric* metric ); + + bool _isStartedUp; + std::unordered_map< MetricType, IMetric* > _entries; + + static const std::vector< MetricType > ALL_METRICS; }; - } + } // namespace Metrics } // namespace NetLib \ No newline at end of file diff --git a/NetworkLibrary/src/metrics/packet_loss_metric.cpp b/NetworkLibrary/src/metrics/packet_loss_metric.cpp index cafe6fd..747c9de 100644 --- a/NetworkLibrary/src/metrics/packet_loss_metric.cpp +++ b/NetworkLibrary/src/metrics/packet_loss_metric.cpp @@ -1,6 +1,6 @@ #include "packet_loss_metric.h" -#include "metrics/metric_names.h" +#include "metrics/metric_types.h" #include "logger.h" @@ -18,26 +18,26 @@ namespace NetLib { } - void PacketLossMetric::GetName( std::string& out_name_buffer ) const + MetricType PacketLossMetric::GetType() const { - out_name_buffer.assign( PACKET_LOSS_METRIC ); + return MetricType::PACKET_LOSS; } - uint32 PacketLossMetric::GetValue( const std::string& value_type ) const + uint32 PacketLossMetric::GetValue( ValueType value_type ) const { uint32 result = 0; - if ( value_type == "MAX" ) + if ( value_type == ValueType::MAX ) { result = _maxValue; } - else if ( value_type == "CURRENT" ) + else if ( value_type == ValueType::CURRENT ) { result = _currentValue; } else { - LOG_WARNING( "Unknown value type '%s' for PacketLossMetric", value_type.c_str() ); + LOG_WARNING( "Unknown value type '%u' for PacketLossMetric", static_cast< uint8 >( value_type ) ); } return result; diff --git a/NetworkLibrary/src/metrics/packet_loss_metric.h b/NetworkLibrary/src/metrics/packet_loss_metric.h index 83fbaec..69aa26f 100644 --- a/NetworkLibrary/src/metrics/packet_loss_metric.h +++ b/NetworkLibrary/src/metrics/packet_loss_metric.h @@ -12,8 +12,8 @@ namespace NetLib public: PacketLossMetric(); - void GetName( std::string& out_name_buffer ) const override; - uint32 GetValue( const std::string& value_type ) const override; + MetricType GetType() const override; + uint32 GetValue( ValueType value_type ) const override; void SetUpdateRate( float32 update_rate ) override; void Update( float32 elapsed_time ) override; void AddValueSample( uint32 value, const std::string& sample_type = "NONE" ) override; diff --git a/NetworkLibrary/src/metrics/upload_bandwidth_metric.cpp b/NetworkLibrary/src/metrics/upload_bandwidth_metric.cpp index 65b9289..206606b 100644 --- a/NetworkLibrary/src/metrics/upload_bandwidth_metric.cpp +++ b/NetworkLibrary/src/metrics/upload_bandwidth_metric.cpp @@ -2,8 +2,6 @@ #include "logger.h" -#include "metrics/metric_names.h" - namespace NetLib { namespace Metrics @@ -17,26 +15,26 @@ namespace NetLib { } - void UploadBandwidthMetric::GetName( std::string& out_name_buffer ) const + MetricType UploadBandwidthMetric::GetType() const { - out_name_buffer.assign( UPLOAD_BANDWIDTH_METRIC ); + return MetricType::UPLOAD_BANDWIDTH; } - uint32 UploadBandwidthMetric::GetValue( const std::string& value_type ) const + uint32 UploadBandwidthMetric::GetValue( ValueType value_type ) const { uint32 result = 0; - if ( value_type == "MAX" ) + if ( value_type == ValueType::MAX ) { result = _maxValue; } - else if ( value_type == "CURRENT" ) + else if ( value_type == ValueType::CURRENT ) { result = _currentValue; } else { - LOG_WARNING( "Unknown value type '%s' for UploadBandwidthMetric", value_type.c_str() ); + LOG_WARNING( "Unknown value type '%u' for UploadBandwidthMetric", static_cast< uint8 >( value_type ) ); } return result; diff --git a/NetworkLibrary/src/metrics/upload_bandwidth_metric.h b/NetworkLibrary/src/metrics/upload_bandwidth_metric.h index c0e9d90..c55bf95 100644 --- a/NetworkLibrary/src/metrics/upload_bandwidth_metric.h +++ b/NetworkLibrary/src/metrics/upload_bandwidth_metric.h @@ -10,8 +10,8 @@ namespace NetLib public: UploadBandwidthMetric(); - void GetName( std::string& out_name_buffer ) const override; - uint32 GetValue( const std::string& value_type ) const override; + MetricType GetType() const override; + uint32 GetValue( ValueType value_type ) const override; void SetUpdateRate( float32 update_rate ) override; void Update( float32 elapsed_time ) override; void AddValueSample( uint32 value, const std::string& sample_type = "NONE" ) override; diff --git a/NetworkLibrary/src/transmission_channels/reliable_ordered_channel.cpp b/NetworkLibrary/src/transmission_channels/reliable_ordered_channel.cpp index e2e59d8..e13109b 100644 --- a/NetworkLibrary/src/transmission_channels/reliable_ordered_channel.cpp +++ b/NetworkLibrary/src/transmission_channels/reliable_ordered_channel.cpp @@ -14,8 +14,8 @@ #include "core/socket.h" #include "core/address.h" -#include "metrics/metric_names.h" #include "metrics/metrics_handler.h" +#include "metrics/metric_types.h" #include "logger.h" #include "AlgorithmUtils.h" @@ -82,7 +82,7 @@ namespace NetLib } bool ReliableOrderedChannel::CreateAndSendPacket( Socket& socket, const Address& address, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { bool result = false; @@ -134,9 +134,9 @@ namespace NetLib socket.SendTo( buffer.GetData(), buffer.GetSize(), address ); // TODO See what happens when the socket couldn't send the packet - if ( metrics_handler != nullptr ) + if ( metrics_handler.HasMetric( Metrics::MetricType::UPLOAD_BANDWIDTH ) ) { - metrics_handler->AddValue( Metrics::UPLOAD_BANDWIDTH_METRIC, packet.Size() ); + metrics_handler.AddValue( Metrics::MetricType::UPLOAD_BANDWIDTH, packet.Size() ); } _areUnsentACKs = false; @@ -147,9 +147,9 @@ namespace NetLib std::unique_ptr< Message > message = packet.TryGetNextMessage(); AddUnackedMessage( std::move( message ) ); - if ( metrics_handler != nullptr ) + if ( metrics_handler.HasMetric( Metrics::MetricType::PACKET_LOSS ) ) { - metrics_handler->AddValue( Metrics::PACKET_LOSS_METRIC, 1, "SENT" ); + metrics_handler.AddValue( Metrics::MetricType::PACKET_LOSS, 1, "SENT" ); } } @@ -177,7 +177,7 @@ namespace NetLib return ( !_unsentMessages.empty() || AreUnackedMessagesToResend() ); } - std::unique_ptr< Message > ReliableOrderedChannel::GetMessageToSend( Metrics::MetricsHandler* metrics_handler ) + std::unique_ptr< Message > ReliableOrderedChannel::GetMessageToSend( Metrics::MetricsHandler& metrics_handler ) { std::unique_ptr< Message > message = nullptr; if ( !_unsentMessages.empty() ) @@ -194,9 +194,9 @@ namespace NetLib else { message = TryGetUnackedMessageToResend(); - if ( message != nullptr && metrics_handler != nullptr ) + if ( message != nullptr && metrics_handler.HasMetric( Metrics::MetricType::RETRANSMISSIONS ) ) { - metrics_handler->AddValue( Metrics::RETRANSMISSION_METRIC, 1 ); + metrics_handler.AddValue( Metrics::MetricType::RETRANSMISSIONS, 1 ); } } @@ -244,7 +244,7 @@ namespace NetLib } bool ReliableOrderedChannel::AddReceivedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { assert( message != nullptr ); @@ -259,9 +259,9 @@ namespace NetLib LOG_INFO( "The message with ID = %hu is duplicated. Ignoring it...", messageSequenceNumber ); // Submit duplicate message metric - if ( metrics_handler != nullptr ) + if ( metrics_handler.HasMetric( Metrics::MetricType::DUPLICATE_MESSAGES ) ) { - metrics_handler->AddValue( Metrics::DUPLICATE_METRIC, 1 ); + metrics_handler.AddValue( Metrics::MetricType::DUPLICATE_MESSAGES, 1 ); } // Release duplicate message @@ -404,7 +404,7 @@ namespace NetLib // Check if with this new message received we can process other newer (out of order) messages received in the // previous states. bool continueProcessing = true; - while ( !_unorderedMessagesWaitingForPrevious.empty() || continueProcessing ) + while ( !_unorderedMessagesWaitingForPrevious.empty() && continueProcessing ) { uint32 index = 0; if ( DoesUnorderedMessagesBufferContainsSequenceNumber( _nextOrderedMessageSequenceNumber, index ) ) @@ -425,12 +425,12 @@ namespace NetLib } void ReliableOrderedChannel::ProcessUnorderedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { AddUnorderedMessage( std::move( message ) ); - if ( metrics_handler != nullptr ) + if ( metrics_handler.HasMetric( Metrics::MetricType::OUT_OF_ORDER_MESSAGES ) ) { - metrics_handler->AddValue( Metrics::OUT_OF_ORDER_METRIC, 1 ); + metrics_handler.AddValue( Metrics::MetricType::OUT_OF_ORDER_MESSAGES, 1 ); } } @@ -440,7 +440,7 @@ namespace NetLib } bool ReliableOrderedChannel::TryRemoveAckedMessageFromUnacked( uint16 sequence_number, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { bool result = false; @@ -454,11 +454,14 @@ namespace NetLib UpdateRTT( messageRTT ); // Submit latency and jitter metrics - if ( metrics_handler != nullptr ) + const uint32 latency = messageRTT / 2; + if ( metrics_handler.HasMetric( Metrics::MetricType::LATENCY ) ) { - const uint32 latency = messageRTT / 2; - metrics_handler->AddValue( Metrics::LATENCY_METRIC, latency ); - metrics_handler->AddValue( Metrics::JITTER_METRIC, latency ); + metrics_handler.AddValue( Metrics::MetricType::LATENCY, latency ); + } + if ( metrics_handler.HasMetric( Metrics::MetricType::JITTER ) ) + { + metrics_handler.AddValue( Metrics::MetricType::JITTER, latency ); } // Remove message from buffers @@ -607,7 +610,7 @@ namespace NetLib } void ReliableOrderedChannel::ProcessACKs( uint32 acks, uint16 lastAckedMessageSequenceNumber, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { LOG_INFO( "Last acked from client = %hu", lastAckedMessageSequenceNumber ); @@ -639,7 +642,7 @@ namespace NetLib return result; } - void ReliableOrderedChannel::Update( float32 deltaTime, Metrics::MetricsHandler* metrics_handler ) + void ReliableOrderedChannel::Update( float32 deltaTime, Metrics::MetricsHandler& metrics_handler ) { // Update unacked message timeouts std::list< float32 >::iterator it = _unackedReliableMessageTimeouts.begin(); @@ -650,9 +653,9 @@ namespace NetLib if ( timeout <= 0 ) { - if ( metrics_handler != nullptr ) + if ( metrics_handler.HasMetric( Metrics::MetricType::PACKET_LOSS ) ) { - metrics_handler->AddValue( Metrics::PACKET_LOSS_METRIC, 1, "LOST" ); + metrics_handler.AddValue( Metrics::MetricType::PACKET_LOSS, 1, "LOST" ); } timeout = 0; } diff --git a/NetworkLibrary/src/transmission_channels/reliable_ordered_channel.h b/NetworkLibrary/src/transmission_channels/reliable_ordered_channel.h index 2369c4c..b9d8316 100644 --- a/NetworkLibrary/src/transmission_channels/reliable_ordered_channel.h +++ b/NetworkLibrary/src/transmission_channels/reliable_ordered_channel.h @@ -37,21 +37,21 @@ namespace NetLib ReliableOrderedChannel& operator=( ReliableOrderedChannel&& other ) noexcept; bool CreateAndSendPacket( Socket& socket, const Address& address, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; bool AddMessageToSend( std::unique_ptr< Message > message ) override; bool ArePendingMessagesToSend() const override; - std::unique_ptr< Message > GetMessageToSend( Metrics::MetricsHandler* metrics_handler ); + std::unique_ptr< Message > GetMessageToSend( Metrics::MetricsHandler& metrics_handler ); bool AddReceivedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; bool ArePendingReadyToProcessMessages() const override; const Message* GetReadyToProcessMessage() override; void ProcessACKs( uint32 acks, uint16 lastAckedMessageSequenceNumber, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; - void Update( float32 deltaTime, Metrics::MetricsHandler* metrics_handler ) override; + void Update( float32 deltaTime, Metrics::MetricsHandler& metrics_handler ) override; void Reset() override; @@ -146,7 +146,7 @@ namespace NetLib /// A pointer to the metrics handler to update LATENCY and JITTER /// metrics. /// True if the acked message was removed from _unackedReliableMessages, False otherwise. - bool TryRemoveAckedMessageFromUnacked( uint16 sequence_number, Metrics::MetricsHandler* metrics_handler ); + bool TryRemoveAckedMessageFromUnacked( uint16 sequence_number, Metrics::MetricsHandler& metrics_handler ); /// /// Gets, if available, from the _unackedReliableMessages buffer the unacked message index associated to @@ -190,7 +190,7 @@ namespace NetLib /// The unordered message received /// The metrics handler to update out of order metric metric void ProcessUnorderedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ); + Metrics::MetricsHandler& metrics_handler ); /// /// Adds an unordered message to the _unorderedMessagesWaitingForPrevious buffer until you get them in the diff --git a/NetworkLibrary/src/transmission_channels/transmission_channel.h b/NetworkLibrary/src/transmission_channels/transmission_channel.h index e1b47ae..852efe9 100644 --- a/NetworkLibrary/src/transmission_channels/transmission_channel.h +++ b/NetworkLibrary/src/transmission_channels/transmission_channel.h @@ -1,13 +1,14 @@ #pragma once #include "numeric_types.h" +#include "communication/message.h" + #include #include #include namespace NetLib { - class Message; class MessageFactory; class Socket; class Address; @@ -21,7 +22,8 @@ namespace NetLib { UnreliableOrdered = 0, ReliableOrdered = 1, - UnreliableUnordered = 2 + UnreliableUnordered = 2, + Count = 3 }; class TransmissionChannel @@ -45,7 +47,7 @@ namespace NetLib /// bandwidth. /// True if the packet was created and sent, False otherwise. virtual bool CreateAndSendPacket( Socket& socket, const Address& address, - Metrics::MetricsHandler* metrics_handler ) = 0; + Metrics::MetricsHandler& metrics_handler ) = 0; /// /// Adds to the channel a message pending to be sent through the network. The header of the message must be @@ -68,15 +70,15 @@ namespace NetLib /// The message received to be processed. /// True if the message was stored correclt, False otherwise. virtual bool AddReceivedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ) = 0; + Metrics::MetricsHandler& metrics_handler ) = 0; virtual bool ArePendingReadyToProcessMessages() const = 0; virtual const Message* GetReadyToProcessMessage() = 0; void FreeProcessedMessages(); virtual void ProcessACKs( uint32 acks, uint16 lastAckedMessageSequenceNumber, - Metrics::MetricsHandler* metrics_handler ) = 0; + Metrics::MetricsHandler& metrics_handler ) = 0; - virtual void Update( float32 deltaTime, Metrics::MetricsHandler* metrics_handler ) = 0; + virtual void Update( float32 deltaTime, Metrics::MetricsHandler& metrics_handler ) = 0; virtual void Reset(); diff --git a/NetworkLibrary/src/transmission_channels/unreliable_ordered_transmission_channel.cpp b/NetworkLibrary/src/transmission_channels/unreliable_ordered_transmission_channel.cpp index 10b40ee..f8e2f95 100644 --- a/NetworkLibrary/src/transmission_channels/unreliable_ordered_transmission_channel.cpp +++ b/NetworkLibrary/src/transmission_channels/unreliable_ordered_transmission_channel.cpp @@ -11,8 +11,8 @@ #include "core/socket.h" #include "core/address.h" -#include "metrics/metric_names.h" #include "metrics/metrics_handler.h" +#include "metrics/metric_types.h" namespace NetLib { @@ -41,7 +41,7 @@ namespace NetLib } bool UnreliableOrderedTransmissionChannel::CreateAndSendPacket( Socket& socket, const Address& address, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { bool result = false; @@ -85,9 +85,9 @@ namespace NetLib socket.SendTo( buffer.GetData(), buffer.GetSize(), address ); // TODO See what happens when the socket couldn't send the packet - if ( metrics_handler != nullptr ) + if ( metrics_handler.HasMetric( Metrics::MetricType::UPLOAD_BANDWIDTH ) ) { - metrics_handler->AddValue( Metrics::UPLOAD_BANDWIDTH_METRIC, packet.Size() ); + metrics_handler.AddValue( Metrics::MetricType::UPLOAD_BANDWIDTH, packet.Size() ); } // Clean messages @@ -122,7 +122,7 @@ namespace NetLib } std::unique_ptr< Message > UnreliableOrderedTransmissionChannel::GetMessageToSend( - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { if ( !ArePendingMessagesToSend() ) { @@ -151,7 +151,7 @@ namespace NetLib } bool UnreliableOrderedTransmissionChannel::AddReceivedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { assert( message != nullptr ); @@ -195,7 +195,7 @@ namespace NetLib } void UnreliableOrderedTransmissionChannel::ProcessACKs( uint32 acks, uint16 lastAckedMessageSequenceNumber, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { // This channel is not supporting ACKs since it is unreliable. So do nothing } @@ -205,7 +205,7 @@ namespace NetLib return false; } - void UnreliableOrderedTransmissionChannel::Update( float32 deltaTime, Metrics::MetricsHandler* metrics_handler ) + void UnreliableOrderedTransmissionChannel::Update( float32 deltaTime, Metrics::MetricsHandler& metrics_handler ) { } diff --git a/NetworkLibrary/src/transmission_channels/unreliable_ordered_transmission_channel.h b/NetworkLibrary/src/transmission_channels/unreliable_ordered_transmission_channel.h index 43b572a..a82e6f5 100644 --- a/NetworkLibrary/src/transmission_channels/unreliable_ordered_transmission_channel.h +++ b/NetworkLibrary/src/transmission_channels/unreliable_ordered_transmission_channel.h @@ -16,23 +16,23 @@ namespace NetLib UnreliableOrderedTransmissionChannel& operator=( UnreliableOrderedTransmissionChannel&& other ) noexcept; bool CreateAndSendPacket( Socket& socket, const Address& address, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; bool AddMessageToSend( std::unique_ptr< Message > message ) override; bool ArePendingMessagesToSend() const override; - std::unique_ptr< Message > GetMessageToSend( Metrics::MetricsHandler* metrics_handler ); + std::unique_ptr< Message > GetMessageToSend( Metrics::MetricsHandler& metrics_handler ); uint32 GetSizeOfNextUnsentMessage() const; bool AddReceivedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; bool ArePendingReadyToProcessMessages() const override; const Message* GetReadyToProcessMessage() override; void ProcessACKs( uint32 acks, uint16 lastAckedMessageSequenceNumber, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; bool IsMessageDuplicated( uint16 messageSequenceNumber ) const; - void Update( float32 deltaTime, Metrics::MetricsHandler* metrics_handler ) override; + void Update( float32 deltaTime, Metrics::MetricsHandler& metrics_handler ) override; void Reset() override; diff --git a/NetworkLibrary/src/transmission_channels/unreliable_unordered_transmission_channel.cpp b/NetworkLibrary/src/transmission_channels/unreliable_unordered_transmission_channel.cpp index 8eba3d1..20c508b 100644 --- a/NetworkLibrary/src/transmission_channels/unreliable_unordered_transmission_channel.cpp +++ b/NetworkLibrary/src/transmission_channels/unreliable_unordered_transmission_channel.cpp @@ -11,8 +11,8 @@ #include "core/socket.h" #include "core/address.h" -#include "metrics/metric_names.h" #include "metrics/metrics_handler.h" +#include "metrics/metric_types.h" namespace NetLib { @@ -35,7 +35,7 @@ namespace NetLib } bool UnreliableUnorderedTransmissionChannel::CreateAndSendPacket( Socket& socket, const Address& address, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { bool result = false; @@ -77,9 +77,9 @@ namespace NetLib socket.SendTo( buffer.GetData(), buffer.GetSize(), address ); // TODO See what happens when the socket couldn't send the packet - if ( metrics_handler != nullptr ) + if ( metrics_handler.HasMetric( Metrics::MetricType::UPLOAD_BANDWIDTH ) ) { - metrics_handler->AddValue( Metrics::UPLOAD_BANDWIDTH_METRIC, packet.Size() ); + metrics_handler.AddValue( Metrics::MetricType::UPLOAD_BANDWIDTH, packet.Size() ); } // Send messages ownership back to remote peer @@ -114,7 +114,7 @@ namespace NetLib } std::unique_ptr< Message > UnreliableUnorderedTransmissionChannel::GetMessageToSend( - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { if ( !ArePendingMessagesToSend() ) { @@ -142,7 +142,7 @@ namespace NetLib } bool UnreliableUnorderedTransmissionChannel::AddReceivedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { assert( message != nullptr ); @@ -177,7 +177,7 @@ namespace NetLib } void UnreliableUnorderedTransmissionChannel::ProcessACKs( uint32 acks, uint16 lastAckedMessageSequenceNumber, - Metrics::MetricsHandler* metrics_handler ) + Metrics::MetricsHandler& metrics_handler ) { } @@ -186,7 +186,7 @@ namespace NetLib return false; } - void UnreliableUnorderedTransmissionChannel::Update( float32 deltaTime, Metrics::MetricsHandler* metrics_handler ) + void UnreliableUnorderedTransmissionChannel::Update( float32 deltaTime, Metrics::MetricsHandler& metrics_handler ) { } diff --git a/NetworkLibrary/src/transmission_channels/unreliable_unordered_transmission_channel.h b/NetworkLibrary/src/transmission_channels/unreliable_unordered_transmission_channel.h index 272c1a5..a8af606 100644 --- a/NetworkLibrary/src/transmission_channels/unreliable_unordered_transmission_channel.h +++ b/NetworkLibrary/src/transmission_channels/unreliable_unordered_transmission_channel.h @@ -21,23 +21,23 @@ namespace NetLib UnreliableUnorderedTransmissionChannel&& other ) noexcept; bool CreateAndSendPacket( Socket& socket, const Address& address, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; bool AddMessageToSend( std::unique_ptr< Message > message ) override; bool ArePendingMessagesToSend() const override; - std::unique_ptr< Message > GetMessageToSend( Metrics::MetricsHandler* metrics_handler ); + std::unique_ptr< Message > GetMessageToSend( Metrics::MetricsHandler& metrics_handler ); uint32 GetSizeOfNextUnsentMessage() const; bool AddReceivedMessage( std::unique_ptr< Message > message, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; bool ArePendingReadyToProcessMessages() const override; const Message* GetReadyToProcessMessage() override; void ProcessACKs( uint32 acks, uint16 lastAckedMessageSequenceNumber, - Metrics::MetricsHandler* metrics_handler ) override; + Metrics::MetricsHandler& metrics_handler ) override; bool IsMessageDuplicated( uint16 messageSequenceNumber ) const; - void Update( float32 deltaTime, Metrics::MetricsHandler* metrics_handler ) override; + void Update( float32 deltaTime, Metrics::MetricsHandler& metrics_handler ) override; private: /// diff --git a/README.md b/README.md index 199d44b..881b234 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A Network Library for competitive realtime multiplayer games. >_**NOTE:** Beginning January 1st, 2026, no further updates will be publicly released for this project in order to comply with my employer’s policies. I apologize for any inconvenience this may cause._ This repository contains four different projects: -1. Network library +1. Network library [Docs](docs/network_library/network_library_index.md) 2. Demo game. [Docs](docs/demo_game/demo_game_index.md) 3. Engine [Docs](docs/engine/engine_index.md) 4. Common (Shared files between Demo game and Network Library projects) diff --git a/docs/network_library/execution_flow.md b/docs/network_library/execution_flow.md new file mode 100644 index 0000000..60f88d7 --- /dev/null +++ b/docs/network_library/execution_flow.md @@ -0,0 +1,192 @@ +# Execution flow +This library operates through four well-defined phases: **Start → PreTick → Tick → Stop**. +Each phase has a specific responsibility and must not perform tasks belonging to other phases. +This ensures predictable behavior, easier debugging, and a clearer architecture. +![Execution flow diagram](./images/full_execution_flow_diagram.png) + +## Index: +- [Start Phase](#Start-Phase) +- [PreTick Phase](#PreTick-Phase) +- [Tick Phase](#Tick-Phase) +- [Stop Phase](#Stop-Phase) + +## Start Phase +Initialize the peer and all required internal systems in preparation for the network loop. +This phase is executed **once**, before the first tick. + +The start phase contains the following **subphases**: +1. [Set connection state to `Connecting`](#1-Set-connection-state-to-Connecting) +2. [Start socket](#2-Start-socket) +3. [Start peer-type specific logic](#3-Start-peer-type-specific-logic) +4. [Set current tick](#4-Set-current-tick) + +### 1 Set connection state to Connecting +Transition the peer's connection state from `Disconnected` to `Connecting`. + +Description of the **procedure**: +- Set local peer state to `Connecting`. + +**Notes**: +- **Server**: promoted later to `Connected` inside this Start process. +- **Client**: remains in `Connecting` until receiving either _ConnectionAccepted_ or _ConnectionDenied_ from the server. + +### 2 Start socket +Create and initialize the underlying network socket. + +Description of the **procedure**: +- Initialize Berkeley Sockets API. +- Create socket object. +- Configure blocking / non-blocking behavior. + +### 3 Start peer-type specific logic +Initialize systems exclusive to the peer type (Client or Server). +Description of the **procedure** for the **server**: +- Bind the socket to listen on all incoming IPs. + +Description of the **procedure** for the **client**: +- Bind the socket to only listen on server's IP. +- Generate communication salt number. +- Create the server's _early remote peer_. + +### 4 Set current tick +Initialize simulation tick counter to `1`. + +## PreTick Phase +Receive network data, validate it, build packets, and place messages into transmission channels and process them. This phase is executed inside a **fixed update loop** and always before the Tick phase. + +This phase **does NOT**: +- Send any data. + +The pre-tick phase contains the following **subphases**: +1. [Read received data](#1-Read-received-data) +2. [Process received data](#2-Process-received-data) + +### 1 Read received data +Read incoming datagrams and convert them into `NetworkPacket` objects. Then stores them within the transmission channels to be ready for being processed. This receive pipeline does not process messages yet. + +Description of the **procedure**: +- For each datagram received: + 1. Read raw data from socket. + 2. Validate datagram. + 3. Construct `NetworkPacket`. + 4. Route its messages to: + - an existing `RemotePeer`, or + - an `EarlyRemotePeer` (peer still in the connection/handshake process). + +### 2 Process received data +Interpret and execute pending messages that were queued during data reception. This is where connection requests, ping-pongs, acknowledgments, and replication messages are consumed. + +Description of the **procedure**: +- For each **RemotePeer**: + - For each pending message: + - Process message. +- For each **EarlyRemotePeer** (⚠️ currently missing): + - For each pending message: + - Process message. + +## Tick Phase +Advance internal systems, update peer state, build outgoing data, send packets, and handle disconnections. This phase is executed inside a **fixed update loop** and always after the Pre-tick phase. + +This phase **does NOT**: +- Read incoming socket data + +The tick phase contains the following **subphases**: +1. [Update remote peers](#1-Update-remote-peers) +2. [Update peer-type specific logic](#2-Update-peer-type-specific-logic) +3. [Finish disconnecting remote peers](#3-Finish-disconnecting-remote-peers) +4. [Send pending data](#4-Send-pending-data) +5. [Stop peer, if requested](#5-Stop-peer-if-requested) + +### 1 Update remote peers +Update systems that belong to remote peers. + +Description of the **procedure**: +- For each Remote Peer: + - **Peer lifecycle systems** + - Update inactivity system. + - Detect conditions to initiate disconnection. + - **Peer communication systems** + - Update transmission channels. + - Update metrics / telemetry. + - Update Ping-Pong. + +### 2 Update peer-type specific logic +Execute logic specific to client or server. + +Description of the **procedure** for the **server**: +- Update replication component. + +Description of the **procedure** for the **client**: +- If connection state `Disconnected`, enqueue a connection request. +- Update time syncer component. + +### 3 Finish disconnecting remote peers +Finalize disconnection of remote peers that entered the disconnecting state. + +Description of the **procedure**: +For each disconnection request: +- If `shouldNotify`, send a disconnection message. +- Perform disconnection cleanup. + +### 4 Send pending data +Build and transmit outgoing packets for all remote peers and all their transmission channels. + +Description of the **procedure**: +- For each **RemotePeer**: + - For each transmission channel: + - Build and send a network packet. +- For each **EarlyRemotePeer** (⚠️ currently missing): + - For each transmission channel: + - Build and send a network packet. + +**Notes**: +- A **separate packet is built per transmission channel**. +- Messages from different channels are **never mixed** in the same packet. + +### 5 Stop peer if requested +If a stop has been requested during this tick, transition the system toward the Stop phase. + +## Stop Phase +Shut down the peer and release all network resources. This phase is executed **once**, after the last Tick. + +The stop phase contains the following **subphases**: +1. [Stop peer-type specific logic](#1-Stop-peer-type-specific-logic) +2. [Disconnect remote peers](#2-Disconnect-remote-peers) +3. [Close socket](#3-Close-socket) +4. [Send pending data](#4-Send-pending-data) +5. [Set connection state to Disconnected](#5-Set-connection-state-to-Disconnected) + +### 1 Stop peer-type specific logic +Shutdown logic specific to client or server. + +Description of the **procedure** for the **server**: +- None + +Description of the **procedure** for the **client**: +- None + +### 2 Disconnect remote peers +Remove all remote peers and optionally notify them of the disconnection. + +Description of the **procedure**: +- For each remote peer: + - If `shouldNotify`, send Disconnect message. +- Remove all remote peers. + +### 3 Close socket +Release all socket API resources. + +Description of the **procedure**: +- Destroy socket object. +- Deinitialize socket API. + +### 4 Set connection state to Disconnected +Mark the peer as fully disconnected by setting the connection state to `Disconnected` + +## Sum up +| Phase name | What it does | What it does NOT | When it is called | +|------------|---------|------------------|-------------------| +| Start|Initialize socket and systems|Read data, Send data|Once| +| PreTick|Read data, Process data|Send data|Each fixed tick| +| Tick|Update logic, Send data|Read data|Each fixed tick| +| Stop|Cleanup socket and systems|Read data|Once| \ No newline at end of file diff --git a/docs/network_library/images/full_execution_flow_diagram.png b/docs/network_library/images/full_execution_flow_diagram.png new file mode 100644 index 0000000..358ccb8 Binary files /dev/null and b/docs/network_library/images/full_execution_flow_diagram.png differ diff --git a/docs/network_library/images/network_library_flow_diagram.png b/docs/network_library/images/network_library_flow_diagram.png new file mode 100644 index 0000000..24d3f08 Binary files /dev/null and b/docs/network_library/images/network_library_flow_diagram.png differ diff --git a/docs/network_library/network_library_index.md b/docs/network_library/network_library_index.md new file mode 100644 index 0000000..1077422 --- /dev/null +++ b/docs/network_library/network_library_index.md @@ -0,0 +1,5 @@ +# Network library documentation index +This is the main page of the network library documentation + +## Topics: +1. [Execution flow](execution_flow.md)