From ad91f7e1be0e84c5d0de8239f289b5602538948f Mon Sep 17 00:00:00 2001 From: Abhik Jain Date: Thu, 17 Feb 2022 16:01:19 +0530 Subject: [PATCH] benchmarks: add a version agnostic simple router Signed-off-by: Abhik Jain --- Cargo.lock | 269 ++- Cargo.toml | 3 +- benchmarks/simplerouter/Cargo.toml | 13 + .../simplerouter/src/bin/simplerouter.rs | 11 + benchmarks/simplerouter/src/lib.rs | 122 ++ benchmarks/simplerouter/src/network.rs | 102 + benchmarks/simplerouter/src/protocol/mod.rs | 430 ++++ benchmarks/simplerouter/src/protocol/v4.rs | 863 ++++++++ benchmarks/simplerouter/src/protocol/v5.rs | 1952 +++++++++++++++++ 9 files changed, 3672 insertions(+), 93 deletions(-) create mode 100644 benchmarks/simplerouter/Cargo.toml create mode 100644 benchmarks/simplerouter/src/bin/simplerouter.rs create mode 100644 benchmarks/simplerouter/src/lib.rs create mode 100644 benchmarks/simplerouter/src/network.rs create mode 100644 benchmarks/simplerouter/src/protocol/mod.rs create mode 100644 benchmarks/simplerouter/src/protocol/v4.rs create mode 100644 benchmarks/simplerouter/src/protocol/v5.rs diff --git a/Cargo.lock b/Cargo.lock index 8dbaf5cef..5281ecfe4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -103,18 +103,18 @@ dependencies = [ [[package]] name = "async-tungstenite" -version = "0.13.1" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07b30ef0ea5c20caaa54baea49514a206308989c68be7ecd86c7f956e4da6378" +checksum = "5682ea0913e5c20780fe5785abacb85a411e7437bf52a1bedb93ddb3972cb8dd" dependencies = [ "futures-io", "futures-util", "log", "pin-project-lite 0.2.7", - "tokio 1.8.2", + "rustls-native-certs", + "tokio 1.17.0", "tokio-rustls", - "tungstenite 0.13.0", - "webpki-roots", + "tungstenite 0.16.0", ] [[package]] @@ -126,7 +126,7 @@ dependencies = [ "futures 0.3.15", "pharos", "rustc_version 0.3.3", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -185,7 +185,7 @@ version = "0.4.0" dependencies = [ "argh", "async-channel", - "bytes 1.0.1", + "bytes 1.1.0", "futures 0.3.15", "itoa", "jemallocator", @@ -197,7 +197,7 @@ dependencies = [ "rumqttlog 0.9.0", "serde", "serde_json", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -263,9 +263,9 @@ checksum = "0e4cec68f03f32e44924783795810fa50a7035d8c8ebe78580ad7e6c703fba38" [[package]] name = "bytes" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" +checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" [[package]] name = "cache-padded" @@ -798,7 +798,7 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "825343c4eef0b63f541f8903f395dc5beb362a979b5799a84062527ef1e37726" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-core", "futures-sink", @@ -806,7 +806,7 @@ dependencies = [ "http", "indexmap", "slab", - "tokio 1.8.2", + "tokio 1.17.0", "tokio-util", "tracing", ] @@ -831,7 +831,7 @@ checksum = "f0b7591fb62902706ae8e7aaff416b1b0fa2c0fd0878b46dc13baa3712d8a855" dependencies = [ "base64 0.13.0", "bitflags", - "bytes 1.0.1", + "bytes 1.1.0", "headers-core", "http", "mime", @@ -872,7 +872,7 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "527e8c9ac747e28542699a951517aa9a6945af506cd1f2e1b53a576c17b6cc11" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "itoa", ] @@ -883,7 +883,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60daa14be0e0786db0f03a9e57cb404c9d756eed2b6c62b9ea98ec5743ec75a9" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "http", "pin-project-lite 0.2.7", ] @@ -915,7 +915,7 @@ version = "0.14.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7728a72c4c7d72665fde02204bcbd93b247721025b222ef78606f14513e0fd03" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-channel", "futures-core", "futures-util", @@ -927,7 +927,7 @@ dependencies = [ "itoa", "pin-project-lite 0.2.7", "socket2", - "tokio 1.8.2", + "tokio 1.17.0", "tower-service", "tracing", "want", @@ -978,7 +978,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", ] [[package]] @@ -1111,15 +1111,15 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.98" +version = "0.2.118" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" +checksum = "06e509672465a0504304aa87f9f176f2b2b716ed8fb105ebe5c02dc6dce96a94" [[package]] name = "lock_api" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" +checksum = "88943dd7ef4a2e5a4bfa2753aaab3013e34ce2533d1996fb18ef591e315e2b3b" dependencies = [ "scopeguard", ] @@ -1211,9 +1211,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.7.13" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c2bdb6314ec10835cd3293dd268473a835c02b7b352e788be788b3c6ca6bb16" +checksum = "ba272f85fa0b41fc91872be579b3bbe0f56b792aa361a380eb669469f68dafb2" dependencies = [ "libc", "log", @@ -1279,7 +1279,7 @@ dependencies = [ name = "mqttbytes" version = "0.6.0" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "criterion", "pretty_assertions", "rand 0.7.3", @@ -1513,7 +1513,17 @@ checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" dependencies = [ "instant", "lock_api", - "parking_lot_core", + "parking_lot_core 0.8.3", +] + +[[package]] +name = "parking_lot" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f5ec2493a61ac0506c0f4199f99070cbe83857b0337006a30f3e6719b8ef58" +dependencies = [ + "lock_api", + "parking_lot_core 0.9.1", ] [[package]] @@ -1530,6 +1540,19 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "parking_lot_core" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28141e0cc4143da2443301914478dc976a61ffdb3f043058310c70df2fed8954" +dependencies = [ + "cfg-if 1.0.0", + "libc", + "redox_syscall", + "smallvec", + "windows-sys", +] + [[package]] name = "pem" version = "0.8.3" @@ -1666,7 +1689,7 @@ dependencies = [ "libc", "log", "nix 0.19.1", - "parking_lot", + "parking_lot 0.11.1", "prost 0.7.0", "prost-build", "prost-derive 0.7.0", @@ -1687,7 +1710,7 @@ dependencies = [ "libc", "log", "nix 0.20.0", - "parking_lot", + "parking_lot 0.11.1", "prost 0.7.0", "prost-build", "prost-derive 0.7.0", @@ -1761,7 +1784,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e6984d2f1a23009bd270b8bb56d0926810a3d483f59c987d77969e9d8e840b2" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "prost-derive 0.7.0", ] @@ -1771,7 +1794,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32d3ebd75ac2679c2af3a92246639f9fcc8a442ee420719cc4fe195b98dd5fa3" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "heck", "itertools 0.9.0", "log", @@ -1815,7 +1838,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b518d7cdd93dab1d1122cf07fa9a60771836c668dde9d9e2a139f957f0d9f1bb" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "prost 0.7.0", ] @@ -2030,7 +2053,7 @@ version = "0.10.0" dependencies = [ "async-channel", "async-tungstenite", - "bytes 1.0.1", + "bytes 1.1.0", "color-backtrace", "crossbeam-channel", "envy", @@ -2043,9 +2066,10 @@ dependencies = [ "pretty_env_logger", "rustls", "rustls-native-certs", + "rustls-pemfile 0.3.0", "serde", "thiserror", - "tokio 1.8.2", + "tokio 1.17.0", "tokio-rustls", "url", "webpki", @@ -2057,7 +2081,7 @@ name = "rumqttd" version = "0.9.0" dependencies = [ "argh", - "bytes 1.0.1", + "bytes 1.1.0", "confy", "futures-util", "jemallocator", @@ -2066,9 +2090,10 @@ dependencies = [ "pprof 0.4.4", "pretty_env_logger", "rumqttlog 0.9.0", + "rustls-pemfile 0.3.0", "serde", "thiserror", - "tokio 1.8.2", + "tokio 1.17.0", "tokio-native-tls", "tokio-rustls", "warp", @@ -2101,7 +2126,7 @@ dependencies = [ "argh", "bencher", "byteorder", - "bytes 1.0.1", + "bytes 1.1.0", "fnv", "futures-util", "jackiechan", @@ -2123,9 +2148,9 @@ dependencies = [ name = "rumqttmesh" version = "0.1.0" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "rumqttlog 0.1.4", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -2154,11 +2179,10 @@ dependencies = [ [[package]] name = "rustls" -version = "0.19.1" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" +checksum = "b323592e3164322f5b193dc4302e4e36cd8d37158a712d664efae1a5c2791700" dependencies = [ - "base64 0.13.0", "log", "ring", "sct", @@ -2167,16 +2191,34 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.5.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a07b7c1885bd8ed3831c289b7870b13ef46fe0e856d288c30d9cc17d75a2092" +checksum = "5ca9ebdfa27d3fc180e42879037b5338ab1c040c06affd00d8338598e7800943" dependencies = [ "openssl-probe", - "rustls", + "rustls-pemfile 0.2.1", "schannel", "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eebeaeb360c87bfb72e84abdb3447159c0eaececf1bef2aecd65a8be949d1c9" +dependencies = [ + "base64 0.13.0", +] + +[[package]] +name = "rustls-pemfile" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee86d63972a7c661d1536fefe8c3c8407321c3df668891286de28abcd087360" +dependencies = [ + "base64 0.13.0", +] + [[package]] name = "ryu" version = "1.0.5" @@ -2222,9 +2264,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "sct" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" dependencies = [ "ring", "untrusted", @@ -2375,6 +2417,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "simplerouter" +version = "0.1.0" +dependencies = [ + "bytes 1.1.0", + "log", + "pretty_env_logger", + "thiserror", + "tokio 1.17.0", +] + [[package]] name = "slab" version = "0.4.3" @@ -2389,9 +2442,9 @@ checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" [[package]] name = "socket2" -version = "0.4.0" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e3dfc207c526015c632472a77be09cf1b6e46866581aecae5cc38fb4235dea2" +checksum = "66d72b759436ae32898a2af0a14218dbf55efde3feeb170eb623637db85ee1e0" dependencies = [ "libc", "winapi 0.3.9", @@ -2492,18 +2545,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.26" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93119e4feac1cbe6c798c34d3a53ea0026b0b1de6a120deef895137c0529bfe2" +checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.26" +version = "1.0.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "060d69a0afe7796bf42e9e2ff91f5ee691fb15c53d38b4b62a9a53eb23164745" +checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b" dependencies = [ "proc-macro2", "quote", @@ -2571,21 +2624,21 @@ dependencies = [ [[package]] name = "tokio" -version = "1.8.2" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2602b8af3767c285202012822834005f596c811042315fa7e9f5b12b2a43207" +checksum = "2af73ac49756f3f7c01172e34a23e5d0216f6c32333757c2c61feb2bbff5a5ee" dependencies = [ - "autocfg", - "bytes 1.0.1", + "bytes 1.1.0", "libc", "memchr", - "mio 0.7.13", + "mio 0.8.0", "num_cpus", "once_cell", - "parking_lot", + "parking_lot 0.12.0", "pin-project-lite 0.2.7", "signal-hook-registry", - "tokio-macros 1.3.0", + "socket2", + "tokio-macros 1.7.0", "winapi 0.3.9", ] @@ -2602,9 +2655,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "1.3.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" +checksum = "b557f72f448c511a979e2564e55d74e6c4432fc96ff4f6241bc6bded342643b7" dependencies = [ "proc-macro2", "quote", @@ -2618,17 +2671,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b" dependencies = [ "native-tls", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] name = "tokio-rustls" -version = "0.22.0" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +checksum = "a27d5f2b839802bd8267fa19b0530f5a08b9c08cd417976be2a65d130fe1c11b" dependencies = [ "rustls", - "tokio 1.8.2", + "tokio 1.17.0", "webpki", ] @@ -2640,7 +2693,7 @@ checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f" dependencies = [ "futures-core", "pin-project-lite 0.2.7", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -2652,7 +2705,7 @@ dependencies = [ "futures-util", "log", "pin-project", - "tokio 1.8.2", + "tokio 1.17.0", "tungstenite 0.12.0", ] @@ -2662,12 +2715,12 @@ version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1caa0b0c8d94a049db56b5acf8cba99dc0623aab1b26d5b5f5e2d945846b3592" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures-core", "futures-sink", "log", "pin-project-lite 0.2.7", - "tokio 1.8.2", + "tokio 1.17.0", ] [[package]] @@ -2720,7 +2773,7 @@ checksum = "8ada8297e8d70872fa9a551d93250a9f407beb9f37ef86494eb20012a2ff7c24" dependencies = [ "base64 0.13.0", "byteorder", - "bytes 1.0.1", + "bytes 1.1.0", "http", "httparse", "input_buffer", @@ -2733,16 +2786,15 @@ dependencies = [ [[package]] name = "tungstenite" -version = "0.13.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093" +checksum = "6ad3713a14ae247f22a728a0456a545df14acf3867f905adff84be99e23b3ad1" dependencies = [ "base64 0.13.0", "byteorder", - "bytes 1.0.1", + "bytes 1.1.0", "http", "httparse", - "input_buffer", "log", "rand 0.8.4", "rustls", @@ -2751,7 +2803,6 @@ dependencies = [ "url", "utf-8", "webpki", - "webpki-roots", ] [[package]] @@ -2889,7 +2940,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "332d47745e9a0c38636dbd454729b147d16bd1ed08ae67b3ab281c4506771054" dependencies = [ - "bytes 1.0.1", + "bytes 1.1.0", "futures 0.3.15", "headers", "http", @@ -2904,7 +2955,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", - "tokio 1.8.2", + "tokio 1.17.0", "tokio-stream", "tokio-tungstenite", "tokio-util", @@ -2990,23 +3041,14 @@ dependencies = [ [[package]] name = "webpki" -version = "0.21.4" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" dependencies = [ "ring", "untrusted", ] -[[package]] -name = "webpki-roots" -version = "0.21.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aabe153544e473b775453675851ecc86863d2a81d786d741f6b76778f2a48940" -dependencies = [ - "webpki", -] - [[package]] name = "which" version = "4.1.0" @@ -3060,6 +3102,49 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df6e476185f92a12c072be4a189a0210dcdcf512a1891d6dff9edb874deadc6" +dependencies = [ + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8e92753b1c443191654ec532f14c199742964a061be25d77d7a96f09db20bf5" + +[[package]] +name = "windows_i686_gnu" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a711c68811799e017b6038e0922cb27a5e2f43a2ddb609fe0b6f3eeda9de615" + +[[package]] +name = "windows_i686_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c11bb1a02615db74680b32a68e2d61f553cc24c4eb5b4ca10311740e44172" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c912b12f7454c6620635bbff3450962753834be2a594819bd5e945af18ec64bc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.32.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "504a2476202769977a040c6364301a3f65d0cc9e3fb08600b2bda150a0488316" + [[package]] name = "ws2_32-sys" version = "0.2.1" @@ -3072,9 +3157,9 @@ dependencies = [ [[package]] name = "ws_stream_tungstenite" -version = "0.6.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34c786fc3d0a792f8a6e7a69f3b85afa1cf7b2560bbd434d7d5c32a580e153c0" +checksum = "a672ec78525bf189cefa7f1b72c55f928b3edbdb967e680ca49748ab20821045" dependencies = [ "async-tungstenite", "async_io_stream", @@ -3086,6 +3171,6 @@ dependencies = [ "log", "pharos", "rustc_version 0.4.0", - "tokio 1.8.2", - "tungstenite 0.13.0", + "tokio 1.17.0", + "tungstenite 0.16.0", ] diff --git a/Cargo.toml b/Cargo.toml index ecf22a6f5..d25103f1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,5 +6,6 @@ members = [ "rumqttlog", "rumqttmesh", "rumqttd", - "benchmarks" + "benchmarks", + "benchmarks/simplerouter", ] diff --git a/benchmarks/simplerouter/Cargo.toml b/benchmarks/simplerouter/Cargo.toml new file mode 100644 index 000000000..3e2ee24d8 --- /dev/null +++ b/benchmarks/simplerouter/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "simplerouter" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = "1.1.0" +log = "0.4.14" +pretty_env_logger = "0.4.0" +thiserror = "1.0.30" +tokio = { version = "1.17.0", features = ["net", "sync", "rt-multi-thread", "io-util", "macros"] } diff --git a/benchmarks/simplerouter/src/bin/simplerouter.rs b/benchmarks/simplerouter/src/bin/simplerouter.rs new file mode 100644 index 000000000..cf5527749 --- /dev/null +++ b/benchmarks/simplerouter/src/bin/simplerouter.rs @@ -0,0 +1,11 @@ +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + +#[tokio::main] +async fn main() { + pretty_env_logger::init(); + simplerouter::run(simplerouter::Config { + addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1883)), + }) + .await + .unwrap(); +} diff --git a/benchmarks/simplerouter/src/lib.rs b/benchmarks/simplerouter/src/lib.rs new file mode 100644 index 000000000..18e7fa4ee --- /dev/null +++ b/benchmarks/simplerouter/src/lib.rs @@ -0,0 +1,122 @@ +use std::{io, net::SocketAddr}; + +use bytes::BytesMut; +use log::*; +use tokio::net::TcpListener; + +mod network; +mod protocol; +use network::Network; +use protocol::{v4, v5}; + +pub struct Config { + pub addr: SocketAddr, +} + +pub async fn run(config: Config) -> Result<(), Error> { + let listener = TcpListener::bind(config.addr).await?; + info!("router: listening on {}", config.addr); + + loop { + let (stream, addr) = listener.accept().await?; + info!("router: accepted connection from {}", addr); + let (network, _) = match Network::read_connect(stream).await { + Ok(v) => v, + Err(e) => { + error!("router: unable to read connect : {}", e); + continue; + } + }; + info!("connection: sent connack"); + tokio::spawn(publisher_handle(network)); + } +} + +async fn publisher_handle(mut network: Network) { + let mut payload = BytesMut::with_capacity(2); + v4::pingresp::write(&mut payload).unwrap(); + let pingresp_bytes = payload.split().freeze(); + + loop { + let packet = match network.poll().await { + Ok(packet) => packet, + Err(e) => { + error!("connection: unable to read packet: {}", e); + return; + } + }; + match packet { + protocol::Packet::V4(packet) => match packet { + v4::Packet::Disconnect => { + info!("connection: received disconnect, exiting"); + return; + }, + v4::Packet::PingReq => { + if let Err(e) = network.send_data(&pingresp_bytes).await { + error!("unable to send pingresp, exiting : {}", e); + return; + }; + } + v4::Packet::Publish(publish) => { + let pkid = match publish.view_meta() { + Ok(v) => v.2, + Err(e) => { + error!("connection: malformed publish packet : {}", e); + continue; + } + }; + payload.reserve(2); + v4::puback::write(pkid, &mut payload).unwrap(); + if let Err(e) = network.send_data(&payload.split().freeze()).await { + error!("unable to send puback pkid = {}, exiting : {}", pkid, e); + return; + }; + } + p => { + error!("connection: invalid packet {:?}", p); + continue; + } + }, + protocol::Packet::V5(packet) => match packet { + v5::Packet::Disconnect => { + info!("connection: received disconnect, exiting"); + return; + }, + v5::Packet::PingReq => { + if let Err(e) = network.send_data(&pingresp_bytes).await { + error!("unable to send pingresp, exiting : {}", e); + return; + }; + } + v5::Packet::Publish(publish) => { + let pkid = match publish.view_meta() { + Ok(v) => v.2, + Err(e) => { + error!("connection: malformed publish packet : {}", e); + continue; + } + }; + payload.reserve(8); + v5::puback::write(pkid, v5::puback::PubAckReason::Success, None, &mut payload) + .unwrap(); + if let Err(e) = network.send_data(&payload.split().freeze()).await { + error!("unable to send puback pkid = {}, exiting : {}", pkid, e); + return; + }; + } + p => { + error!("connection: invalid packet {:?}", p); + continue; + } + }, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("MQTT : {0}")] + MQTT(#[from] crate::protocol::Error), + #[error("i/O : {0}")] + IO(#[from] io::Error), +} diff --git a/benchmarks/simplerouter/src/network.rs b/benchmarks/simplerouter/src/network.rs new file mode 100644 index 000000000..b61c02fe4 --- /dev/null +++ b/benchmarks/simplerouter/src/network.rs @@ -0,0 +1,102 @@ +use std::io; + +use bytes::{Bytes, BytesMut}; +use log::*; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpStream, +}; + +use crate::{ + protocol::{self, v4, v5, Connect, Packet}, + Error, +}; + +pub(crate) struct Network { + stream: TcpStream, + buf: BytesMut, + protocol_level: u8, +} + +impl Network { + pub(crate) async fn read_connect(stream: TcpStream) -> Result<(Self, Connect), Error> { + let mut network = Self { + stream, + buf: BytesMut::with_capacity(4096), + protocol_level: 0, + }; + debug!("network: reading from stream"); + network.stream.read_buf(&mut network.buf).await?; + let connect_packet = loop { + match protocol::read_first_connect(&mut network.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + network.read_atleast(count).await? + } + res => break res?, + } + }; + debug!("network: read connect"); + match &connect_packet { + Connect::V4(_) => { + network.protocol_level = 4; + let mut payload = BytesMut::with_capacity(10); + v4::connack::write(v4::connack::ConnectReturnCode::Success, false, &mut payload)?; + network.send_data(&payload.split().freeze()).await?; + } + Connect::V5(_) => { + network.protocol_level = 5; + let mut payload = BytesMut::with_capacity(10); + v5::connack::write( + v5::connack::ConnectReturnCode::Success, + false, + None, + &mut payload, + )?; + network.send_data(&payload.split().freeze()).await?; + } + } + debug!("network: sent connack"); + Ok((network, connect_packet)) + } + + async fn read_atleast(&mut self, count: usize) -> io::Result<()> { + let mut len = 0; + while len < count { + len += self.stream.read_buf(&mut self.buf).await?; + } + debug!("network: read {} bytes", len); + + Ok(()) + } + + pub(crate) async fn poll(&mut self) -> Result { + loop { + match self.protocol_level { + 4 => match v4::read_mut(&mut self.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + self.read_atleast(count).await?; + continue; + } + res => return Ok(Packet::V4(res?)), + }, + 5 => match v5::read_mut(&mut self.buf, 4096) { + Err(protocol::Error::InsufficientBytes(count)) => { + self.read_atleast(count).await?; + continue; + } + res => return Ok(Packet::V5(res?)), + }, + // SAFETY: we don't allow changing protocol_level + _ => unsafe { std::hint::unreachable_unchecked() }, + } + } + } + + pub(crate) async fn send_data(&mut self, data: &Bytes) -> Result<(), Error> { + debug!( + "network: sent {} bytes", + self.stream.write(data.as_ref()).await? + ); + Ok(()) + } +} diff --git a/benchmarks/simplerouter/src/protocol/mod.rs b/benchmarks/simplerouter/src/protocol/mod.rs new file mode 100644 index 000000000..2ecdbd320 --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/mod.rs @@ -0,0 +1,430 @@ +#![allow(dead_code)] +use std::{slice::Iter, str::Utf8Error}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +pub mod v4; +pub mod v5; + +/// Checks if the filter is valid +/// +/// https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106 +pub fn valid_filter(filter: &str) -> bool { + if filter.is_empty() { + return false; + } + + let hirerarchy = filter.split('/').collect::>(); + if let Some((last, remaining)) = hirerarchy.split_last() { + // # is not allowed in filer except as a last entry + // invalid: sport/tennis#/player + // invalid: sport/tennis/#/ranking + for entry in remaining.iter() { + if entry.contains('#') { + return false; + } + } + + // only single '#" is allowed in last entry + // invalid: sport/tennis# + if last.len() != 1 && last.contains('#') { + return false; + } + } + + true +} + +/// MQTT packet type +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PacketType { + Connect = 1, + ConnAck, + Publish, + PubAck, + PubRec, + PubRel, + PubComp, + Subscribe, + SubAck, + Unsubscribe, + UnsubAck, + PingReq, + PingResp, + Disconnect, +} + +/// Error during serialization and deserialization +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +pub enum Error { + #[error("Expected connect packet, received = {0:?}")] + NotConnect(PacketType), + #[error("Received an unexpected connect packet")] + UnexpectedConnect, + #[error("Invalid return code received as response for connect = {0}")] + InvalidConnectReturnCode(u8), + #[error("Invalid reason = {0}")] + InvalidReason(u8), + #[error("Invalid protocol used")] + InvalidProtocol, + #[error("Invalid protocol level")] + InvalidProtocolLevel(u8), + #[error("Invalid packet format")] + IncorrectPacketFormat, + #[error("Invalid packet type = {0}")] + InvalidPacketType(u8), + #[error("Packet type unsupported = {0:?}")] + UnsupportedPacket(PacketType), + #[error("Invalid retain forward rule = {0}")] + InvalidRetainForwardRule(u8), + #[error("Invalid QoS level = {0}")] + InvalidQoS(u8), + #[error("Invalid subscribe reason code = {0}")] + InvalidSubscribeReasonCode(u8), + #[error("Packet received has id Zero")] + PacketIdZero, + #[error("Subscription had id Zero")] + SubscriptionIdZero, + #[error("Payload size is incorrect")] + PayloadSizeIncorrect, + #[error("Payload is too long")] + PayloadTooLong, + #[error("Payload size has been exceeded by {0} bytes")] + PayloadSizeLimitExceeded(usize), + #[error("Payload is required")] + PayloadRequired, + #[error("Topic not utf-8 = {0}")] + TopicNotUtf8(#[from] Utf8Error), + #[error("Promised boundary crossed, contains {0} bytes")] + BoundaryCrossed(usize), + #[error("Packet is malformed")] + MalformedPacket, + #[error("Remaining length is malformed")] + MalformedRemainingLength, + /// More bytes required to frame packet. Argument + /// implies minimum additional bytes required to + /// proceed further + #[error("Insufficient number of bytes to frame packet, {0} more bytes required")] + InsufficientBytes(usize), + #[error("Property does not exist = {0}")] + InvalidPropertyType(u8), +} + +/// Quality of service +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum QoS { + AtMostOnce = 0, + AtLeastOnce = 1, +} + +/// Maps a number to QoS +pub fn qos(num: u8) -> Result { + match num { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + qos => Err(Error::InvalidQoS(qos)), + } +} + +/// Packet type from a byte +/// +/// ```ignore +/// 7 3 0 +/// +--------------------------+--------------------------+ +/// byte 1 | MQTT Control Packet Type | Flags for each type | +/// +--------------------------+--------------------------+ +/// | Remaining Bytes Len (1/2/3/4 bytes) | +/// +-----------------------------------------------------+ +/// +/// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Figure_2.2_- +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub struct FixedHeader { + /// First byte of the stream. Used to identify packet types and + /// several flags + pub byte1: u8, + /// Length of fixed header. Byte 1 + (1..4) bytes. So fixed header + /// len can vary from 2 bytes to 5 bytes + /// 1..4 bytes are variable length encoded to represent remaining length + pub fixed_header_len: usize, + /// Remaining length of the packet. Doesn't include fixed header bytes + /// Represents variable header + payload size + pub remaining_len: usize, +} + +impl FixedHeader { + pub fn new(byte1: u8, remaining_len_len: usize, remaining_len: usize) -> FixedHeader { + FixedHeader { + byte1, + fixed_header_len: remaining_len_len + 1, + remaining_len, + } + } + + pub fn packet_type(&self) -> Result { + let num = self.byte1 >> 4; + match num { + 1 => Ok(PacketType::Connect), + 2 => Ok(PacketType::ConnAck), + 3 => Ok(PacketType::Publish), + 4 => Ok(PacketType::PubAck), + 5 => Ok(PacketType::PubRec), + 6 => Ok(PacketType::PubRel), + 7 => Ok(PacketType::PubComp), + 8 => Ok(PacketType::Subscribe), + 9 => Ok(PacketType::SubAck), + 10 => Ok(PacketType::Unsubscribe), + 11 => Ok(PacketType::UnsubAck), + 12 => Ok(PacketType::PingReq), + 13 => Ok(PacketType::PingResp), + 14 => Ok(PacketType::Disconnect), + _ => Err(Error::InvalidPacketType(num)), + } + } + + /// Returns the size of full packet (fixed header + variable header + payload) + /// Fixed header is enough to get the size of a frame in the stream + pub fn frame_length(&self) -> usize { + self.fixed_header_len + self.remaining_len + } +} + +/// Checks if the stream has enough bytes to frame a packet and returns fixed header +/// only if a packet can be framed with existing bytes in the `stream`. +/// The passed stream doesn't modify parent stream's cursor. If this function +/// returned an error, next `check` on the same parent stream is forced start +/// with cursor at 0 again (Iter is owned. Only Iter's cursor is changed internally) +pub fn check(stream: Iter, max_packet_size: usize) -> Result { + // Create fixed header if there are enough bytes in the stream + // to frame full packet + let stream_len = stream.len(); + let fixed_header = parse_fixed_header(stream)?; + + // Don't let rogue connections attack with huge payloads. + // Disconnect them before reading all that data + if fixed_header.remaining_len > max_packet_size { + return Err(Error::PayloadSizeLimitExceeded(fixed_header.remaining_len)); + } + + // If the current call fails due to insufficient bytes in the stream, + // after calculating remaining length, we extend the stream + let frame_length = fixed_header.frame_length(); + if stream_len < frame_length { + return Err(Error::InsufficientBytes(frame_length - stream_len)); + } + + Ok(fixed_header) +} + +/// Parses fixed header +fn parse_fixed_header(mut stream: Iter) -> Result { + // At least 2 bytes are necessary to frame a packet + let stream_len = stream.len(); + if stream_len < 2 { + return Err(Error::InsufficientBytes(2 - stream_len)); + } + + let byte1 = stream.next().unwrap(); + let (len_len, len) = length(stream)?; + + Ok(FixedHeader::new(*byte1, len_len, len)) +} + +/// Parses variable byte integer in the stream and returns the length +/// and number of bytes that make it. Used for remaining length calculation +/// as well as for calculating property lengths +pub fn length(stream: Iter) -> Result<(usize, usize), Error> { + let mut len: usize = 0; + let mut len_len = 0; + let mut done = false; + let mut shift = 0; + + // Use continuation bit at position 7 to continue reading next + // byte to frame 'length'. + // Stream 0b1xxx_xxxx 0b1yyy_yyyy 0b1zzz_zzzz 0b0www_wwww will + // be framed as number 0bwww_wwww_zzz_zzzz_yyy_yyyy_xxx_xxxx + for byte in stream { + len_len += 1; + let byte = *byte as usize; + len += (byte & 0x7F) << shift; + + // stop when continue bit is 0 + done = (byte & 0x80) == 0; + if done { + break; + } + + shift += 7; + + // Only a max of 4 bytes allowed for remaining length + // more than 4 shifts (0, 7, 14, 21) implies bad length + if shift > 21 { + return Err(Error::MalformedRemainingLength); + } + } + + // Not enough bytes to frame remaining length. wait for + // one more byte + if !done { + return Err(Error::InsufficientBytes(1)); + } + + Ok((len_len, len)) +} + +/// Returns big endian u16 view from next 2 bytes +pub fn view_u16(stream: &[u8]) -> Result { + let v = match stream.get(0..2) { + Some(v) => (v[0] as u16) << 8 | (v[1] as u16), + None => return Err(Error::MalformedPacket), + }; + + Ok(v) +} + +/// Returns big endian u16 view from next 2 bytes +pub fn view_str(stream: &[u8], end: usize) -> Result<&str, Error> { + let v = match stream.get(0..end) { + Some(v) => v, + None => return Err(Error::BoundaryCrossed(stream.len())), + }; + + let v = std::str::from_utf8(v)?; + Ok(v) +} + +/// After collecting enough bytes to frame a packet (packet's frame()) +/// , It's possible that content itself in the stream is wrong. Like expected +/// packet id or qos not being present. In cases where `read_mqtt_string` or +/// `read_mqtt_bytes` exhausted remaining length but packet framing expects to +/// parse qos next, these pre checks will prevent `bytes` crashes + +fn read_u32(stream: &mut Bytes) -> Result { + if stream.len() < 4 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u32()) +} + +pub fn read_u16(stream: &mut Bytes) -> Result { + if stream.len() < 2 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u16()) +} + +fn read_u8(stream: &mut Bytes) -> Result { + if stream.len() < 1 { + return Err(Error::MalformedPacket); + } + + Ok(stream.get_u8()) +} + +/// Reads a series of bytes with a length from a byte stream +fn read_mqtt_bytes(stream: &mut Bytes) -> Result { + let len = read_u16(stream)? as usize; + + // Prevent attacks with wrong remaining length. This method is used in + // `packet.assembly()` with (enough) bytes to frame packet. Ensures that + // reading variable len string or bytes doesn't cross promised boundary + // with `read_fixed_header()` + if len > stream.len() { + return Err(Error::BoundaryCrossed(len)); + } + + Ok(stream.split_to(len)) +} + +/// Serializes bytes to stream (including length) +fn write_mqtt_bytes(stream: &mut BytesMut, bytes: &[u8]) { + stream.put_u16(bytes.len() as u16); + stream.extend_from_slice(bytes); +} + +/// Serializes a string to stream +pub fn write_mqtt_string(stream: &mut BytesMut, string: &str) { + write_mqtt_bytes(stream, string.as_bytes()); +} + +/// Writes remaining length to stream and returns number of bytes for remaining length +pub fn write_remaining_length(stream: &mut BytesMut, len: usize) -> Result { + if len > 268_435_455 { + return Err(Error::PayloadTooLong); + } + + let mut done = false; + let mut x = len; + let mut count = 0; + + while !done { + let mut byte = (x % 128) as u8; + x /= 128; + if x > 0 { + byte |= 128; + } + + stream.put_u8(byte); + count += 1; + done = x == 0; + } + + Ok(count) +} + +/// Return number of remaining length bytes required for encoding length +fn len_len(len: usize) -> usize { + if len >= 2_097_152 { + 4 + } else if len >= 16_384 { + 3 + } else if len >= 128 { + 2 + } else { + 1 + } +} + +pub enum Connect { + V4(v4::connect::Connect), + V5(v5::connect::Connect), +} + +#[derive(Debug)] +pub enum Packet { + V4(v4::Packet), + V5(v5::Packet), +} + +pub(crate) fn read_first_connect(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + match fixed_header.packet_type()? { + PacketType::Connect => {} + p => return Err(Error::NotConnect(p)), + } + let mut packet = packet.freeze(); + + let variable_header_index = fixed_header.fixed_header_len; + packet.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut packet)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + let protocol_level = read_u8(&mut packet)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + match protocol_level { + 4 => Ok(Connect::V4(v4::connect::connect_v4_part(packet)?)), + 5 => Ok(Connect::V5(v5::connect::connect_v5_part(packet)?)), + _ => Err(Error::InvalidProtocolLevel(protocol_level)), + } +} diff --git a/benchmarks/simplerouter/src/protocol/v4.rs b/benchmarks/simplerouter/src/protocol/v4.rs new file mode 100644 index 000000000..e418877c0 --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/v4.rs @@ -0,0 +1,863 @@ +#![allow(dead_code)] + +use super::*; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +pub(crate) mod connect { + use super::*; + use bytes::Bytes; + + /// Connection packet initiated by the client + #[derive(Debug, Clone, PartialEq)] + pub struct Connect { + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + } + + impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + keep_alive: 10, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut bytes)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + let protocol_level = read_u8(&mut bytes)?; + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + if protocol_level != 4 { + return Err(Error::InvalidProtocolLevel(protocol_level)); + } + + connect_v4_part(bytes) + } + + pub fn write(&self, buffer: &mut BytesMut) -> Result { + let len = self.len(); + buffer.put_u8(0b0001_0000); + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, "MQTT"); + buffer.put_u8(0x04); + + let flags_index = 1 + count + 2 + 4 + 1; + + let mut connect_flags = 0; + if self.clean_session { + connect_flags |= 0x02; + } + + buffer.put_u8(connect_flags); + buffer.put_u16(self.keep_alive); + write_mqtt_string(buffer, &self.client_id); + + if let Some(last_will) = &self.last_will { + connect_flags |= last_will.write(buffer)?; + } + + if let Some(login) = &self.login { + connect_flags |= login.write(buffer); + } + + // update connect flags + buffer[flags_index] = connect_flags; + Ok(len) + } + } + + pub(crate) fn connect_v4_part(mut bytes: Bytes) -> Result { + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + let client_id = read_mqtt_bytes(&mut bytes)?; + let client_id = std::str::from_utf8(&client_id)?.to_owned(); + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + keep_alive, + client_id, + clean_session, + last_will, + login, + }; + + Ok(connect) + } + + /// LastWill that broker forwards on behalf of the client + #[derive(Debug, Clone, PartialEq)] + pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + } + + impl LastWill { + pub fn _new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + } + } + + fn len(&self) -> usize { + let mut len = 0; + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + let will_topic = read_mqtt_bytes(&mut bytes)?; + let will_topic = std::str::from_utf8(&will_topic)?.to_owned(); + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + }) + } + }; + + Ok(last_will) + } + + fn write(&self, buffer: &mut BytesMut) -> Result { + let mut connect_flags = 0; + + connect_flags |= 0x04 | (self.qos as u8) << 3; + if self.retain { + connect_flags |= 0x20; + } + + write_mqtt_string(buffer, &self.topic); + write_mqtt_bytes(buffer, &self.message); + Ok(connect_flags) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct Login { + username: String, + password: String, + } + + impl Login { + pub fn new>(u: S, p: S) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => { + let username = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&username)?.to_owned() + } + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => { + let password = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&password)?.to_owned() + } + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + + fn write(&self, buffer: &mut BytesMut) -> u8 { + let mut connect_flags = 0; + if !self.username.is_empty() { + connect_flags |= 0x80; + write_mqtt_string(buffer, &self.username); + } + + if !self.password.is_empty() { + connect_flags |= 0x40; + write_mqtt_string(buffer, &self.password); + } + + connect_flags + } + } +} + +pub(crate) mod connack { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum ConnectReturnCode { + Success = 0, + RefusedProtocolVersion, + BadClientId, + ServiceUnavailable, + BadUserNamePassword, + NotAuthorized, + } + + /// Acknowledgement to connect packet + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + } + + impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + }; + + Ok(connack) + } + } + + pub fn write( + code: ConnectReturnCode, + session_present: bool, + buffer: &mut BytesMut, + ) -> Result { + // sesssion present + code + let len = 1 + 1; + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u8(session_present as u8); + buffer.put_u8(code as u8); + + Ok(1 + count + len) + } + + /// Connection return code type + fn connect_return(num: u8) -> Result { + match num { + 0 => Ok(ConnectReturnCode::Success), + 1 => Ok(ConnectReturnCode::RefusedProtocolVersion), + 2 => Ok(ConnectReturnCode::BadClientId), + 3 => Ok(ConnectReturnCode::ServiceUnavailable), + 4 => Ok(ConnectReturnCode::BadUserNamePassword), + 5 => Ok(ConnectReturnCode::NotAuthorized), + num => Err(Error::InvalidConnectReturnCode(num)), + } + } +} + +pub(crate) mod publish { + use super::*; + use bytes::{BufMut, Bytes, BytesMut}; + + #[derive(Debug, Clone, PartialEq)] + pub struct Publish { + pub fixed_header: FixedHeader, + pub raw: Bytes, + } + + impl Publish { + // pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload: Bytes::from(payload.into()), + // } + // } + + // pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload, + // } + // } + + // pub fn len(&self) -> usize { + // let mut len = 2 + self.topic.len(); + // if self.qos != QoS::AtMostOnce && self.pkid != 0 { + // len += 2; + // } + + // len += self.payload.len(); + // len + // } + + pub fn view_meta(&self) -> Result<(&str, u8, u16, bool, bool), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + let dup = (self.fixed_header.byte1 & 0b1000) != 0; + let retain = (self.fixed_header.byte1 & 0b0001) != 0; + + // FIXME: Remove indexes and use get method + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + + let pkid = match qos { + 0 => 0, + 1 => { + let stream = &stream[topic_len..]; + let pkid = view_u16(stream)?; + pkid + } + v => return Err(Error::InvalidQoS(v)), + }; + + if qos == 1 && pkid == 0 { + return Err(Error::PacketIdZero); + } + + Ok((topic, qos, pkid, dup, retain)) + } + + pub fn view_topic(&self) -> Result<&str, Error> { + // FIXME: Remove indexes + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + Ok(topic) + } + + pub fn take_topic_and_payload(mut self) -> Result<(Bytes, Bytes), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + + let variable_header_index = self.fixed_header.fixed_header_len; + self.raw.advance(variable_header_index); + let topic = read_mqtt_bytes(&mut self.raw)?; + + match qos { + 0 => (), + 1 => self.raw.advance(2), + v => return Err(Error::InvalidQoS(v)), + }; + + let payload = self.raw; + Ok((topic, payload)) + } + + pub fn read(fixed_header: FixedHeader, bytes: Bytes) -> Result { + let publish = Publish { + fixed_header, + raw: bytes, + }; + + Ok(publish) + } + } + + pub struct PublishBytes(pub Bytes); + + impl From for Result { + fn from(raw: PublishBytes) -> Self { + let fixed_header = check(raw.0.iter(), 100 * 1024 * 1024)?; + Ok(Publish { + fixed_header, + raw: raw.0, + }) + } + } + + pub fn write( + topic: &str, + qos: QoS, + pkid: u16, + dup: bool, + retain: bool, + payload: &[u8], + buffer: &mut BytesMut, + ) -> Result { + let mut len = 2 + topic.len(); + if qos != QoS::AtMostOnce { + len += 2; + } + + len += payload.len(); + + let dup = dup as u8; + let qos = qos as u8; + let retain = retain as u8; + + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, topic); + + if qos != 0 { + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + buffer.extend_from_slice(&payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +pub(crate) mod puback { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Acknowledgement to QoS1 publish + #[derive(Debug, Clone, PartialEq)] + pub struct PubAck { + pub pkid: u16, + } + + impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { pkid } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { pkid }); + } + + // No properties len or properties if remaining len > 2 but < 4 + if fixed_header.remaining_len < 4 { + return Ok(PubAck { pkid }); + } + + let puback = PubAck { pkid }; + + Ok(puback) + } + } + + pub fn write(pkid: u16, buffer: &mut BytesMut) -> Result { + let len = 2; // pkid + buffer.put_u8(0x40); + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + Ok(1 + count + len) + } +} + +pub(crate) mod subscribe { + use super::*; + use bytes::{Buf, Bytes}; + + /// Subscription packet + #[derive(Debug, Clone, PartialEq)] + pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + } + + impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + }; + + let mut filters = Vec::new(); + filters.push(filter); + Subscribe { pkid: 0, filters } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { path, qos }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); // len of pkid + vec![subscribe filter len] + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_bytes(&mut bytes)?; + let path = std::str::from_utf8(&path)?.to_owned(); + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + }); + } + + let subscribe = Subscribe { pkid, filters }; + + Ok(subscribe) + } + } + + pub fn write( + filters: Vec, + pkid: u16, + buffer: &mut BytesMut, + ) -> Result { + let len = 2 + filters.iter().fold(0, |s, t| s + t.len()); // len of pkid + vec![subscribe filter len] + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let remaining_len_bytes = write_remaining_length(buffer, len)?; + + // write packet id + buffer.put_u16(pkid); + + // write filters + for filter in filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + len) + } + + /// Subscription filter + #[derive(Debug, Clone, PartialEq)] + pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + } + + impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { path, qos } + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } + } +} + +pub(crate) mod suback { + use std::convert::{TryFrom, TryInto}; + + use super::*; + use bytes::{Buf, Bytes}; + + /// Acknowledgement to subscribe + #[derive(Debug, Clone, PartialEq)] + pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + } + + impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { pkid, return_codes } + } + + pub fn len(&self) -> usize { + let len = 2 + self.return_codes.len(); + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { pkid, return_codes }; + Ok(suback) + } + } + + pub fn write( + return_codes: Vec, + pkid: u16, + buffer: &mut BytesMut, + ) -> Result { + let len = 2 + return_codes.len(); + buffer.put_u8(0x90); + + let remaining_len_bytes = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + + let p: Vec = return_codes + .iter() + .map(|&code| match code { + SubscribeReasonCode::Success(qos) => qos as u8, + SubscribeReasonCode::Failure => 0x80, + }) + .collect(); + + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + len) + } + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum SubscribeReasonCode { + Success(QoS), + Failure, + } + + impl TryFrom for SubscribeReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::Success(QoS::AtMostOnce), + 1 => SubscribeReasonCode::Success(QoS::AtLeastOnce), + 128 => SubscribeReasonCode::Failure, + v => return Err(Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } + } + + pub fn codes(c: Vec) -> Vec { + c.into_iter() + .map(|v| { + let qos = qos(v).unwrap(); + SubscribeReasonCode::Success(qos) + }) + .collect() + } +} + +pub(crate) mod pingresp { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} + +pub(crate) mod pingreq { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xC0, 0x00]); + Ok(2) + } +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read_mut(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut Bytes, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Packet { + Connect(connect::Connect), + Publish(publish::Publish), + ConnAck(connack::ConnAck), + PubAck(puback::PubAck), + PingReq, + PingResp, + Subscribe(subscribe::Subscribe), + SubAck(suback::SubAck), + Disconnect, +} diff --git a/benchmarks/simplerouter/src/protocol/v5.rs b/benchmarks/simplerouter/src/protocol/v5.rs new file mode 100644 index 000000000..c63fe816e --- /dev/null +++ b/benchmarks/simplerouter/src/protocol/v5.rs @@ -0,0 +1,1952 @@ +#![allow(dead_code)] + +use std::fmt; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use super::*; + +pub(crate) mod connect { + use super::*; + use bytes::Bytes; + + /// Connection packet initiated by the client + #[derive(Debug, Clone, PartialEq)] + pub struct Connect { + /// Mqtt keep alive time + pub keep_alive: u16, + /// Client Id + pub client_id: String, + /// Clean session. Asks the broker to clear previous state + pub clean_session: bool, + /// Will that broker needs to publish when the client disconnects + pub last_will: Option, + /// Login credentials + pub login: Option, + /// Properties + pub properties: Option, + } + + impl Connect { + pub fn new>(id: S) -> Connect { + Connect { + keep_alive: 10, + client_id: id.into(), + clean_session: true, + last_will: None, + login: None, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + "MQTT".len() // protocol name + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + + len += 2 + self.client_id.len(); + + // last will len + if let Some(last_will) = &self.last_will { + len += last_will.len(); + } + + // username and password len + if let Some(login) = &self.login { + len += login.len(); + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + // Variable header + let protocol_name = read_mqtt_bytes(&mut bytes)?; + let protocol_name = std::str::from_utf8(&protocol_name)?.to_owned(); + if protocol_name != "MQTT" { + return Err(Error::InvalidProtocol); + } + + let protocol_level = read_u8(&mut bytes)?; + if protocol_level != 5 { + return Err(Error::InvalidProtocolLevel(protocol_level)); + } + + connect_v5_part(bytes) + } + } + + pub(crate) fn connect_v5_part(mut bytes: Bytes) -> Result { + let connect_flags = read_u8(&mut bytes)?; + let clean_session = (connect_flags & 0b10) != 0; + let keep_alive = read_u16(&mut bytes)?; + + let properties = ConnectProperties::read(&mut bytes)?; + + // Payload + let client_id = read_mqtt_bytes(&mut bytes)?; + let client_id = std::str::from_utf8(&client_id)?.to_owned(); + let last_will = LastWill::read(connect_flags, &mut bytes)?; + let login = Login::read(connect_flags, &mut bytes)?; + + let connect = Connect { + keep_alive, + client_id, + clean_session, + last_will, + login, + properties, + }; + + Ok(connect) + } + + /// LastWill that broker forwards on behalf of the client + #[derive(Debug, Clone, PartialEq)] + pub struct LastWill { + pub topic: String, + pub message: Bytes, + pub qos: QoS, + pub retain: bool, + } + + impl LastWill { + pub fn _new( + topic: impl Into, + payload: impl Into>, + qos: QoS, + retain: bool, + ) -> LastWill { + LastWill { + topic: topic.into(), + message: Bytes::from(payload.into()), + qos, + retain, + } + } + + fn len(&self) -> usize { + let mut len = 0; + len += 2 + self.topic.len() + 2 + self.message.len(); + len + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let last_will = match connect_flags & 0b100 { + 0 if (connect_flags & 0b0011_1000) != 0 => { + return Err(Error::IncorrectPacketFormat); + } + 0 => None, + _ => { + let will_topic = read_mqtt_bytes(&mut bytes)?; + let will_topic = std::str::from_utf8(&will_topic)?.to_owned(); + let will_message = read_mqtt_bytes(&mut bytes)?; + let will_qos = qos((connect_flags & 0b11000) >> 3)?; + Some(LastWill { + topic: will_topic, + message: will_message, + qos: will_qos, + retain: (connect_flags & 0b0010_0000) != 0, + }) + } + }; + + Ok(last_will) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct Login { + username: String, + password: String, + } + + impl Login { + pub fn new>(u: S, p: S) -> Login { + Login { + username: u.into(), + password: p.into(), + } + } + + fn read(connect_flags: u8, mut bytes: &mut Bytes) -> Result, Error> { + let username = match connect_flags & 0b1000_0000 { + 0 => String::new(), + _ => { + let username = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&username)?.to_owned() + } + }; + + let password = match connect_flags & 0b0100_0000 { + 0 => String::new(), + _ => { + let password = read_mqtt_bytes(&mut bytes)?; + std::str::from_utf8(&password)?.to_owned() + } + }; + + if username.is_empty() && password.is_empty() { + Ok(None) + } else { + Ok(Some(Login { username, password })) + } + } + + fn len(&self) -> usize { + let mut len = 0; + + if !self.username.is_empty() { + len += 2 + self.username.len(); + } + + if !self.password.is_empty() { + len += 2 + self.password.len(); + } + + len + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct ConnectProperties { + /// Expiry interval property after loosing connection + pub session_expiry_interval: Option, + /// Maximum simultaneous packets + pub receive_maximum: Option, + /// Maximum packet size + pub max_packet_size: Option, + /// Maximum mapping integer for a topic + pub topic_alias_max: Option, + pub request_response_info: Option, + pub request_problem_info: Option, + /// List of user properties + pub user_properties: Vec<(String, String)>, + /// Method of authentication + pub authentication_method: Option, + /// Authentication data + pub authentication_data: Option, + } + + impl ConnectProperties { + fn _new() -> ConnectProperties { + ConnectProperties { + session_expiry_interval: None, + receive_maximum: None, + max_packet_size: None, + topic_alias_max: None, + request_response_info: None, + request_problem_info: None, + user_properties: Vec::new(), + authentication_method: None, + authentication_data: None, + } + } + + fn read(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_maximum = None; + let mut max_packet_size = None; + let mut topic_alias_max = None; + let mut request_response_info = None; + let mut request_problem_info = None; + let mut user_properties = Vec::new(); + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_maximum = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::RequestResponseInformation => { + request_response_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RequestProblemInformation => { + request_problem_info = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::AuthenticationMethod => { + let method = read_mqtt_bytes(&mut bytes)?; + let method = std::str::from_utf8(&method)?.to_owned(); + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnectProperties { + session_expiry_interval, + receive_maximum, + max_packet_size, + topic_alias_max, + request_response_info, + request_problem_info, + user_properties, + authentication_method, + authentication_data, + })) + } + + fn len(&self) -> usize { + let mut len = 0; + + if self.session_expiry_interval.is_some() { + len += 1 + 4; + } + + if self.receive_maximum.is_some() { + len += 1 + 2; + } + + if self.max_packet_size.is_some() { + len += 1 + 4; + } + + if self.topic_alias_max.is_some() { + len += 1 + 2; + } + + if self.request_response_info.is_some() { + len += 1 + 1; + } + + if self.request_problem_info.is_some() { + len += 1 + 1; + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + } +} + +pub(crate) mod connack { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum ConnectReturnCode { + Success = 0x00, + UnspecifiedError = 0x80, + MalformedPacket = 0x81, + ProtocolError = 0x82, + ImplementationSpecificError = 0x83, + UnsupportedProtocolVersion = 0x84, + ClientIdentifierNotValid = 0x85, + BadUserNamePassword = 0x86, + NotAuthorized = 0x87, + ServerUnavailable = 0x88, + ServerBusy = 0x89, + Banned = 0x8a, + BadAuthenticationMethod = 0x8c, + TopicNameInvalid = 0x90, + PacketTooLarge = 0x95, + QuotaExceeded = 0x97, + PayloadFormatInvalid = 0x99, + RetainNotSupported = 0x9a, + QoSNotSupported = 0x9b, + UseAnotherServer = 0x9c, + ServerMoved = 0x9d, + ConnectionRateExceeded = 0x94, + } + + /// Acknowledgement to connect packet + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAck { + pub session_present: bool, + pub code: ConnectReturnCode, + pub properties: Option, + } + + impl ConnAck { + pub fn new(code: ConnectReturnCode, session_present: bool) -> ConnAck { + ConnAck { + code, + session_present, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 1 // session present + + 1; // code + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // 1 byte for 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let flags = read_u8(&mut bytes)?; + let return_code = read_u8(&mut bytes)?; + + let session_present = (flags & 0x01) == 1; + let code = connect_return(return_code)?; + let connack = ConnAck { + session_present, + code, + properties: ConnAckProperties::extract(&mut bytes)?, + }; + + Ok(connack) + } + } + + pub fn write( + code: ConnectReturnCode, + session_present: bool, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + // TODO: maybe we can remove double checking if properties == None ? + + let mut len = 1 // session present + + 1; // code + if let Some(ref properties) = properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // 1 byte for 0 len + len += 1; + } + + buffer.put_u8(0x20); + + let count = write_remaining_length(buffer, len)?; + + buffer.put_u8(session_present as u8); + buffer.put_u8(code as u8); + + if let Some(properties) = properties { + properties.write(buffer)?; + } else { + // 1 byte for 0 len + buffer.put_u8(0); + } + + Ok(1 + count + len) + } + + #[derive(Debug, Clone, PartialEq)] + pub struct ConnAckProperties { + pub session_expiry_interval: Option, + pub receive_max: Option, + pub max_qos: Option, + pub retain_available: Option, + pub max_packet_size: Option, + pub assigned_client_identifier: Option, + pub topic_alias_max: Option, + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + pub wildcard_subscription_available: Option, + pub subscription_identifiers_available: Option, + pub shared_subscription_available: Option, + pub server_keep_alive: Option, + pub response_information: Option, + pub server_reference: Option, + pub authentication_method: Option, + pub authentication_data: Option, + } + + impl ConnAckProperties { + pub fn new() -> ConnAckProperties { + ConnAckProperties { + session_expiry_interval: None, + receive_max: None, + max_qos: None, + retain_available: None, + max_packet_size: None, + assigned_client_identifier: None, + topic_alias_max: None, + reason_string: None, + user_properties: Vec::new(), + wildcard_subscription_available: None, + subscription_identifiers_available: None, + shared_subscription_available: None, + server_keep_alive: None, + response_information: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(_) = &self.session_expiry_interval { + len += 1 + 4; + } + + if let Some(_) = &self.receive_max { + len += 1 + 2; + } + + if let Some(_) = &self.max_qos { + len += 1 + 1; + } + + if let Some(_) = &self.retain_available { + len += 1 + 1; + } + + if let Some(_) = &self.max_packet_size { + len += 1 + 4; + } + + if let Some(id) = &self.assigned_client_identifier { + len += 1 + 2 + id.len(); + } + + if let Some(_) = &self.topic_alias_max { + len += 1 + 2; + } + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + if let Some(_) = &self.wildcard_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.subscription_identifiers_available { + len += 1 + 1; + } + + if let Some(_) = &self.shared_subscription_available { + len += 1 + 1; + } + + if let Some(_) = &self.server_keep_alive { + len += 1 + 2; + } + + if let Some(info) = &self.response_information { + len += 1 + 2 + info.len(); + } + + if let Some(reference) = &self.server_reference { + len += 1 + 2 + reference.len(); + } + + if let Some(authentication_method) = &self.authentication_method { + len += 1 + 2 + authentication_method.len(); + } + + if let Some(authentication_data) = &self.authentication_data { + len += 1 + 2 + authentication_data.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut session_expiry_interval = None; + let mut receive_max = None; + let mut max_qos = None; + let mut retain_available = None; + let mut max_packet_size = None; + let mut assigned_client_identifier = None; + let mut topic_alias_max = None; + let mut reason_string = None; + let mut user_properties = Vec::new(); + let mut wildcard_subscription_available = None; + let mut subscription_identifiers_available = None; + let mut shared_subscription_available = None; + let mut server_keep_alive = None; + let mut response_information = None; + let mut server_reference = None; + let mut authentication_method = None; + let mut authentication_data = None; + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SessionExpiryInterval => { + session_expiry_interval = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::ReceiveMaximum => { + receive_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::MaximumQos => { + max_qos = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::RetainAvailable => { + retain_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::AssignedClientIdentifier => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let id = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + id.len(); + assigned_client_identifier = Some(id); + } + PropertyType::MaximumPacketSize => { + max_packet_size = Some(read_u32(&mut bytes)?); + cursor += 4; + } + PropertyType::TopicAliasMaximum => { + topic_alias_max = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ReasonString => { + let reason = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&reason)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + PropertyType::WildcardSubscriptionAvailable => { + wildcard_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SubscriptionIdentifierAvailable => { + subscription_identifiers_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::SharedSubscriptionAvailable => { + shared_subscription_available = Some(read_u8(&mut bytes)?); + cursor += 1; + } + PropertyType::ServerKeepAlive => { + server_keep_alive = Some(read_u16(&mut bytes)?); + cursor += 2; + } + PropertyType::ResponseInformation => { + let info = read_mqtt_bytes(&mut bytes)?; + let info = std::str::from_utf8(&info)?.to_owned(); + cursor += 2 + info.len(); + response_information = Some(info); + } + PropertyType::ServerReference => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reference = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reference.len(); + server_reference = Some(reference); + } + PropertyType::AuthenticationMethod => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let method = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + method.len(); + authentication_method = Some(method); + } + PropertyType::AuthenticationData => { + let data = read_mqtt_bytes(&mut bytes)?; + cursor += 2 + data.len(); + authentication_data = Some(data); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(ConnAckProperties { + session_expiry_interval, + receive_max, + max_qos, + retain_available, + max_packet_size, + assigned_client_identifier, + topic_alias_max, + reason_string, + user_properties, + wildcard_subscription_available, + subscription_identifiers_available, + shared_subscription_available, + server_keep_alive, + response_information, + server_reference, + authentication_method, + authentication_data, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + buffer.put_u8(PropertyType::SessionExpiryInterval as u8); + buffer.put_u32(session_expiry_interval); + } + + if let Some(receive_maximum) = self.receive_max { + buffer.put_u8(PropertyType::ReceiveMaximum as u8); + buffer.put_u16(receive_maximum); + } + + if let Some(qos) = self.max_qos { + buffer.put_u8(PropertyType::MaximumQos as u8); + buffer.put_u8(qos); + } + + if let Some(retain_available) = self.retain_available { + buffer.put_u8(PropertyType::RetainAvailable as u8); + buffer.put_u8(retain_available); + } + + if let Some(max_packet_size) = self.max_packet_size { + buffer.put_u8(PropertyType::MaximumPacketSize as u8); + buffer.put_u32(max_packet_size); + } + + if let Some(id) = &self.assigned_client_identifier { + buffer.put_u8(PropertyType::AssignedClientIdentifier as u8); + write_mqtt_string(buffer, id); + } + + if let Some(topic_alias_max) = self.topic_alias_max { + buffer.put_u8(PropertyType::TopicAliasMaximum as u8); + buffer.put_u16(topic_alias_max); + } + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + if let Some(w) = self.wildcard_subscription_available { + buffer.put_u8(PropertyType::WildcardSubscriptionAvailable as u8); + buffer.put_u8(w); + } + + if let Some(s) = self.subscription_identifiers_available { + buffer.put_u8(PropertyType::SubscriptionIdentifierAvailable as u8); + buffer.put_u8(s); + } + + if let Some(s) = self.shared_subscription_available { + buffer.put_u8(PropertyType::SharedSubscriptionAvailable as u8); + buffer.put_u8(s); + } + + if let Some(keep_alive) = self.server_keep_alive { + buffer.put_u8(PropertyType::ServerKeepAlive as u8); + buffer.put_u16(keep_alive); + } + + if let Some(info) = &self.response_information { + buffer.put_u8(PropertyType::ResponseInformation as u8); + write_mqtt_string(buffer, info); + } + + if let Some(reference) = &self.server_reference { + buffer.put_u8(PropertyType::ServerReference as u8); + write_mqtt_string(buffer, reference); + } + + if let Some(authentication_method) = &self.authentication_method { + buffer.put_u8(PropertyType::AuthenticationMethod as u8); + write_mqtt_string(buffer, authentication_method); + } + + if let Some(authentication_data) = &self.authentication_data { + buffer.put_u8(PropertyType::AuthenticationData as u8); + write_mqtt_bytes(buffer, authentication_data); + } + + Ok(()) + } + } + + /// Connection return code type + fn connect_return(num: u8) -> Result { + match num { + 0x00 => Ok(ConnectReturnCode::Success), + 0x80 => Ok(ConnectReturnCode::UnspecifiedError), + 0x81 => Ok(ConnectReturnCode::MalformedPacket), + 0x82 => Ok(ConnectReturnCode::ProtocolError), + 0x83 => Ok(ConnectReturnCode::ImplementationSpecificError), + 0x84 => Ok(ConnectReturnCode::UnsupportedProtocolVersion), + 0x85 => Ok(ConnectReturnCode::ClientIdentifierNotValid), + 0x86 => Ok(ConnectReturnCode::BadUserNamePassword), + 0x87 => Ok(ConnectReturnCode::NotAuthorized), + 0x88 => Ok(ConnectReturnCode::ServerUnavailable), + 0x89 => Ok(ConnectReturnCode::ServerBusy), + 0x8a => Ok(ConnectReturnCode::Banned), + 0x8c => Ok(ConnectReturnCode::BadAuthenticationMethod), + 0x90 => Ok(ConnectReturnCode::TopicNameInvalid), + 0x95 => Ok(ConnectReturnCode::PacketTooLarge), + 0x97 => Ok(ConnectReturnCode::QuotaExceeded), + 0x99 => Ok(ConnectReturnCode::PayloadFormatInvalid), + 0x9a => Ok(ConnectReturnCode::RetainNotSupported), + 0x9b => Ok(ConnectReturnCode::QoSNotSupported), + 0x9c => Ok(ConnectReturnCode::UseAnotherServer), + 0x9d => Ok(ConnectReturnCode::ServerMoved), + 0x94 => Ok(ConnectReturnCode::ConnectionRateExceeded), + num => Err(Error::InvalidConnectReturnCode(num)), + } + } +} + +pub(crate) mod publish { + use super::*; + use bytes::{BufMut, Bytes, BytesMut}; + + #[derive(Debug, Clone, PartialEq)] + pub struct Publish { + pub fixed_header: FixedHeader, + pub raw: Bytes, + } + + impl Publish { + // pub fn new, P: Into>>(topic: S, qos: QoS, payload: P) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload: Bytes::from(payload.into()), + // } + // } + + // pub fn from_bytes>(topic: S, qos: QoS, payload: Bytes) -> Publish { + // Publish { + // dup: false, + // qos, + // retain: false, + // pkid: 0, + // topic: topic.into(), + // payload, + // } + // } + + // pub fn len(&self) -> usize { + // let mut len = 2 + self.topic.len(); + // if self.qos != QoS::AtMostOnce && self.pkid != 0 { + // len += 2; + // } + + // len += self.payload.len(); + // len + // } + + pub fn view_meta(&self) -> Result<(&str, u8, u16, bool, bool), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + let dup = (self.fixed_header.byte1 & 0b1000) != 0; + let retain = (self.fixed_header.byte1 & 0b0001) != 0; + + // FIXME: Remove indexes and use get method + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + + let pkid = match qos { + 0 => 0, + 1 => { + let stream = &stream[topic_len..]; + let pkid = view_u16(stream)?; + pkid + } + v => return Err(Error::InvalidQoS(v)), + }; + + if qos == 1 && pkid == 0 { + return Err(Error::PacketIdZero); + } + + Ok((topic, qos, pkid, dup, retain)) + } + + pub fn view_topic(&self) -> Result<&str, Error> { + // FIXME: Remove indexes + let stream = &self.raw[self.fixed_header.fixed_header_len..]; + let topic_len = view_u16(&stream)? as usize; + + let stream = &stream[2..]; + let topic = view_str(stream, topic_len)?; + Ok(topic) + } + + pub fn take_topic_and_payload(mut self) -> Result<(Bytes, Bytes), Error> { + let qos = (self.fixed_header.byte1 & 0b0110) >> 1; + + let variable_header_index = self.fixed_header.fixed_header_len; + self.raw.advance(variable_header_index); + let topic = read_mqtt_bytes(&mut self.raw)?; + + match qos { + 0 => (), + 1 => self.raw.advance(2), + v => return Err(Error::InvalidQoS(v)), + }; + + let payload = self.raw; + Ok((topic, payload)) + } + + pub fn read(fixed_header: FixedHeader, bytes: Bytes) -> Result { + let publish = Publish { + fixed_header, + raw: bytes, + }; + + Ok(publish) + } + } + + pub struct PublishBytes(pub Bytes); + + impl From for Result { + fn from(raw: PublishBytes) -> Self { + let fixed_header = check(raw.0.iter(), 100 * 1024 * 1024)?; + Ok(Publish { + fixed_header, + raw: raw.0, + }) + } + } + + pub fn write( + topic: &str, + qos: QoS, + pkid: u16, + dup: bool, + retain: bool, + payload: &[u8], + buffer: &mut BytesMut, + ) -> Result { + let mut len = 2 + topic.len(); + if qos != QoS::AtMostOnce { + len += 2; + } + + len += payload.len(); + + let dup = dup as u8; + let qos = qos as u8; + let retain = retain as u8; + + buffer.put_u8(0b0011_0000 | retain | qos << 1 | dup << 3); + + let count = write_remaining_length(buffer, len)?; + write_mqtt_string(buffer, topic); + + if qos != 0 { + if pkid == 0 { + return Err(Error::PacketIdZero); + } + + buffer.put_u16(pkid); + } + + buffer.extend_from_slice(&payload); + + // TODO: Returned length is wrong in other packets. Fix it + Ok(1 + count + len) + } +} + +pub(crate) mod puback { + use super::*; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + + /// Acknowledgement to QoS1 publish + #[derive(Debug, Clone, PartialEq)] + pub struct PubAck { + pub pkid: u16, + pub reason: PubAckReason, + pub properties: Option, + } + + impl PubAck { + pub fn new(pkid: u16) -> PubAck { + PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + } + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + let pkid = read_u16(&mut bytes)?; + + // No reason code or properties if remaining length == 2 + if fixed_header.remaining_len == 2 { + return Ok(PubAck { + pkid, + reason: PubAckReason::Success, + properties: None, + }); + } + + // No properties len or properties if remaining len > 2 but < 4 + let ack_reason = read_u8(&mut bytes)?; + if fixed_header.remaining_len < 4 { + return Ok(PubAck { + pkid, + reason: reason(ack_reason)?, + properties: None, + }); + } + + let puback = PubAck { + pkid, + reason: reason(ack_reason)?, + properties: PubAckProperties::extract(&mut bytes)?, + }; + + Ok(puback) + } + } + + pub fn write( + pkid: u16, + reason: PubAckReason, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + buffer.put_u8(0x40); + + match &properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + let len = 2 + 1 + properties_len_len + properties_len; + + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + buffer.put_u8(reason as u8); + properties.write(buffer)?; + + Ok(len + count + 1) + } + None => { + // Unlike other packets, property length can be ignored if there are + // no properties in acks + // + // TODO: maybe we should set len = 2 for PubAckReason == Success + let len = 2 + 1; + let count = write_remaining_length(buffer, len)?; + buffer.put_u16(pkid); + buffer.put_u8(reason as u8); + + Ok(len + count + 1) + } + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct PubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + } + + /// Return code in connack + #[derive(Debug, Clone, Copy, PartialEq)] + #[repr(u8)] + pub enum PubAckReason { + Success = 0, + NoMatchingSubscribers = 16, + UnspecifiedError = 128, + ImplementationSpecificError = 131, + NotAuthorized = 135, + TopicNameInvalid = 144, + PacketIdentifierInUse = 145, + QuotaExceeded = 151, + PayloadFormatInvalid = 153, + } + + impl PubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(PubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + /// Connection return code type + fn reason(num: u8) -> Result { + let code = match num { + 0 => PubAckReason::Success, + 16 => PubAckReason::NoMatchingSubscribers, + 128 => PubAckReason::UnspecifiedError, + 131 => PubAckReason::ImplementationSpecificError, + 135 => PubAckReason::NotAuthorized, + 144 => PubAckReason::TopicNameInvalid, + 145 => PubAckReason::PacketIdentifierInUse, + 151 => PubAckReason::QuotaExceeded, + 153 => PubAckReason::PayloadFormatInvalid, + num => return Err(Error::InvalidConnectReturnCode(num)), + }; + + Ok(code) + } +} + +pub(crate) mod subscribe { + use super::*; + use bytes::{Buf, Bytes}; + + /// Subscription packet + #[derive(Clone, PartialEq)] + pub struct Subscribe { + pub pkid: u16, + pub filters: Vec, + pub properties: Option, + } + + impl Subscribe { + pub fn new>(path: S, qos: QoS) -> Subscribe { + let filter = SubscribeFilter { + path: path.into(), + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + let mut filters = Vec::new(); + filters.push(filter); + Subscribe { + pkid: 0, + filters, + properties: None, + } + } + + pub fn add(&mut self, path: String, qos: QoS) -> &mut Self { + let filter = SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + }; + + self.filters.push(filter); + self + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &self.properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubscribeProperties::extract(&mut bytes)?; + + // variable header size = 2 (packet identifier) + let mut filters = Vec::new(); + + while bytes.has_remaining() { + let path = read_mqtt_bytes(&mut bytes)?; + let path = std::str::from_utf8(&path)?.to_owned(); + let options = read_u8(&mut bytes)?; + let requested_qos = options & 0b0000_0011; + + let nolocal = options >> 2 & 0b0000_0001; + let nolocal = if nolocal == 0 { false } else { true }; + + let preserve_retain = options >> 3 & 0b0000_0001; + let preserve_retain = if preserve_retain == 0 { false } else { true }; + + let retain_forward_rule = (options >> 4) & 0b0000_0011; + let retain_forward_rule = match retain_forward_rule { + 0 => RetainForwardRule::OnEverySubscribe, + 1 => RetainForwardRule::OnNewSubscribe, + 2 => RetainForwardRule::Never, + r => return Err(Error::InvalidRetainForwardRule(r)), + }; + + filters.push(SubscribeFilter { + path, + qos: qos(requested_qos)?, + nolocal, + preserve_retain, + retain_forward_rule, + }); + } + + let subscribe = Subscribe { + pkid, + filters, + properties, + }; + + Ok(subscribe) + } + } + + pub fn write( + filters: Vec, + pkid: u16, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + // write packet type + buffer.put_u8(0x82); + + // write remaining length + let mut len = 2 + filters.iter().fold(0, |s, t| s + t.len()); + + if let Some(properties) = &properties { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } else { + // just 1 byte representing 0 len + len += 1; + } + let remaining_len = len; + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + // write packet id + buffer.put_u16(pkid); + + match &properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + // write filters + for filter in filters.iter() { + filter.write(buffer); + } + + Ok(1 + remaining_len_bytes + remaining_len) + } + + /// Subscription filter + #[derive(Clone, PartialEq)] + pub struct SubscribeFilter { + pub path: String, + pub qos: QoS, + pub nolocal: bool, + pub preserve_retain: bool, + pub retain_forward_rule: RetainForwardRule, + } + + impl SubscribeFilter { + pub fn new(path: String, qos: QoS) -> SubscribeFilter { + SubscribeFilter { + path, + qos, + nolocal: false, + preserve_retain: false, + retain_forward_rule: RetainForwardRule::OnEverySubscribe, + } + } + + pub fn len(&self) -> usize { + // filter len + filter + options + 2 + self.path.len() + 1 + } + + fn write(&self, buffer: &mut BytesMut) { + let mut options = 0; + options |= self.qos as u8; + + if self.nolocal { + options |= 1 << 2; + } + + if self.preserve_retain { + options |= 1 << 3; + } + + match self.retain_forward_rule { + RetainForwardRule::OnEverySubscribe => options |= 0 << 4, + RetainForwardRule::OnNewSubscribe => options |= 1 << 4, + RetainForwardRule::Never => options |= 2 << 4, + } + + write_mqtt_string(buffer, self.path.as_str()); + buffer.put_u8(options); + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct SubscribeProperties { + pub id: Option, + pub user_properties: Vec<(String, String)>, + } + + impl SubscribeProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(id) = &self.id { + len += 1 + len_len(*id); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut id = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::SubscriptionIdentifier => { + let (id_len, sub_id) = length(bytes.iter())?; + // TODO: Validate 1 +. Tests are working either way + cursor += 1 + id_len; + bytes.advance(id_len); + id = Some(sub_id) + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubscribeProperties { + id, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(id) = &self.id { + buffer.put_u8(PropertyType::SubscriptionIdentifier as u8); + write_remaining_length(buffer, *id)?; + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub enum RetainForwardRule { + OnEverySubscribe, + OnNewSubscribe, + Never, + } + + impl fmt::Debug for Subscribe { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filters = {:?}, Packet id = {:?}", + self.filters, self.pkid + ) + } + } + + impl fmt::Debug for SubscribeFilter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "Filter = {}, Qos = {:?}, Nolocal = {}, Preserve retain = {}, Forward rule = {:?}", + self.path, self.qos, self.nolocal, self.preserve_retain, self.retain_forward_rule + ) + } + } +} + +pub(crate) mod suback { + use std::convert::{TryFrom, TryInto}; + + use super::*; + use bytes::{Buf, Bytes}; + + /// Acknowledgement to subscribe + #[derive(Debug, Clone, PartialEq)] + pub struct SubAck { + pub pkid: u16, + pub return_codes: Vec, + pub properties: Option, + } + + impl SubAck { + pub fn new(pkid: u16, return_codes: Vec) -> SubAck { + SubAck { + pkid, + return_codes, + properties: None, + } + } + + pub fn len(&self) -> usize { + let mut len = 2 + self.return_codes.len(); + + match &self.properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + len + } + + pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result { + let variable_header_index = fixed_header.fixed_header_len; + bytes.advance(variable_header_index); + + let pkid = read_u16(&mut bytes)?; + let properties = SubAckProperties::extract(&mut bytes)?; + + if !bytes.has_remaining() { + return Err(Error::MalformedPacket); + } + + let mut return_codes = Vec::new(); + while bytes.has_remaining() { + let return_code = read_u8(&mut bytes)?; + return_codes.push(return_code.try_into()?); + } + + let suback = SubAck { + pkid, + return_codes, + properties, + }; + + Ok(suback) + } + } + + pub fn write( + return_codes: Vec, + pkid: u16, + properties: Option, + buffer: &mut BytesMut, + ) -> Result { + buffer.put_u8(0x90); + + let mut len = 2 + return_codes.len(); + + match &properties { + Some(properties) => { + let properties_len = properties.len(); + let properties_len_len = len_len(properties_len); + len += properties_len_len + properties_len; + } + None => { + // just 1 byte representing 0 len + len += 1; + } + } + + let remaining_len = len; + let remaining_len_bytes = write_remaining_length(buffer, remaining_len)?; + + buffer.put_u16(pkid); + + match &properties { + Some(properties) => properties.write(buffer)?, + None => { + write_remaining_length(buffer, 0)?; + } + }; + + let p: Vec = return_codes.iter().map(|code| *code as u8).collect(); + buffer.extend_from_slice(&p); + Ok(1 + remaining_len_bytes + remaining_len) + } + + #[derive(Debug, Clone, PartialEq)] + pub struct SubAckProperties { + pub reason_string: Option, + pub user_properties: Vec<(String, String)>, + } + + impl SubAckProperties { + pub fn len(&self) -> usize { + let mut len = 0; + + if let Some(reason) = &self.reason_string { + len += 1 + 2 + reason.len(); + } + + for (key, value) in self.user_properties.iter() { + len += 1 + 2 + key.len() + 2 + value.len(); + } + + len + } + + pub fn extract(mut bytes: &mut Bytes) -> Result, Error> { + let mut reason_string = None; + let mut user_properties = Vec::new(); + + let (properties_len_len, properties_len) = length(bytes.iter())?; + bytes.advance(properties_len_len); + if properties_len == 0 { + return Ok(None); + } + + let mut cursor = 0; + // read until cursor reaches property length. properties_len = 0 will skip this loop + while cursor < properties_len { + let prop = read_u8(&mut bytes)?; + cursor += 1; + + match property(prop)? { + PropertyType::ReasonString => { + let bytes = read_mqtt_bytes(&mut bytes)?; + let reason = std::str::from_utf8(&bytes)?.to_owned(); + cursor += 2 + reason.len(); + reason_string = Some(reason); + } + PropertyType::UserProperty => { + let key = read_mqtt_bytes(&mut bytes)?; + let key = std::str::from_utf8(&key)?.to_owned(); + let value = read_mqtt_bytes(&mut bytes)?; + let value = std::str::from_utf8(&value)?.to_owned(); + cursor += 2 + key.len() + 2 + value.len(); + user_properties.push((key, value)); + } + _ => return Err(Error::InvalidPropertyType(prop)), + } + } + + Ok(Some(SubAckProperties { + reason_string, + user_properties, + })) + } + + fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> { + let len = self.len(); + write_remaining_length(buffer, len)?; + + if let Some(reason) = &self.reason_string { + buffer.put_u8(PropertyType::ReasonString as u8); + write_mqtt_string(buffer, reason); + } + + for (key, value) in self.user_properties.iter() { + buffer.put_u8(PropertyType::UserProperty as u8); + write_mqtt_string(buffer, key); + write_mqtt_string(buffer, value); + } + + Ok(()) + } + } + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub enum SubscribeReasonCode { + QoS0 = 0, + QoS1 = 1, + QoS2 = 2, + Unspecified = 128, + ImplementationSpecific = 131, + NotAuthorized = 135, + TopicFilterInvalid = 143, + PkidInUse = 145, + QuotaExceeded = 151, + SharedSubscriptionsNotSupported = 158, + SubscriptionIdNotSupported = 161, + WildcardSubscriptionsNotSupported = 162, + } + + impl TryFrom for SubscribeReasonCode { + type Error = Error; + + fn try_from(value: u8) -> Result { + let v = match value { + 0 => SubscribeReasonCode::QoS0, + 1 => SubscribeReasonCode::QoS1, + 2 => SubscribeReasonCode::QoS2, + 128 => SubscribeReasonCode::Unspecified, + 131 => SubscribeReasonCode::ImplementationSpecific, + 135 => SubscribeReasonCode::NotAuthorized, + 143 => SubscribeReasonCode::TopicFilterInvalid, + 145 => SubscribeReasonCode::PkidInUse, + 151 => SubscribeReasonCode::QuotaExceeded, + 158 => SubscribeReasonCode::SharedSubscriptionsNotSupported, + 161 => SubscribeReasonCode::SubscriptionIdNotSupported, + 162 => SubscribeReasonCode::WildcardSubscriptionsNotSupported, + v => return Err(Error::InvalidSubscribeReasonCode(v)), + }; + + Ok(v) + } + } + + pub fn codes(c: Vec) -> Vec { + c.into_iter() + .map(|v| match qos(v).unwrap() { + QoS::AtMostOnce => SubscribeReasonCode::QoS0, + QoS::AtLeastOnce => SubscribeReasonCode::QoS1, + }) + .collect() + } +} + +pub(crate) mod pingresp { + use super::*; + + pub fn write(payload: &mut BytesMut) -> Result { + payload.put_slice(&[0xD0, 0x00]); + Ok(2) + } +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read_mut(stream: &mut BytesMut, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = packet.freeze(); + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +/// Reads a stream of bytes and extracts next MQTT packet out of it +pub fn read(stream: &mut Bytes, max_size: usize) -> Result { + let fixed_header = check(stream.iter(), max_size)?; + + // Test with a stream with exactly the size to check border panics + let packet = stream.split_to(fixed_header.frame_length()); + let packet_type = fixed_header.packet_type()?; + + if fixed_header.remaining_len == 0 { + // no payload packets + return match packet_type { + PacketType::PingReq => Ok(Packet::PingReq), + PacketType::PingResp => Ok(Packet::PingResp), + PacketType::Disconnect => Ok(Packet::Disconnect), + _ => Err(Error::PayloadRequired), + }; + } + + let packet = match packet_type { + PacketType::Connect => Packet::Connect(connect::Connect::read(fixed_header, packet)?), + PacketType::ConnAck => Packet::ConnAck(connack::ConnAck::read(fixed_header, packet)?), + PacketType::Publish => Packet::Publish(publish::Publish::read(fixed_header, packet)?), + PacketType::PubAck => Packet::PubAck(puback::PubAck::read(fixed_header, packet)?), + PacketType::Subscribe => { + Packet::Subscribe(subscribe::Subscribe::read(fixed_header, packet)?) + } + PacketType::SubAck => Packet::SubAck(suback::SubAck::read(fixed_header, packet)?), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect, + v => return Err(Error::UnsupportedPacket(v)), + }; + + Ok(packet) +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Packet { + Connect(connect::Connect), + Publish(publish::Publish), + ConnAck(connack::ConnAck), + PubAck(puback::PubAck), + PingReq, + PingResp, + Subscribe(subscribe::Subscribe), + SubAck(suback::SubAck), + Disconnect, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +fn property(num: u8) -> Result { + let property = match num { + 1 => PropertyType::PayloadFormatIndicator, + 2 => PropertyType::MessageExpiryInterval, + 3 => PropertyType::ContentType, + 8 => PropertyType::ResponseTopic, + 9 => PropertyType::CorrelationData, + 11 => PropertyType::SubscriptionIdentifier, + 17 => PropertyType::SessionExpiryInterval, + 18 => PropertyType::AssignedClientIdentifier, + 19 => PropertyType::ServerKeepAlive, + 21 => PropertyType::AuthenticationMethod, + 22 => PropertyType::AuthenticationData, + 23 => PropertyType::RequestProblemInformation, + 24 => PropertyType::WillDelayInterval, + 25 => PropertyType::RequestResponseInformation, + 26 => PropertyType::ResponseInformation, + 28 => PropertyType::ServerReference, + 31 => PropertyType::ReasonString, + 33 => PropertyType::ReceiveMaximum, + 34 => PropertyType::TopicAliasMaximum, + 35 => PropertyType::TopicAlias, + 36 => PropertyType::MaximumQos, + 37 => PropertyType::RetainAvailable, + 38 => PropertyType::UserProperty, + 39 => PropertyType::MaximumPacketSize, + 40 => PropertyType::WildcardSubscriptionAvailable, + 41 => PropertyType::SubscriptionIdentifierAvailable, + 42 => PropertyType::SharedSubscriptionAvailable, + num => return Err(Error::InvalidPropertyType(num)), + }; + + Ok(property) +}