diff --git a/CMakeLists.txt b/CMakeLists.txt index b9c47d4a18..21c3d4dfc8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,6 +49,7 @@ endif () # Find packages. find_package(OpenSSL 3.0.0 REQUIRED COMPONENTS Crypto) +find_package(ZLIB REQUIRED) find_package(fmt 8.1.1 CONFIG) if (NOT fmt_FOUND) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3c09a5a534..d89e2beecb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -178,6 +178,7 @@ target_link_libraries(tfslib PRIVATE fmt::fmt OpenSSL::Crypto pugixml::pugixml + ZLIB::ZLIB ${CMAKE_THREAD_LIBS_INIT} ${LUA_LIBRARIES} ${MYSQL_CLIENT_LIBS} diff --git a/src/outputmessage.h b/src/outputmessage.h index 402a4dc6f0..2741572ca5 100644 --- a/src/outputmessage.h +++ b/src/outputmessage.h @@ -21,12 +21,12 @@ class OutputMessage : public NetworkMessage void writeMessageLength() { add_header(info.length); } - void addCryptoHeader(checksumMode_t mode, uint32_t& sequence) + void addCryptoHeader(checksumMode_t mode) { if (mode == CHECKSUM_ADLER) { add_header(adlerChecksum(&buffer[outputBufferStart], info.length)); } else if (mode == CHECKSUM_SEQUENCE) { - add_header(sequence++); + add_header(getSequenceId()); } writeMessageLength(); @@ -48,6 +48,9 @@ class OutputMessage : public NetworkMessage info.position += msgLen; } + void setSequenceId(uint32_t sequence) { sequenceId = sequence; } + uint32_t getSequenceId() const { return sequenceId; } + private: template void add_header(T add) @@ -60,6 +63,7 @@ class OutputMessage : public NetworkMessage } MsgSize_t outputBufferStart = INITIAL_BUFFER_POSITION; + uint32_t sequenceId; }; namespace tfs::net { diff --git a/src/protocol.cpp b/src/protocol.cpp index 71b7f6c554..2077a89134 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -46,11 +46,20 @@ bool XTEA_decrypt(NetworkMessage& msg, const xtea::round_keys& key) void Protocol::onSendMessage(const OutputMessage_ptr& msg) { if (!rawMessages) { + if (encryptionEnabled && checksumMode == CHECKSUM_SEQUENCE) { + uint32_t compressionChecksum = 0; + if (msg->getLength() >= 128 && deflateMessage(*msg)) { + compressionChecksum = 0x80000000; + } + + msg->setSequenceId(compressionChecksum | getNextSequenceId()); + } + msg->writeMessageLength(); if (encryptionEnabled) { XTEA_encrypt(*msg, key); - msg->addCryptoHeader(checksumMode, sequenceNumber); + msg->addCryptoHeader(checksumMode); } } } @@ -86,6 +95,37 @@ bool Protocol::RSA_decrypt(NetworkMessage& msg) return msg.getByte() == 0; } +bool Protocol::deflateMessage(OutputMessage& msg) +{ + static thread_local std::vector buffer(NETWORKMESSAGE_MAXSIZE); + + zstream.next_in = msg.getOutputBuffer(); + zstream.avail_in = msg.getLength(); + zstream.next_out = buffer.data(); + zstream.avail_out = buffer.size(); + + const auto result = deflate(&zstream, Z_FINISH); + if (result != Z_OK && result != Z_STREAM_END) { + std::cout << "Error while deflating packet data error: " << (zstream.msg ? zstream.msg : "unknown") + << std::endl; + return false; + } + + const auto size = zstream.total_out; + deflateReset(&zstream); + + if (size <= 0) { + std::cout << "Deflated packet data had invalid size: " << size + << " error: " << (zstream.msg ? zstream.msg : "unknown") << std::endl; + return false; + } + + msg.reset(); + msg.addBytes(reinterpret_cast(buffer.data()), size); + + return true; +} + Connection::Address Protocol::getIP() const { if (auto connection = getConnection()) { diff --git a/src/protocol.h b/src/protocol.h index a6ea13dddf..639993b1ed 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -7,10 +7,17 @@ #include "connection.h" #include "xtea.h" +#include + class Protocol : public std::enable_shared_from_this { public: - explicit Protocol(Connection_ptr connection) : connection(connection) {} + explicit Protocol(Connection_ptr connection) : connection(connection) + { + if (deflateInit2(&zstream, 6, Z_DEFLATED, -15, 8, Z_DEFAULT_STRATEGY) != Z_OK) { + std::cout << "ZLIB initialization error: " << (zstream.msg ? zstream.msg : "unknown") << std::endl; + } + } virtual ~Protocol() = default; // non-copyable @@ -42,6 +49,16 @@ class Protocol : public std::enable_shared_from_this } } + uint32_t getNextSequenceId() + { + const auto sequence = ++sequenceNumber; + if (sequenceNumber >= static_cast(std::numeric_limits::max())) { + sequenceNumber = 0; + } + + return sequence; + } + protected: static constexpr size_t RSA_BUFFER_LENGTH = 128; @@ -57,6 +74,8 @@ class Protocol : public std::enable_shared_from_this static bool RSA_decrypt(NetworkMessage& msg); + bool deflateMessage(OutputMessage& msg); + void setRawMessages(bool value) { rawMessages = value; } virtual void release() {} @@ -72,6 +91,8 @@ class Protocol : public std::enable_shared_from_this bool encryptionEnabled = false; checksumMode_t checksumMode = CHECKSUM_ADLER; bool rawMessages = false; + + z_stream zstream{}; }; #endif // FS_PROTOCOL_H diff --git a/src/protocolgame.cpp b/src/protocolgame.cpp index f1cb31f884..ee5d981e7f 100644 --- a/src/protocolgame.cpp +++ b/src/protocolgame.cpp @@ -383,7 +383,7 @@ void ProtocolGame::onRecvFirstMessage(NetworkMessage& msg) } // Change packet verifying mode for QT clients - if (version >= 1111 && operatingSystem >= CLIENTOS_QT_LINUX && operatingSystem < CLIENTOS_OTCLIENT_LINUX) { + if (version >= 1111 && operatingSystem >= CLIENTOS_QT_LINUX && operatingSystem <= CLIENTOS_OTCLIENT_MAC) { setChecksumMode(CHECKSUM_SEQUENCE); } diff --git a/vcpkg.json b/vcpkg.json index eb7f23d7e2..d47df2137f 100644 --- a/vcpkg.json +++ b/vcpkg.json @@ -14,38 +14,27 @@ }, "libmariadb", "openssl", - "pugixml" + "pugixml", + "zlib" ], "features": { "http": { "description": "Enable HTTP support", - "dependencies": [ - "boost-beast", - "boost-json" - ] + "dependencies": ["boost-beast", "boost-json"] }, "lua": { "description": "Use Lua instead of LuaJIT", - "dependencies": [ - "lua" - ] + "dependencies": ["lua"] }, "luajit": { "description": "Use LuaJIT instead of Lua", - "dependencies": [ - "luajit" - ] + "dependencies": ["luajit"] }, "unit-tests": { "description": "Build unit tests", - "dependencies": [ - "boost-test" - ] + "dependencies": ["boost-test"] } }, - "default-features": [ - "lua", - "http" - ], + "default-features": ["lua", "http"], "builtin-baseline": "215a2535590f1f63788ac9bd2ed58ad15e6afdff" }