From 128eb29883fe1efe753e7a0a23a5d27b5c6950db Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Fri, 1 Mar 2024 22:01:54 -0600 Subject: [PATCH] feat: joining works --- Cargo.lock | 16 +- Cargo.toml | 8 +- README.md | 70 +++++ client/.gitignore | 1 - client/Cargo.toml | 66 ----- client/README.md | 0 client/clippy.toml | 3 - client/rustfmt.toml | 21 -- client/src/main.rs | 44 ---- protocol-765/Cargo.toml | 2 +- protocol-765/src/clientbound.rs | 95 ++++--- protocol-765/src/lib.rs | 11 - protocol-765/src/serverbound.rs | 19 +- ser-macro/src/lib.rs | 166 +++++++----- ser/Cargo.toml | 2 +- ser/src/lib.rs | 202 ++++++--------- ser/src/types.rs | 447 +++++++++++--------------------- server/Cargo.toml | 3 + server/clippy.toml | 2 +- server/src/main.rs | 373 +++++++++++++++++++------- 20 files changed, 779 insertions(+), 772 deletions(-) delete mode 100644 client/.gitignore delete mode 100644 client/Cargo.toml delete mode 100644 client/README.md delete mode 100644 client/clippy.toml delete mode 100644 client/rustfmt.toml delete mode 100644 client/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index d0a3775e..bc8a67c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,17 +74,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "client" -version = "0.1.0" -dependencies = [ - "anyhow", - "protocol-765", - "ser", - "tokio", - "tracing-subscriber", -] - [[package]] name = "getrandom" version = "0.2.12" @@ -251,9 +240,9 @@ dependencies = [ name = "protocol-765" version = "0.1.0" dependencies = [ + "anyhow", "ser", "serde", - "tokio", "uuid", ] @@ -297,11 +286,11 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" name = "ser" version = "0.1.0" dependencies = [ + "anyhow", "byteorder", "ser-macro", "serde", "serde_json", - "tokio", "tracing", "uuid", ] @@ -351,6 +340,7 @@ name = "server" version = "0.1.0" dependencies = [ "anyhow", + "bytes", "protocol-765", "ser", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 3739e79b..84a7efff 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "ser", "ser-macro", "server", - "client" ] [workspace.dependencies] @@ -14,3 +13,10 @@ protocol-765 = { path = "protocol-765" } ser = { path = "ser" } ser-macro = { path = "ser-macro" } +# max perf +[profile.release] +#lto = true +#codegen-units = 1 + + + diff --git a/README.md b/README.md index e69de29b..c8f9397e 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,70 @@ +# Project 10k + +How can we get 10k players to PvP at once on a Minecraft server? + +There are many faction servers which have 500 players on Start of The World (SOTW). +Usually this is around the upper limit for the number of players that can be in one world in vanilla Minecraft. + +## The world + +Suppose there is a 10k x 10k world. +This we can allocate every player (10k x 10k) / 10k = 10k blocks. + +This is equivalent of a square of length sqrt(10k) = 100. If we place the player in the middle, this will mean that +we can allocate a square that stretches 50 blocks NSEW of the center where we can place a player. + +A circle of radius r has an area of pi * r^2. If we allocate circles we will have + +pi * r^2 = 10k +r^2 = 10k/pi +r = sqrt(10k / pi) +r = 56.41 + +Which means the distance to the nearest player would be 2*r = 112.82 + +So if we spread players out equally, there will be 112.82 blocks between them. Of course this is not +possible as circles can not cover the entire map, but perhaps this would be the average distance +to the nearest player if we chose random locations (not sure about maths. +If we assigned players to a grid, then there would be exactly 100 blocks between them. + +r_c = 56.41 is 3.525625 chunks and +r_s = 50 is 3.125 chunks + +If players have > 3 chunk render distance, the entire map will be rendered at once. + +## Memory + +If we have a superflat world with one type of block, we would not have to store any blocks. +However, we probably do not want to do this. + +Suppose the world is 20 blocks deep. This means the total volume of the map is + +10k x 10k x 20 blocks = 2,000,000,000 (2 billion) + +If we have one byte per block (which is realistic if we restrict the number of blocks) we get this only taking + +2B bytes = 2 GB + +This is absolutely feasible. + +In fact, if we had a normal size world + +10k x 10k x 256 and one byte per block this would only take + +25.6 GB + +## Core Count + +Suppose we get a 64-core machine. This means that we can allocate +10k / 64 = 156.25 players per core. +This is much under what a normal vanilla server can do on one core. + +## Network + +Network is very dependent on player packing. +A large factor of sending packets over network has to do with sending player updates. +The bandwidth will be O(nm), where m is a "packing factor" and the number of players within a given radius. +Where all players can see all other players (i.e., there is a small radius), the bandwidth will be O(n^2). + +If we adjust the map size so that there is always a constant number of players m within a certain radius of a map, +we will get the bandwidth will be O(nm) = O(Cn) = CO(n) = O(n) for a constant C. diff --git a/client/.gitignore b/client/.gitignore deleted file mode 100644 index ea8c4bf7..00000000 --- a/client/.gitignore +++ /dev/null @@ -1 +0,0 @@ -/target diff --git a/client/Cargo.toml b/client/Cargo.toml deleted file mode 100644 index 317376dd..00000000 --- a/client/Cargo.toml +++ /dev/null @@ -1,66 +0,0 @@ -[package] -name = "client" -version = "0.1.0" -edition = "2021" -authors = ["Andrew Gazelka "] -readme = "README.md" - -[dependencies] -anyhow = "1.0.80" -protocol-765.workspace = true -ser.workspace = true -tokio = { version = "1.36.0", features = ["full"] } -tracing-subscriber = "0.3.18" - - -[lints.rust] -warnings = "deny" - -[lints.clippy] -# cargo -cargo_common_metadata = "allow" -multiple_crate_versions = "warn" -negative_feature_names = "deny" -redundant_feature_names = "deny" -wildcard_dependencies = "deny" - -restriction = { level = "deny", priority = -1 } -missing_docs_in_private_items = "allow" -question_mark_used = "allow" -print_stdout = "allow" -implicit_return = "allow" -shadow_reuse = "allow" -absolute_paths = "allow" -use_debug = "allow" -unwrap_used = "allow" -std_instead_of_alloc = "allow" # consider denying -default_numeric_fallback = "allow" -as_conversions = "allow" -arithmetic_side_effects = "allow" -shadow_unrelated = "allow" -unseparated_literal_suffix = "allow" -else_if_without_else = "allow" -float_arithmetic = "allow" -single_call_fn = "allow" -missing_inline_in_public_items = "allow" -exhaustive_structs = "allow" -pub_use = "allow" - -complexity = "deny" - -nursery = "deny" - -pedantic = { level = "deny", priority = -1 } -uninlined_format_args = "allow" # consider denying; this is allowed because Copilot often generates code that triggers this lint -needless_pass_by_value = "allow" # consider denying -cast_lossless = "allow" -cast_possible_truncation = "allow" # consider denying -cast_precision_loss = "allow" # consider denying -missing_errors_doc = "allow" # consider denying - -perf = "deny" - -style = "deny" - -suspicious = { level = "deny", priority = -1 } -blanket_clippy_restriction_lints = "allow" diff --git a/client/README.md b/client/README.md deleted file mode 100644 index e69de29b..00000000 diff --git a/client/clippy.toml b/client/clippy.toml deleted file mode 100644 index 50600c99..00000000 --- a/client/clippy.toml +++ /dev/null @@ -1,3 +0,0 @@ -# https://doc.rust-lang.org/nightly/clippy/lint_configuration.html -cognitive-complexity-threshold = 5 -excessive-nesting-threshold = 3 diff --git a/client/rustfmt.toml b/client/rustfmt.toml deleted file mode 100644 index 7073741e..00000000 --- a/client/rustfmt.toml +++ /dev/null @@ -1,21 +0,0 @@ -combine_control_expr = true -comment_width = 100 # https://lkml.org/lkml/2020/5/29/1038 -condense_wildcard_suffixes = true -control_brace_style = "AlwaysSameLine" -edition = "2021" -format_code_in_doc_comments = true -format_macro_bodies = true -format_macro_matchers = true -format_strings = true -group_imports = "StdExternalCrate" -imports_granularity = "Crate" -merge_derives = false -newline_style = "Unix" -normalize_comments = true -normalize_doc_attributes = true -overflow_delimited_expr = true -reorder_impl_items = true -reorder_imports = true -unstable_features = true -wrap_comments = true - diff --git a/client/src/main.rs b/client/src/main.rs deleted file mode 100644 index 3724dd91..00000000 --- a/client/src/main.rs +++ /dev/null @@ -1,44 +0,0 @@ -#![allow(unused_imports)] - -use protocol_765::{clientbound, serverbound, serverbound::StatusRequest}; -use ser::{ExactPacket, ReadExtAsync, Writable, WritePacket}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, -}; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt::init(); - // connect to localhost:25565 - let stream = TcpStream::connect("localhost:25565").await?; - - let (reader, mut writer) = tokio::io::split(stream); - - let mut reader = tokio::io::BufReader::new(reader); - - let handshake = serverbound::Handshake { - protocol_version: 765.into(), - server_address: "localhost".to_owned(), - server_port: 25565, - next_state: serverbound::NextState::Status, - }; - - WritePacket::new(handshake).write_async(&mut writer).await?; - - // writer.flush().await?; - - println!("wrote handshake"); - - WritePacket::new(StatusRequest) - .write_async(&mut writer) - .await?; - - println!("wrote status request"); - - let ExactPacket(clientbound::StatusResponse { json }) = reader.read_type().await?; - - println!("read status response json: {}", json); - - Ok(()) -} diff --git a/protocol-765/Cargo.toml b/protocol-765/Cargo.toml index 5cfbf59b..07b51aed 100644 --- a/protocol-765/Cargo.toml +++ b/protocol-765/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" ser.workspace = true uuid = { version = "1.7.0", features = ["v4"] } serde = { version = "1.0.197", features = ["derive"] } -tokio = { version = "1.36.0", features = ["full"] } +anyhow = "1.0.80" [lints.rust] diff --git a/protocol-765/src/clientbound.rs b/protocol-765/src/clientbound.rs index 19851a25..bddd4c3a 100644 --- a/protocol-765/src/clientbound.rs +++ b/protocol-765/src/clientbound.rs @@ -1,11 +1,14 @@ +use std::io::Write; + use ser::{Packet, Readable, Writable}; +use uuid::Uuid; // Status Response // packet id 0x0 #[derive(Packet, Readable, Writable, Debug, Eq, PartialEq, Clone)] #[packet(0x00, Handshake)] -pub struct StatusResponse { - pub json: String, +pub struct StatusResponse<'a> { + pub json: &'a str, } // Pong @@ -16,39 +19,61 @@ pub struct Pong { pub payload: i64, } -// // Encryption Request -// // Packet ID State Bound To Field Name Field Type Notes -// // 0x01 Login Client Server ID String (20) Appears to be empty. -// // Public Key Length VarInt Length of Public Key -// // Public Key Byte Array The server's public key, in bytes. -// // Verify Token Length VarInt Length of Verify Token. Always 4 for Notchian servers. -// // Verify Token Byte Array A sequence of random bytes generated by the server. -// #[derive(Packet, Writable, Debug)] -// #[packet(0x01, Handshake)] -// pub struct EncryptionRequest { -// pub server_id: String, -// pub public_key: Vec, -// pub verify_token: Vec -// } - -#[cfg(test)] -mod tests { - use std::io::Cursor; - - use ser::{ReadExt, Writable}; - - use crate::clientbound::StatusResponse; - - #[test] - fn test_round_trip() { - let json = r#"{"version":{"name":"1.16.5","protocol":754},"players":{"max":20,"online":0,"sample":[]},"description":{"text":"Hello world"}}"#; - let status_response = super::StatusResponse { - json: json.to_string(), +#[derive(Packet, Readable, Writable, Debug)] +#[packet(0x02, Handshake)] +pub struct LoginSuccess<'a> { + pub uuid: Uuid, + pub username: &'a str, + pub properties: Vec>, +} + +#[derive(Readable, Writable, Debug)] +pub struct PropertyHeader<'a> { + pub name: &'a str, + pub value: &'a str, + pub is_signed: bool, +} + +#[derive(Debug)] +pub struct Property<'a> { + pub name: &'a str, + pub value: &'a str, + pub is_signed: bool, + pub signature: Option<&'a str>, +} + +impl<'a> Writable for Property<'a> { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { + self.name.write(writer)?; + self.value.write(writer)?; + self.is_signed.write(writer)?; + if let Some(signature) = self.signature { + true.write(writer)?; + signature.write(writer)?; + } else { + false.write(writer)?; + } + Ok(()) + } +} + +impl<'a> Readable<'a> for Property<'a> { + fn decode(r: &mut &'a [u8]) -> anyhow::Result { + let PropertyHeader { + name, + value, + is_signed, + } = PropertyHeader::decode(r)?; + let signature = if is_signed { + Some(<&str>::decode(r)?) + } else { + None }; - let mut data = Vec::new(); - status_response.clone().write(&mut data).unwrap(); - let mut reader = std::io::Cursor::new(data); - let status_response2: StatusResponse = reader.read_type().unwrap(); - assert_eq!(status_response, status_response2); + Ok(Self { + name, + value, + is_signed, + signature, + }) } } diff --git a/protocol-765/src/lib.rs b/protocol-765/src/lib.rs index f67fc513..9b0a5874 100644 --- a/protocol-765/src/lib.rs +++ b/protocol-765/src/lib.rs @@ -1,16 +1,5 @@ // https://wiki.vg/Protocol -// The login process is as follows: -// 1. C→S: Handshake with Next State set to 2 (login) -// 2. C→S: Login Start -// 3. S→C: Encryption Request -// 4. Client auth -// 5. C→S: Encryption Response -// 6. Server auth, both enable encryption -// 7. S→C: Set Compression (optional) -// 8. S→C: Login Success -// 9. C→S: Login Acknowledged - pub mod clientbound; pub mod serverbound; pub mod status; diff --git a/protocol-765/src/serverbound.rs b/protocol-765/src/serverbound.rs index 8f8d9ee6..0676e78d 100644 --- a/protocol-765/src/serverbound.rs +++ b/protocol-765/src/serverbound.rs @@ -4,9 +4,9 @@ use uuid::Uuid; // packet id 0x0 #[derive(Packet, Writable, Readable, Debug)] #[packet(0x0, Handshake)] -pub struct Handshake { +pub struct Handshake<'a> { pub protocol_version: VarInt, - pub server_address: String, + pub server_address: &'a str, pub server_port: u16, pub next_state: NextState, } @@ -16,7 +16,7 @@ pub struct Handshake { #[packet(0x0, Handshake)] pub struct StatusRequest; -#[derive(EnumReadable, EnumWritable, Debug, Eq, PartialEq)] +#[derive(EnumReadable, EnumWritable, Debug, Eq, PartialEq, Copy, Clone)] pub enum NextState { Status = 1, Login = 2, @@ -25,8 +25,8 @@ pub enum NextState { // login start #[derive(Packet, Readable, Debug)] #[packet(0x0, Handshake)] -pub struct LoginStart { - pub username: String, +pub struct LoginStart<'a> { + pub username: &'a str, pub uuid: Uuid, } @@ -35,3 +35,12 @@ pub struct LoginStart { pub struct Ping { pub payload: i64, } + +// Login Acknowledged +// Acknowledgement to the Login Success packet sent by the server. +// +// Packet ID State Bound To Field Name Field Type Notes +// 0x03 Login Server no fields +#[derive(Packet, Writable, Readable, Debug)] +#[packet(0x3, Handshake)] +pub struct LoginAcknowledged; diff --git a/ser-macro/src/lib.rs b/ser-macro/src/lib.rs index 499d2730..2380993e 100644 --- a/ser-macro/src/lib.rs +++ b/ser-macro/src/lib.rs @@ -2,7 +2,8 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ parse::{Parse, ParseStream}, - parse_macro_input, DeriveInput, Error, ItemEnum, ItemStruct, Meta, Token, + parse_macro_input, parse_quote, Data, DeriveInput, Error, GenericParam, Generics, ItemEnum, + Meta, Token, }; struct PacketParams(syn::LitInt, syn::Ident); @@ -16,7 +17,6 @@ impl Parse for PacketParams { } } -// https://doc.rust-lang.org/reference/procedural-macros.html#attribute-macros #[proc_macro_derive(Packet, attributes(packet))] pub fn packet(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -33,9 +33,10 @@ pub fn packet(input: TokenStream) -> TokenStream { return TokenStream::from(error.to_compile_error()); }; - let ident = input.ident; + let ident = &input.ident; + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - // let tokens = packet_attr.to_token_stream(); let Meta::List(meta) = &packet_attr.meta else { let error = Error::new_spanned( &packet_attr.meta, @@ -55,9 +56,10 @@ pub fn packet(input: TokenStream) -> TokenStream { }; let expanded = quote! { - impl ::ser::Packet for #ident { + impl #impl_generics ::ser::Packet for #ident #ty_generics #where_clause { const ID: i32 = #id; const STATE: ser::types::PacketState = ser::types::PacketState::#kind; + const NAME: &'static str = stringify!(#ident); } }; @@ -66,59 +68,105 @@ pub fn packet(input: TokenStream) -> TokenStream { #[proc_macro_derive(Writable)] pub fn writable(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as ItemStruct); + let input = parse_macro_input!(input as DeriveInput); let name = input.ident; + let generics = input.generics; + let (generics, where_clause) = extend_generics_and_create_where_clause(&generics); - let idents: Vec<_> = input - .fields - .iter() - .map(|x| x.ident.as_ref().unwrap()) - .collect(); + // Extracting field identifiers and ensuring that the struct can actually be used with this + // macro. + let idents: Vec<_> = match input.data { + syn::Data::Struct(data_struct) => data_struct + .fields + .iter() + .filter_map(|f| f.ident.clone()) + .collect(), + _ => return TokenStream::new(), // Early return if not struct + }; let expanded = quote! { - impl ::ser::Writable for #name { - fn write(self, writer: &mut impl ::std::io::Write) -> ::std::io::Result<()> { - // todo: make sure to make sure all fields are ::ser::Writable + impl #generics ::ser::Writable for #name #generics #where_clause { + fn write(&self, writer: &mut impl ::std::io::Write) -> ::anyhow::Result<()> { #(self.#idents.write(writer)?;)* Ok(()) } - - async fn write_async(self, writer: &mut (impl ::tokio::io::AsyncWrite + ::std::marker::Unpin)) -> ::std::io::Result<()> { - // todo: make sure to make sure all fields are ::ser::Writable - #(self.#idents.write_async(writer).await?;)* - Ok(()) - } + // + // async fn write_async(self, writer: &mut (impl ::tokio::io::AsyncWrite + ::std::marker::Unpin)) -> ::anyhow::Result<()> { + // #(self.#idents.write_async(writer).await?;)* + // Ok(()) + // } } }; TokenStream::from(expanded) } +// lifetime if needed, and to create a where clause that bounds all fields by the +// specified lifetime. It returns a tuple of the possibly extended generics and +// the where clause. +fn extend_generics_and_create_where_clause( + generics: &Generics, +) -> (Generics, proc_macro2::TokenStream) { + let mut generics = generics.clone(); + let mut where_clause = generics.make_where_clause().predicates.clone(); + + for param in &generics.params { + if let GenericParam::Type(type_param) = param { + // Assuming all types must implement the 'Readable' trait bound by a certain lifetime + where_clause.push(parse_quote!(#type_param: ::ser::Readable<'a>)); + } + } + + (generics, quote!(#where_clause)) +} + #[proc_macro_derive(Readable)] pub fn readable(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as ItemStruct); + let input = parse_macro_input!(input as DeriveInput); let name = input.ident; + let generics = input.generics; + let (generics, where_clause) = extend_generics_and_create_where_clause(&generics); + + // if generics is empty, then make it <'_> + let readable_generics = if generics.params.is_empty() { + quote! { <'_> } + } else { + quote! { #generics } + }; - let idents: Vec<_> = input + let slice_generics = { + let lifetimes = generics.lifetimes().collect::>(); + if !lifetimes.is_empty() { + // Assume using the first lifetime if available + let lifetime = lifetimes[0]; + quote! { &#lifetime } + } else { + // Default to '_ if no lifetimes are present + quote! { &'_ } + } + }; + + let Data::Struct(data) = input.data else { + let error = Error::new_spanned( + &name, + "only structs are supported for the `#[derive(Readable)]` attribute", + ); + return TokenStream::from(error.to_compile_error()); + }; + + let idents: Vec<_> = data .fields .iter() - .map(|x| x.ident.as_ref().unwrap()) + .map(|f| f.ident.as_ref().unwrap()) .collect(); - let types: Vec<_> = input.fields.iter().map(|x| &x.ty).collect(); let expanded = quote! { - impl ::ser::Readable for #name { - fn read(reader: &mut impl ::std::io::BufRead) -> ::std::io::Result { - Ok(#name { - #(#idents: <#types as ::ser::Readable>::read(reader)?),* - }) - } - - async fn read_async(reader: &mut (impl ::tokio::io::AsyncBufRead + ::std::marker::Unpin)) -> ::std::io::Result { - Ok(#name { - #(#idents: <#types as ::ser::Readable>::read_async(reader).await?),* + impl #generics ::ser::Readable #readable_generics for #name #generics where #where_clause { + fn decode(r: &mut #slice_generics [u8]) -> ::anyhow::Result { + Ok(Self { + #(#idents: ::ser::Readable::decode(r)?,)* }) } } @@ -129,23 +177,23 @@ pub fn readable(input: TokenStream) -> TokenStream { #[proc_macro_derive(EnumWritable)] pub fn enum_writable(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as ItemEnum); + let input = parse_macro_input!(input as DeriveInput); let name = input.ident; let expanded = quote! { impl ::ser::Writable for #name { - fn write(self, writer: &mut impl ::std::io::Write) -> ::std::io::Result<()> { - let v = self as i32; + fn write(&self, writer: &mut impl ::std::io::Write) -> ::anyhow::Result<()> { + let v = *self as i32; let v = VarInt(v); v.write(writer) } - async fn write_async(self, writer: &mut (impl ::tokio::io::AsyncWrite + ::std::marker::Unpin)) -> ::std::io::Result<()> { - let v = self as i32; - let v = VarInt(v); - v.write_async(writer).await - } + // async fn write_async(self, writer: &mut (impl ::tokio::io::AsyncWrite + ::std::marker::Unpin)) -> ::anyhow::Result<()> { + // let v = self as i32; + // let v = VarInt(v); + // v.write_async(writer).await + // } } }; @@ -154,16 +202,24 @@ pub fn enum_writable(input: TokenStream) -> TokenStream { #[proc_macro_derive(EnumReadable)] pub fn enum_readable_count(input: TokenStream) -> TokenStream { - let input = parse_macro_input!(input as ItemEnum); + let input = parse_macro_input!(input as DeriveInput); let name = input.ident; - let idents: Vec<_> = input.variants.iter().map(|x| x.ident.clone()).collect(); + let Data::Enum(data) = input.data else { + let error = Error::new_spanned( + &name, + "only enums are supported for the `#[derive(EnumReadable)]` attribute", + ); + return TokenStream::from(error.to_compile_error()); + }; + + let idents: Vec<_> = data.variants.iter().map(|x| x.ident.clone()).collect(); // for instance if we have enum Foo { A = 3, B = 5, // C = 7}, then the discriminants will be 3, 5, 7 else default to idx // let discriminants = // todo - let discriminants: Vec<_> = input + let discriminants: Vec<_> = data .variants .iter() .enumerate() @@ -180,22 +236,12 @@ pub fn enum_readable_count(input: TokenStream) -> TokenStream { .collect(); let expanded = quote! { - impl ser::Readable for #name { - fn read(byte_reader: &mut impl ::std::io::BufRead) -> ::std::io::Result { - let VarInt(inner) = VarInt::read(byte_reader)?; - - match inner { - #(#discriminants => Ok(#name::#idents)),*, - _ => ::std::result::Result::Err(::std::io::Error::new(::std::io::ErrorKind::InvalidData, "Invalid enum discriminant")) - } - } - - async fn read_async(byte_reader: &mut (impl ::tokio::io::AsyncBufRead + ::std::marker::Unpin)) -> ::std::io::Result { - let VarInt(inner) = VarInt::read_async(byte_reader).await?; - + impl ser::Readable<'_> for #name { + fn decode(r: &mut &[u8]) -> anyhow::Result { + let VarInt(inner) = VarInt::decode(r)?; match inner { - #(#discriminants => Ok(#name::#idents)),*, - _ => ::std::result::Result::Err(::std::io::Error::new(::std::io::ErrorKind::InvalidData, "Invalid enum discriminant")) + #(#discriminants => Ok(#name::#idents),)* + _ => Err(anyhow::anyhow!("invalid discriminant")) } } } diff --git a/ser/Cargo.toml b/ser/Cargo.toml index 83d38ec3..d9b1218f 100644 --- a/ser/Cargo.toml +++ b/ser/Cargo.toml @@ -12,7 +12,7 @@ uuid = "1.7.0" tracing = "0.1.40" serde = "1.0.197" serde_json = "1.0.114" -tokio = { version = "1.36.0", features = ["io-util"] } +anyhow = "1.0.80" [features] default = ["ser-macro"] diff --git a/ser/src/lib.rs b/ser/src/lib.rs index b04a9991..16afa064 100644 --- a/ser/src/lib.rs +++ b/ser/src/lib.rs @@ -1,9 +1,8 @@ -use std::{fmt::Debug, future::Future, io::Cursor}; +use std::{fmt::Debug, io::Write}; // re-export the `ser-macro` crate #[cfg(feature = "ser-macro")] pub use ser_macro::*; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tracing::debug; use crate::types::{VarInt, VarUInt}; @@ -11,51 +10,47 @@ use crate::types::{VarInt, VarUInt}; pub mod types; pub trait Writable { - fn write(self, writer: &mut impl std::io::Write) -> std::io::Result<()>; - fn write_async( - self, - writer: &mut (impl AsyncWrite + Unpin), - ) -> impl Future> - where - Self: Sized; -} - -pub trait Readable { - fn read(reader: &mut impl std::io::BufRead) -> std::io::Result - where - Self: Sized; - - fn read_async( - reader: &mut (impl tokio::io::AsyncBufRead + Unpin), - ) -> impl Future> - where - Self: Sized; + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()>; + // fn write_async( + // &self, + // writer: &mut (impl AsyncWrite + Unpin), + // ) -> impl Future> + // where + // Self: Sized; } -// ext trait on std::io::BufRead -pub trait ReadExt: std::io::BufRead { - fn read_type(&mut self) -> std::io::Result - where - Self: Sized, - { - T::read(self) - } -} - -pub trait ReadExtAsync: tokio::io::AsyncBufRead + Unpin { - fn read_type(&mut self) -> impl Future> - where - Self: Sized, - { - T::read_async(self) - } +pub trait Readable<'a>: Sized { + /// Reads this object from the provided byte slice. + /// + /// Implementations of `Readable` are expected to shrink the slice from the + /// front as bytes are read. + fn decode(r: &mut &'a [u8]) -> anyhow::Result; } -impl ReadExt for T {} -impl ReadExtAsync for T {} - -pub trait WriteExt: std::io::Write { - fn write_type(&mut self, data: T) -> std::io::Result<&mut Self> +// // ext trait on std::io::BufRead +// pub trait ReadExt: std::io::BufRead { +// fn read_type(&mut self) -> std::anyhow::Result +// where +// Self: Sized, +// { +// T::read(self) +// } +// } +// +// pub trait ReadExtAsync: tokio::io::AsyncBufRead + Unpin { +// fn read_type(&mut self) -> impl Future> +// where +// Self: Sized, +// { +// T::read_async(self) +// } +// } +// +// impl ReadExt for T {} +// impl ReadExtAsync for T {} + +pub trait WriteExt: Write { + fn write_type(&mut self, data: T) -> anyhow::Result<&mut Self> where Self: Sized, { @@ -64,27 +59,28 @@ pub trait WriteExt: std::io::Write { } } -pub trait WriteExtAsync: AsyncWrite + Unpin { - fn write_type( - &mut self, - data: T, - ) -> impl Future> - where - Self: Sized, - { - async move { - data.write_async(self).await?; - Ok(self) - } - } -} - -impl WriteExt for T {} -impl WriteExtAsync for T {} +// pub trait WriteExtAsync: AsyncWrite + Unpin { +// fn write_type( +// &mut self, +// data: T, +// ) -> impl Future> +// where +// Self: Sized, +// { +// async move { +// data.write_async(self).await?; +// Ok(self) +// } +// } +// } + +impl WriteExt for T {} +// impl WriteExtAsync for T {} pub trait Packet { const ID: i32; const STATE: types::PacketState; + const NAME: &'static str; } #[derive(Debug)] @@ -111,7 +107,7 @@ impl WritePacket { } impl Writable for WritePacket { - fn write(self, writer: &mut impl std::io::Write) -> std::io::Result<()> { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { let mut data = Vec::new(); self.id.write(&mut data)?; self.data.write(&mut data)?; @@ -126,72 +122,20 @@ impl Writable for WritePacket { } // #[tracing::instrument(skip(writer))] - async fn write_async(self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { - let mut data = Vec::new(); - self.id.write(&mut data)?; - self.data.write(&mut data)?; - - // todo: unnecessary allocation - VarUInt(data.len() as u32).write_async(writer).await?; - writer.write_all(&data).await?; - - debug!("wrote packet ID: {:#x} length: {}", self.id.0, data.len()); - - // format hex debug! raw data 0x00 - // debug!("{:#x?}", data); - - Ok(()) - } -} - -impl Readable for ExactPacket { - fn read(reader: &mut impl std::io::BufRead) -> std::io::Result - where - Self: Sized, - { - debug!("reading packet"); - let length = VarUInt::read(reader)?; - let mut data = vec![0; length.0 as usize]; - - reader.read_exact(&mut data)?; - - let mut cursor = Cursor::new(data); - - let _id = VarInt::read(&mut cursor)?; - - let result = T::read(&mut cursor)?; - - debug!("read packet {:?}", result); - - Ok(Self(result)) - } - - async fn read_async( - reader: &mut (impl tokio::io::AsyncBufRead + Unpin), - ) -> std::io::Result - where - Self: Sized, - { - debug!("reading packet"); - let VarUInt(length) = reader.read_type().await?; - let mut data = vec![0; length as usize]; - - reader.read_exact(&mut data).await?; - - let mut cursor = Cursor::new(data); - - let id = VarInt::read(&mut cursor)?; - - if id.0 != T::ID { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("expected packet ID: {:#x} got: {:#x}", T::ID, id.0), - )); - } - - let result = T::read(&mut cursor)?; - debug!("read packet {:?}", result); - - Ok(Self(result)) - } + // async fn write_async(&self, writer: &mut (impl AsyncWrite + Unpin)) -> anyhow::Result<()> { + // let mut data = Vec::new(); + // self.id.write(&mut data)?; + // self.data.write(&mut data)?; + // + // // todo: unnecessary allocation + // VarUInt(data.len() as u32).write_async(writer).await?; + // writer.write_all(&data).await?; + // + // debug!("wrote packet ID: {:#x} length: {}", self.id.0, data.len()); + // + // // format hex debug! raw data 0x00 + // // debug!("{:#x?}", data); + // + // Ok(()) + // } } diff --git a/ser/src/types.rs b/ser/src/types.rs index 326439fe..1f12612f 100644 --- a/ser/src/types.rs +++ b/ser/src/types.rs @@ -1,16 +1,17 @@ use std::{ fmt::Debug, - io, - io::{BufRead, Write}, + io::{Read, Write}, }; -use byteorder::{ReadBytesExt, WriteBytesExt}; -use serde::{de::DeserializeOwned, Serialize}; -use tokio::io::{AsyncBufRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use serde::Serialize; use tracing::debug; use uuid::Uuid; -use crate::{Readable, Writable, WriteExt, WriteExtAsync}; +use crate::{Readable, Writable, WriteExt}; + +/// The maximum number of bytes in a single Minecraft packet. +pub const MAX_PACKET_SIZE: i32 = 2_097_152; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[repr(u8)] @@ -19,39 +20,61 @@ pub enum PacketState { } impl Writable for i64 { - fn write(self, writer: &mut impl Write) -> io::Result<()> { - writer.write_i64::(self) + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { + writer.write_i64::(*self)?; + Ok(()) } +} - async fn write_async(self, writer: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> { - writer.write_i64(self).await +impl<'a> Readable<'a> for &'a str { + fn decode(r: &mut &'a [u8]) -> anyhow::Result { + let len = VarInt::decode(r)?; + #[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss + )] + let len = len.0 as usize; + let s = std::str::from_utf8(&r[..len])?; + *r = &r[len..]; + Ok(s) } } -impl Readable for String { - fn read(reader: &mut impl BufRead) -> io::Result - where - Self: Sized, - { - let length = VarUInt::read(reader)?.0 as usize; - let mut buffer = vec![0; length]; - reader.read_exact(&mut buffer)?; - Ok(Self::from_utf8(buffer).unwrap()) +impl<'a, T: Readable<'a>> Readable<'a> for Vec { + fn decode(r: &mut &'a [u8]) -> anyhow::Result { + let len = VarInt::decode(r)?; + + #[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss + )] + let mut vec = Self::with_capacity(len.0 as usize); + for _ in 0..len.0 { + vec.push(T::decode(r)?); + } + Ok(vec) } +} - async fn read_async(reader: &mut (impl AsyncBufRead + Unpin)) -> io::Result - where - Self: Sized, - { - let length = VarUInt::read_async(reader).await?.0 as usize; - let mut buffer = vec![0; length]; - reader.read_exact(&mut buffer).await?; - Ok(Self::from_utf8(buffer).unwrap()) +impl Writable for Vec { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { + #[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss + )] + VarInt(self.len() as i32).write(writer)?; + for item in self { + item.write(writer)?; + } + Ok(()) } } impl Writable for String { - fn write(self, writer: &mut impl Write) -> io::Result<()> { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { let bytes = self.as_bytes(); let length = bytes.len() as u32; @@ -59,110 +82,54 @@ impl Writable for String { debug!("Writing string (sync): {self} with: {length} bytes and {str_length} characters"); - writer.write_type(VarUInt(length))?.write_all(bytes) - } - - async fn write_async(self, writer: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> { - let bytes = self.as_bytes(); - let length = bytes.len() as u32; - - let str_length = self.len(); - - debug!("Writing string: {self} with: {length} bytes and {str_length} characters"); - - writer - .write_type(VarUInt(length)) - .await? - .write_all(bytes) - .await + writer.write_type(VarUInt(length))?.write_all(bytes)?; + Ok(()) } } -impl Readable for u16 { - fn read(reader: &mut impl BufRead) -> io::Result - where - Self: Sized, - { - reader.read_u16::() - } - - async fn read_async(reader: &mut (impl AsyncBufRead + Unpin)) -> io::Result - where - Self: Sized, - { - let mut buffer = [0; 2]; - reader.read_exact(&mut buffer).await?; - Ok(Self::from_be_bytes(buffer)) +impl Readable<'_> for u16 { + fn decode(r: &mut &[u8]) -> anyhow::Result { + Ok(r.read_u16::()?) } } impl Writable for u16 { - fn write(self, writer: &mut impl Write) -> io::Result<()> { - writer.write_u16::(self) - } - - async fn write_async(self, writer: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> { - writer.write_u16(self).await + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { + writer.write_u16::(*self)?; + Ok(()) } } -impl Readable for i64 { - fn read(reader: &mut impl BufRead) -> io::Result - where - Self: Sized, - { - reader.read_i64::() +impl Readable<'_> for i16 { + fn decode(r: &mut &[u8]) -> anyhow::Result { + Ok(r.read_i16::()?) } +} - async fn read_async(reader: &mut (impl AsyncBufRead + Unpin)) -> io::Result - where - Self: Sized, - { - reader.read_i64().await +impl Readable<'_> for i64 { + fn decode(r: &mut &'_ [u8]) -> anyhow::Result { + Ok(r.read_i64::()?) } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Json(pub T); -impl Readable for Json { - fn read(reader: &mut impl BufRead) -> io::Result - where - Self: Sized, - { - let string = String::read(reader)?; - let Ok(value) = serde_json::from_str(&string) else { - return Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid JSON")); - }; - - Ok(Self(value)) - } - - async fn read_async(reader: &mut (impl AsyncBufRead + Unpin)) -> io::Result - where - Self: Sized, - { - let string = String::read_async(reader).await?; - let Ok(value) = serde_json::from_str(&string) else { - return Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid JSON")); - }; - +impl<'a, T: serde::de::Deserialize<'a>> Readable<'a> for Json { + fn decode(r: &mut &'a [u8]) -> anyhow::Result { + let s = std::str::from_utf8(r)?; + let value = serde_json::from_str(s)?; + *r = &r[s.len()..]; Ok(Self(value)) } } impl Writable for Json { - fn write(self, writer: &mut impl Write) -> io::Result<()> { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { let string = serde_json::to_string_pretty(&self.0)?; debug!("Writing JSON:\n{string}"); string.write(writer) } - - async fn write_async(self, writer: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> { - let string = serde_json::to_string_pretty(&self.0)?; - debug!("Writing JSON:\n{string}"); - string.write_async(writer).await - } } // Variable-length data encoding a two's complement signed 32-bit integer; more info in their @@ -170,143 +137,112 @@ impl Writable for Json { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct VarInt(pub i32); +pub enum VarIntDecodeError { + Incomplete, + TooLarge, +} + +impl VarInt { + /// The maximum number of bytes a `VarInt` could occupy when read from and + /// written to the Minecraft protocol. + pub const MAX_SIZE: usize = 5; + + pub fn decode_partial(mut r: impl Read) -> Result { + let mut val = 0; + for i in 0..Self::MAX_SIZE { + let byte = r.read_u8().map_err(|_| VarIntDecodeError::Incomplete)?; + val |= (byte as i32 & 0b0111_1111) << (i * 7); + if byte & 0b1000_0000 == 0 { + return Ok(val); + } + } + + Err(VarIntDecodeError::TooLarge) + } + + #[must_use] + pub const fn written_size(self) -> usize { + let mut value = self.0; + let mut size = 0; + loop { + size += 1; + value >>= 7; + if value == 0 { + break; + } + } + size + } +} + impl From for VarInt { fn from(value: i32) -> Self { Self(value) } } +impl Readable<'_> for bool { + fn decode(r: &mut &[u8]) -> anyhow::Result { + Ok(r.read_u8()? != 0) + } +} + +impl Writable for bool { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { + writer.write_all(&[*self as u8])?; + Ok(()) + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct VarUInt(pub u32); impl Writable for VarUInt { - fn write(self, writer: &mut impl Write) -> io::Result<()> { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { #[allow(clippy::cast_sign_loss, clippy::cast_possible_wrap)] let value = VarInt(self.0 as i32); value.write(writer) } - - async fn write_async(self, writer: &mut (impl AsyncWrite + Unpin)) -> io::Result<()> { - #[allow(clippy::cast_sign_loss, clippy::cast_possible_wrap)] - let value = VarInt(self.0 as i32); - value.write_async(writer).await - } } -impl Readable for VarUInt { - fn read(reader: &mut impl BufRead) -> io::Result - where - Self: Sized, - { - #[allow(clippy::cast_sign_loss)] - Ok(Self(VarInt::read(reader)?.0 as u32)) - } - - async fn read_async(reader: &mut (impl AsyncBufRead + Unpin)) -> io::Result - where - Self: Sized, - { - #[allow(clippy::cast_sign_loss)] - Ok(Self(VarInt::read_async(reader).await?.0 as u32)) +impl Readable<'_> for Uuid { + fn decode(r: &mut &'_ [u8]) -> anyhow::Result { + let x = r.read_u128::()?; + Ok(Self::from_u128(x)) } } -impl Readable for Uuid { - fn read(reader: &mut impl BufRead) -> std::io::Result - where - Self: Sized, - { - // Encoded as an unsigned 128-bit integer (or two unsigned 64-bit integers: the most - // significant 64 bits and then the least significant 64 bits) - let value = reader.read_u128::()?; - debug!("Read UUID: {}", value); - Ok(Self::from_u128(value)) - } - - async fn read_async(reader: &mut (impl AsyncBufRead + Unpin)) -> std::io::Result - where - Self: Sized, - { - let value = reader.read_u128().await?; - debug!("Read UUID: {}", value); - Ok(Self::from_u128(value)) +impl Writable for Uuid { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { + writer.write_u128::(self.as_u128())?; + Ok(()) } } -const SEGMENT_BITS: u8 = 0x7F; -const CONTINUE_BIT: u8 = 0x80; - -impl Readable for VarInt { - fn read(reader: &mut impl BufRead) -> std::io::Result - where - Self: Sized, - { - let mut value = 0i32; - let mut position = 0; - - loop { - let mut buffer = [0u8; 1]; - reader.read_exact(&mut buffer)?; - let current_byte = buffer[0]; - - let segment_value = (current_byte & SEGMENT_BITS) as i32; - // Ensure we're not shifting bits into oblivion, which can happen with a malformed - // VarInt. - if position > 32 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "VarInt is too big", - )); - } - // SAFETY: `position` is guaranteed to be at most 28 here, ensuring the shift is safe. - value |= segment_value << position; - - if current_byte & CONTINUE_BIT == 0 { - break; - } - - position += 7; - } - Ok(Self(value)) - } - - async fn read_async(reader: &mut (impl AsyncBufRead + Unpin)) -> std::io::Result - where - Self: Sized, - { - let mut value = 0i32; - let mut position = 0; - +impl Readable<'_> for VarInt { + fn decode(r: &mut &[u8]) -> anyhow::Result { + let mut result = 0; + let mut shift = 0; loop { - let mut buffer = [0u8; 1]; - reader.read_exact(&mut buffer).await?; - let current_byte = buffer[0]; - - let segment_value = (current_byte & SEGMENT_BITS) as i32; - // Ensure we're not shifting bits into oblivion, which can happen with a malformed - // VarInt. - if position > 32 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "VarInt is too big", - )); - } - // SAFETY: `position` is guaranteed to be at most 28 here, ensuring the shift is safe. - value |= segment_value << position; - - if current_byte & CONTINUE_BIT == 0 { + let byte = r.read_u8()?; + result |= ((byte & 0x7F) as i32) << shift; + if byte & 0x80 == 0 { break; } - - position += 7; + shift += 7; } - Ok(Self(value)) + Ok(Self(result)) } } +const SEGMENT_BITS: u8 = 0x7F; +const CONTINUE_BIT: u8 = 0x80; + impl Writable for VarInt { #[allow(clippy::cast_sign_loss, clippy::cast_possible_wrap)] - fn write(self, writer: &mut impl std::io::Write) -> std::io::Result<()> { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { + // todO: + // should this take by value? let mut value = self.0; loop { if (value & !SEGMENT_BITS as i32) == 0 { @@ -320,37 +256,16 @@ impl Writable for VarInt { value = ((value as u32) >> 7) as i32; } } - - #[allow(clippy::cast_sign_loss, clippy::cast_possible_wrap)] - async fn write_async(self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { - let mut value = self.0; - loop { - if (value & !SEGMENT_BITS as i32) == 0 { - writer.write_all(&[value as u8]).await?; - return Ok(()); - } - writer - .write_all(&[(value as u8 & SEGMENT_BITS) | CONTINUE_BIT]) - .await?; - // Note: Rust does not have a logical right shift operator (>>>), but since we're - // working with a signed int, converting to u32 for the shift operation - // achieves the same effect of not preserving the sign bit. - value = ((value as u32) >> 7) as i32; - } - } } pub struct VarLong(pub i64); -impl Readable for VarLong { - fn read(reader: &mut impl BufRead) -> std::io::Result - where - Self: Sized, - { +impl Readable<'_> for VarLong { + fn decode(r: &mut &[u8]) -> anyhow::Result { let mut result = 0; let mut shift = 0; loop { - let byte = reader.read_u8()?; + let byte = r.read_u8()?; result |= ((byte & 0x7F) as i64) << shift; if byte & 0x80 == 0 { break; @@ -359,27 +274,19 @@ impl Readable for VarLong { } Ok(Self(result)) } +} - async fn read_async(reader: &mut (impl AsyncBufRead + Unpin)) -> std::io::Result - where - Self: Sized, - { - let mut result = 0; - let mut shift = 0; - loop { - let byte = reader.read_u8().await?; - result |= ((byte & 0x7F) as i64) << shift; - if byte & 0x80 == 0 { - break; - } - shift += 7; - } - Ok(Self(result)) +impl Writable for &str { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { + let bytes = self.as_bytes(); + let length = bytes.len() as u32; + writer.write_type(VarUInt(length))?.write_all(bytes)?; + Ok(()) } } impl Writable for VarLong { - fn write(self, writer: &mut impl std::io::Write) -> std::io::Result<()> { + fn write(&self, writer: &mut impl Write) -> anyhow::Result<()> { let mut value = self.0; loop { #[allow(clippy::cast_sign_loss)] @@ -395,42 +302,4 @@ impl Writable for VarLong { } Ok(()) } - - async fn write_async(self, writer: &mut (impl AsyncWrite + Unpin)) -> std::io::Result<()> { - let mut value = self.0; - loop { - #[allow(clippy::cast_sign_loss)] - let mut byte = (value & 0x7F) as u8; - value >>= 7; - if value != 0 { - byte |= 0x80; - } - writer.write_all(&[byte]).await?; - if value == 0 { - break; - } - } - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use crate::{types::VarInt, Readable, Writable}; - - fn round_trip(num: i32) { - let original = VarInt(num); - let mut buffer = Vec::new(); - original.write(&mut buffer).unwrap(); - - let mut cursor = std::io::Cursor::new(buffer); - let round_tripped = VarInt::read(&mut cursor).unwrap(); - assert_eq!(original.0, round_tripped.0); - } - - #[test] - fn test_round_trip_varint() { - round_trip(0x1234_5678); - round_trip(32); - } } diff --git a/server/Cargo.toml b/server/Cargo.toml index f56df545..261df51d 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -13,6 +13,7 @@ ser.workspace = true tracing-subscriber = "0.3.18" tracing = "0.1.40" serde_json = "1.0.114" +bytes = "1.5.0" [lints.rust] @@ -49,10 +50,12 @@ exhaustive_structs = "allow" pub_use = "allow" let_underscore_untyped = "allow" infinite_loop = "allow" +single_char_lifetime_names = "allow" complexity = "deny" nursery = { level = "deny", priority = -1 } +future_not_send = "allow" pedantic = { level = "deny", priority = -1 } uninlined_format_args = "allow" # consider denying; this is allowed because Copilot often generates code that triggers this lint diff --git a/server/clippy.toml b/server/clippy.toml index 50600c99..5ad7f93b 100644 --- a/server/clippy.toml +++ b/server/clippy.toml @@ -1,3 +1,3 @@ # https://doc.rust-lang.org/nightly/clippy/lint_configuration.html cognitive-complexity-threshold = 5 -excessive-nesting-threshold = 3 +excessive-nesting-threshold = 4 diff --git a/server/src/main.rs b/server/src/main.rs index 80c8f381..b1e24a66 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,7 +1,14 @@ #![allow(unused)] -use anyhow::{bail, ensure}; -use protocol_765::{serverbound, serverbound::NextState, status::Root}; -use ser::{ExactPacket, ReadExtAsync, Readable, Writable, WritePacket}; + +use std::{io, io::ErrorKind}; + +use anyhow::{bail, ensure, Context}; +use bytes::{Buf, BufMut, BytesMut}; +use protocol_765::{clientbound, serverbound, serverbound::NextState, status::Root}; +use ser::{ + types::{VarInt, VarIntDecodeError, MAX_PACKET_SIZE}, + ExactPacket, Packet, Readable, Writable, WritePacket, +}; use serde_json::json; use tokio::{ io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf}, @@ -10,38 +17,240 @@ use tokio::{ }; use tracing::{debug, error, info, instrument, warn}; -struct Process { - writer: WriteHalf, - reader: BufReader>, +#[derive(Default)] +struct PacketDecoder { + buf: BytesMut, } -impl Process { - fn new(stream: TcpStream) -> Self { - let (reader, writer) = tokio::io::split(stream); - let reader = BufReader::new(reader); - // let writer = BufWriter::new(writer); - Self { writer, reader } +#[derive(Clone, Debug, Default)] +pub struct PacketFrame { + /// The ID of the decoded packet. + pub id: i32, + /// The contents of the packet after the leading VarInt ID. + pub body: BytesMut, +} + +impl PacketFrame { + pub fn decode<'a, P>(&'a self) -> anyhow::Result

+ where + P: Packet + Readable<'a>, + { + ensure!( + P::ID == self.id, + "packet ID mismatch while decoding '{}': expected {}, got {}", + P::NAME, + P::ID, + self.id + ); + + #[allow(clippy::min_ident_chars)] + let mut r = &*self.body; + + let pkt = P::decode(&mut r)?; + + ensure!( + r.is_empty(), + "missed {} bytes while decoding '{}'", + r.len(), + P::NAME + ); + + Ok(pkt) } +} - #[instrument(skip(self))] - async fn process(mut self, id: usize) -> anyhow::Result<()> { - let bytes = self.reader.fill_buf().await?; +impl PacketDecoder { + #[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss, + clippy::min_ident_chars + )] + pub fn try_next_packet(&mut self) -> anyhow::Result> { + let mut r = &*self.buf; + + let packet_len = match VarInt::decode_partial(&mut r) { + Ok(len) => len, + Err(VarIntDecodeError::Incomplete) => return Ok(None), + Err(VarIntDecodeError::TooLarge) => bail!("malformed packet length VarInt"), + }; + + ensure!( + (0..=MAX_PACKET_SIZE).contains(&packet_len), + "packet length of {packet_len} is out of bounds" + ); + + if r.len() < packet_len as usize { + // Not enough data arrived yet. + return Ok(None); + } + + let packet_len_len = VarInt(packet_len).written_size(); + + let mut data; + + self.buf.advance(packet_len_len); + + data = self.buf.split_to(packet_len as usize); + + // Decode the leading packet ID. + r = &*data; + let packet_id = VarInt::decode(&mut r) + .context("failed to decode packet ID")? + .0; + + data.advance(data.len() - r.len()); + + Ok(Some(PacketFrame { + id: packet_id, + body: data, + })) + } + + pub fn queue_bytes(&mut self, mut bytes: BytesMut) { + self.buf.unsplit(bytes); + } + + pub fn take_capacity(&mut self) -> BytesMut { + self.buf.split_off(self.buf.len()) + } + + pub fn reserve(&mut self, additional: usize) { + self.buf.reserve(additional); + } +} + +const READ_BUF_SIZE: usize = 4096; + +#[derive(Default)] +pub struct PacketEncoder { + buf: BytesMut, +} + +impl PacketEncoder { + #[allow( + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::cast_sign_loss, + clippy::min_ident_chars + )] + pub fn append_packet

(&mut self, pkt: &P) -> anyhow::Result<()> + where + P: Packet + Writable, + { + let start_len = self.buf.len(); + + let mut writer = (&mut self.buf).writer(); + VarInt(P::ID).write(&mut writer)?; + + pkt.write(&mut writer)?; + + let data_len = self.buf.len() - start_len; + + let packet_len = data_len; + + ensure!( + packet_len <= MAX_PACKET_SIZE as usize, + "packet exceeds maximum length" + ); + + let packet_len_size = VarInt(packet_len as i32).written_size(); + + self.buf.put_bytes(0, packet_len_size); + self.buf + .copy_within(start_len..start_len + data_len, start_len + packet_len_size); + + #[allow(clippy::indexing_slicing)] + let mut front = &mut self.buf[start_len..]; + VarInt(packet_len as i32).write(&mut front)?; - if let Some(byte) = bytes.first() { - if *byte == 0xfe { - warn!("first byte: {:#x}", byte); - self.status().await?; - return Ok(()); + Ok(()) + } + + /// Takes all the packets written so far and encrypts them if encryption is + /// enabled. + pub fn take(&mut self) -> BytesMut { + self.buf.split() + } +} + +struct Io { + stream: TcpStream, + dec: PacketDecoder, + enc: PacketEncoder, + frame: PacketFrame, +} + +impl Io { + pub async fn recv_packet<'a, P>(&'a mut self) -> anyhow::Result

+ where + P: Packet + Readable<'a>, + { + loop { + if let Some(frame) = self.dec.try_next_packet()? { + self.frame = frame; + return self.frame.decode(); + } + + self.dec.reserve(READ_BUF_SIZE); + let mut buf = self.dec.take_capacity(); + + if self.stream.read_buf(&mut buf).await? == 0 { + return Err(io::Error::from(ErrorKind::UnexpectedEof).into()); } - warn!("no first byte"); + + // This should always be an O(1) unsplit because we reserved space earlier and + // the call to `read_buf` shouldn't have grown the allocation. + self.dec.queue_bytes(buf); + } + } + + pub async fn recv_packet_raw(&mut self) -> anyhow::Result { + loop { + if let Some(frame) = self.dec.try_next_packet()? { + return Ok(frame); + } + + self.dec.reserve(READ_BUF_SIZE); + let mut buf = self.dec.take_capacity(); + + if self.stream.read_buf(&mut buf).await? == 0 { + return Err(io::Error::from(ErrorKind::UnexpectedEof).into()); + } + + // This should always be an O(1) unsplit because we reserved space earlier and + // the call to `read_buf` shouldn't have grown the allocation. + self.dec.queue_bytes(buf); + } + } + + fn new(stream: TcpStream) -> Self { + Self { + stream, + dec: PacketDecoder::default(), + enc: PacketEncoder::default(), + frame: PacketFrame::default(), } + } + + pub(crate) async fn send_packet

(&mut self, pkt: &P) -> anyhow::Result<()> + where + P: Packet + Writable, + { + self.enc.append_packet(pkt)?; + let bytes = self.enc.take(); + self.stream.write_all(&bytes).await?; + Ok(()) + } - let ExactPacket(serverbound::Handshake { + #[instrument(skip(self))] + async fn process(mut self, id: usize) -> anyhow::Result<()> { + let serverbound::Handshake { protocol_version, server_address, server_port, next_state, - }) = self.reader.read_type().await?; + } = self.recv_packet().await?; ensure!(protocol_version.0 == 765, "expected protocol version 765"); ensure!(server_port == 25565, "expected server port 25565"); @@ -54,23 +263,55 @@ impl Process { Ok(()) } + // The login process is as follows: + // 1. C→S: Handshake with Next State set to 2 (login) + // 2. C→S: Login Start + // 3. S→C: Encryption Request + // 4. Client auth + // 5. C→S: Encryption Response + // 6. Server auth, both enable encryption + // 7. S→C: Set Compression (optional) + // 8. S→C: Login Success + // 9. C→S: Login Acknowledged async fn login(mut self) -> anyhow::Result<()> { - info!("login"); + debug!("login"); - let ExactPacket(serverbound::LoginStart { username, uuid }) = - self.reader.read_type().await?; + let serverbound::LoginStart { username, uuid } = self.recv_packet().await?; debug!("username: {username}"); debug!("uuid: {uuid}"); + let username = username.to_owned(); + + let packet = clientbound::LoginSuccess { + uuid, + username: &username, + properties: vec![], + }; + + debug!("sending {packet:?}"); + + self.send_packet(&packet).await?; + + let serverbound::LoginAcknowledged = self.recv_packet().await?; + + debug!("received login acknowledged"); + + self.main_loop().await?; + Ok(()) } - async fn status(mut self) -> anyhow::Result<()> { - info!("status"); - let ExactPacket(serverbound::StatusRequest) = self.reader.read_type().await?; + async fn main_loop(mut self) -> anyhow::Result<()> { + loop { + let packet = self.recv_packet_raw().await?; + debug!("received {packet:?}"); + } + } - info!("byte"); + async fn status(mut self) -> anyhow::Result<()> { + debug!("status"); + let serverbound::StatusRequest = self.recv_packet().await?; let mut json = json!({ "version": { @@ -85,81 +326,31 @@ impl Process { "description": "10k babyyyyy", }); - let send = WritePacket::new(protocol_765::clientbound::StatusResponse { - json: json.to_string(), - }); + let json = serde_json::to_string_pretty(&json)?; - send.write_async(&mut self.writer).await?; + let send = clientbound::StatusResponse { json: &json }; - info!("wrote status response"); + self.send_packet(&send).await?; - let ExactPacket(serverbound::Ping { payload }) = self.reader.read_type().await?; + debug!("wrote status response"); - info!("read ping {}", payload); + let serverbound::Ping { payload } = self.recv_packet().await?; - let pong = WritePacket::new(protocol_765::clientbound::Pong { payload }); - pong.write_async(&mut self.writer).await?; + debug!("read ping {}", payload); + + let pong = clientbound::Pong { payload }; + self.send_packet(&pong).await?; Ok(()) } } -async fn print_errors(future: impl std::future::Future>) { +async fn print_errors(future: impl core::future::Future>) { if let Err(err) = future.await { error!("{:?}", err); } } -// #[allow(clippy::infinite_loop)] -// async fn process(id: usize, stream: TcpStream) -> anyhow::Result<()> { -// println!("{handshake:?}"); -// -// ensure!( -// handshake.data.next_state == NextState::Status, -// "expected status" -// ); -// -// let _status_pkt = ExactPacket::::read_async(&mut reader).await?; -// // ensure!(status_pkt.0 == 0, "expected status packet"); -// -// let mut writer = BufWriter::new(writer); -// -// let json = Root::sample(); -// let json = serde_json::to_string_pretty(&json)?; -// -// println!("{}", json); -// -// let response = protocol_765::clientbound::StatusResponse { json }; -// -// let login_start = WritePacket::new(response); -// login_start.write_async(&mut writer).await?; -// -// debug!("wrote status response"); -// -// let pong = WritePacket::new(protocol_765::clientbound::Pong { payload: 0 }); -// pong.write_async(&mut writer).await?; -// -// println!("wrote pong"); -// -// // wait for the client to disconnect -// // loop { -// // writer.write_u8(0).await?; -// // // let byte = reader.read_u8().await?; -// // sleep(core::time::Duration::from_millis(10)).await; -// // } -// -// while let Ok(byte) = reader.read_u8().await { -// println!("{:#x}", byte); -// } -// // let x = reader.read_u8().await?; -// // -// // println!("{:#x}", x); -// -// println!("done"); -// -// Ok(()) -// } - #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt::init(); @@ -172,7 +363,7 @@ async fn main() -> anyhow::Result<()> { loop { let (stream, _) = listener.accept().await?; - let process = Process::new(stream); + let process = Io::new(stream); let action = process.process(id); let action = print_errors(action);