Skip to content

Commit

Permalink
Introduce websocket support (#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
FabijanC committed Oct 7, 2024
1 parent 9d8eea5 commit 47610ea
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 5 deletions.
39 changes: 37 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ keywords = ["starknet", "cairo", "testnet", "local", "server"]
[workspace.dependencies]

# axum
axum = { version = "0.7" }
axum = { version = "0.7", features = ["ws"] }
http-body-util = { version = "0.1" }
tower-http = { version = "0.5", features = ["full"] }

Expand Down Expand Up @@ -115,6 +115,7 @@ parking_lot = "0.12.3"
serial_test = "3.1.1"
hex = "0.4.3"
lazy_static = { version = "1.4.0" }
tokio-tungstenite = { version = "0.21.0" }

# Benchmarking
criterion = { version = "0.3.4", features = ["async_tokio"] }
47 changes: 47 additions & 0 deletions crates/starknet-devnet-server/src/api/json_rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ mod write_endpoints;

pub const RPC_SPEC_VERSION: &str = "0.7.1";

use axum::extract::ws::{Message, WebSocket};
use enum_helper_macros::{AllVariantsSerdeRenames, VariantName};
use futures::StreamExt;
use models::{
BlockAndClassHashInput, BlockAndContractAddressInput, BlockAndIndexInput, CallInput,
EstimateFeeInput, EventsInput, GetStorageInput, TransactionHashInput, TransactionHashOutput,
Expand Down Expand Up @@ -154,6 +156,28 @@ impl RpcHandler for JsonRpcHandler {
}
}
}

async fn on_websocket(&self, mut socket: WebSocket) {
while let Some(msg) = socket.next().await {
match msg {
Ok(Message::Text(text)) => {
self.on_websocket_rpc_call(text.as_bytes(), &mut socket).await;
}
Ok(Message::Binary(bytes)) => {
self.on_websocket_rpc_call(&bytes, &mut socket).await;
}
Ok(Message::Close(_)) => {
tracing::info!("Websocket disconnected");
return;
}
other => {
tracing::error!("Socket handler got an unexpected message: {other:?}")
}
}
}

tracing::error!("Failed socket read");
}
}

impl JsonRpcHandler {
Expand Down Expand Up @@ -347,6 +371,29 @@ impl JsonRpcHandler {
starknet_resp.to_rpc_result()
}

/// Takes `bytes` to be an encoded RPC call, executes it, and sends the response back via `ws`.
async fn on_websocket_rpc_call(&self, bytes: &[u8], ws: &mut WebSocket) {
match serde_json::from_slice(bytes) {
Ok(call) => {
let resp = self.on_call(call).await;
let resp_serialized = serde_json::to_string(&resp).unwrap_or_else(|e| {
let err_msg = format!("Error converting RPC response to string: {e}");
tracing::error!(err_msg);
err_msg
});

if let Err(e) = ws.send(Message::Text(resp_serialized)).await {
tracing::error!("Error sending websocket message: {e}");
}
}
Err(e) => {
if let Err(e) = ws.send(Message::Text(e.to_string())).await {
tracing::error!("Error sending websocket message: {e}");
}
}
}
}

const DUMPABLE_METHODS: &'static [&'static str] = &[
"devnet_impersonateAccount",
"devnet_stopImpersonateAccount",
Expand Down
19 changes: 18 additions & 1 deletion crates/starknet-devnet-server/src/rpc_handler.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::fmt::{self};

use axum::extract::rejection::JsonRejection;
use axum::extract::State;
use axum::extract::ws::WebSocket;
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use axum::Json;
use futures::{future, FutureExt};
use serde::de::DeserializeOwned;
Expand Down Expand Up @@ -33,6 +35,9 @@ pub trait RpcHandler: Clone + Send + Sync + 'static {
/// **Note**: override this function if the expected `Request` deviates from `{ "method" :
/// "<name>", "params": "<params>" }`
async fn on_call(&self, call: RpcMethodCall) -> RpcResponse;

/// Handles websocket connection, from start to finish.
async fn on_websocket(&self, mut socket: WebSocket);
}

/// Handles incoming JSON-RPC Request
Expand All @@ -52,6 +57,18 @@ pub async fn handle<THandler: RpcHandler>(
}
}

pub async fn handle_socket<THandler: RpcHandler>(
ws_upgrade: WebSocketUpgrade,
State(handler): State<THandler>,
) -> impl IntoResponse {
tracing::info!("New websocket connection!");
ws_upgrade.on_failed_upgrade(|e| tracing::error!("Failed websocket upgrade: {e:?}")).on_upgrade(
move |socket| async move {
handler.on_websocket(socket).await;
},
)
}

#[macro_export]
/// Match a list of comma-separated pairs enclosed in square brackets. First pair member is the HTTP
/// path which is mapped to an RPC request with the method that is the second pair member. Using the
Expand Down
1 change: 1 addition & 0 deletions crates/starknet-devnet-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ fn json_rpc_routes<TJsonRpcHandler: RpcHandler>(json_rpc_handler: TJsonRpcHandle
Router::new()
.route("/", post(rpc_handler::handle::<TJsonRpcHandler>))
.route("/rpc", post(rpc_handler::handle::<TJsonRpcHandler>))
.route("/ws", get(rpc_handler::handle_socket::<TJsonRpcHandler>))
.with_state(json_rpc_handler)
}

Expand Down
1 change: 1 addition & 0 deletions crates/starknet-devnet/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ usc = { workspace = true }
reqwest = { workspace = true }
criterion = { workspace = true }
serial_test = { workspace = true }
tokio-tungstenite = { workspace = true }


[[bench]]
Expand Down
6 changes: 5 additions & 1 deletion crates/starknet-devnet/tests/common/background_devnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use url::Url;

use super::constants::{
ACCOUNTS, HEALTHCHECK_PATH, HOST, MAX_PORT, MIN_PORT, PREDEPLOYED_ACCOUNT_INITIAL_BALANCE,
RPC_PATH, SEED,
RPC_PATH, SEED, WS_PATH,
};
use super::errors::TestError;
use super::reqwest_client::{PostReqwestSender, ReqwestClient};
Expand Down Expand Up @@ -160,6 +160,10 @@ impl BackgroundDevnet {
Err(TestError::DevnetNotStartable)
}

pub fn ws_url(&self) -> String {
format!("ws://{HOST}:{}{WS_PATH}", self.port)
}

pub async fn send_custom_rpc(
&self,
method: &str,
Expand Down
1 change: 1 addition & 0 deletions crates/starknet-devnet/tests/common/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub const CHAIN_ID_CLI_PARAM: &str = "TESTNET";
// URL paths
pub const RPC_PATH: &str = "/rpc";
pub const HEALTHCHECK_PATH: &str = "/is_alive";
pub const WS_PATH: &str = "/ws";

// predeployed account info with seed=42
pub const PREDEPLOYED_ACCOUNT_ADDRESS: &str =
Expand Down
Loading

0 comments on commit 47610ea

Please sign in to comment.