diff --git a/core/src/server/helpers.rs b/core/src/server/helpers.rs index df62cde412..00babedc7b 100644 --- a/core/src/server/helpers.rs +++ b/core/src/server/helpers.rs @@ -38,6 +38,9 @@ use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore}; use super::rpc_module::{DisconnectError, SendTimeoutError, SubscriptionMessage, TrySendError}; +/// Subscription permit. +pub type SubscriptionPermit = OwnedSemaphorePermit; + /// Bounded writer that allows writing at most `max_len` bytes. /// /// ``` @@ -191,20 +194,6 @@ pub fn prepare_error(data: &[u8]) -> (Id<'_>, ErrorCode) { } } -/// A permitted subscription. -#[derive(Debug)] -pub struct SubscriptionPermit { - _permit: OwnedSemaphorePermit, - resource: Arc, -} - -impl SubscriptionPermit { - /// Get the handle to [`tokio::sync::Notify`]. - pub fn handle(&self) -> Arc { - self.resource.clone() - } -} - /// Wrapper over [`tokio::sync::Notify`] with bounds check. #[derive(Debug, Clone)] pub struct BoundedSubscriptions { @@ -227,10 +216,7 @@ impl BoundedSubscriptions { /// /// Fails if `max_subscriptions` have been exceeded. pub fn acquire(&self) -> Option { - Arc::clone(&self.guard) - .try_acquire_owned() - .ok() - .map(|p| SubscriptionPermit { _permit: p, resource: self.resource.clone() }) + Arc::clone(&self.guard).try_acquire_owned().ok() } /// Get the maximum number of permitted subscriptions. diff --git a/core/src/server/rpc_module.rs b/core/src/server/rpc_module.rs index f91f4b658b..1ba4e0625d 100644 --- a/core/src/server/rpc_module.rs +++ b/core/src/server/rpc_module.rs @@ -76,9 +76,7 @@ pub type MaxResponseSize = usize; /// A 3-tuple containing: /// - Call result as a `String`, /// - a [`mpsc::UnboundedReceiver`] to receive future subscription results -/// - a [`crate::server::helpers::SubscriptionPermit`] to allow subscribers to notify their [`SubscriptionSink`] when -/// they disconnect. -pub type RawRpcResponse = (MethodResponse, mpsc::Receiver, SubscriptionPermit); +pub type RawRpcResponse = (MethodResponse, mpsc::Receiver); /// Error that may occur during [`SubscriptionSink::try_send`]. #[derive(Debug)] @@ -408,7 +406,7 @@ impl Methods { let params = params.to_rpc_params()?; let req = Request::new(method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0)); tracing::trace!("[Methods::call] Method: {:?}, params: {:?}", method, params); - let (resp, _, _) = self.inner_call(req, 1).await; + let (resp, _) = self.inner_call(req, 1, mock_subscription_permit()).await; if resp.success { serde_json::from_str::>(&resp.result).map(|r| r.result).map_err(Into::into) @@ -456,27 +454,28 @@ impl Methods { ) -> Result<(MethodResponse, mpsc::Receiver), Error> { tracing::trace!("[Methods::raw_json_request] Request: {:?}", request); let req: Request = serde_json::from_str(request)?; - let (resp, rx, _) = self.inner_call(req, buf_size).await; + let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await; Ok((resp, rx)) } /// Execute a callback. - async fn inner_call(&self, req: Request<'_>, buf_size: usize) -> RawRpcResponse { + async fn inner_call( + &self, + req: Request<'_>, + buf_size: usize, + subscription_permit: SubscriptionPermit, + ) -> RawRpcResponse { let (tx, mut rx) = mpsc::channel(buf_size); let id = req.id.clone(); let params = Params::new(req.params.map(|params| params.get())); - let bounded_subs = BoundedSubscriptions::new(u32::MAX); - let p1 = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed"); - let p2 = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed"); let response = match self.method(&req.method) { None => MethodResponse::error(req.id, ErrorObject::from(ErrorCode::MethodNotFound)), Some(MethodCallback::Sync(cb)) => (cb)(id, params, usize::MAX), Some(MethodCallback::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), 0, usize::MAX).await, Some(MethodCallback::Subscription(cb)) => { - let conn_state = - ConnState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit: p1 }; + let conn_state = ConnState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit }; let res = match (cb)(id, params, MethodSink::new(tx.clone()), conn_state).await { Ok(rp) => rp, Err(id) => MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)), @@ -495,7 +494,7 @@ impl Methods { tracing::trace!("[Methods::inner_call] Method: {}, response: {:?}", req.method, response); - (response, rx, p2) + (response, rx) } /// Helper to create a subscription on the `RPC module` without having to spin up a server. @@ -544,7 +543,7 @@ impl Methods { tracing::trace!("[Methods::subscribe] Method: {}, params: {:?}", sub_method, params); - let (resp, rx, permit) = self.inner_call(req, buf_size).await; + let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await; let subscription_response = match serde_json::from_str::>(&resp.result) { Ok(r) => r, @@ -556,7 +555,7 @@ impl Methods { let sub_id = subscription_response.result.into_owned(); - Ok(Subscription { sub_id, rx, _permit: permit }) + Ok(Subscription { sub_id, rx }) } /// Returns an `Iterator` with all the method names registered on this server. @@ -1127,7 +1126,6 @@ impl Drop for SubscriptionSink { pub struct Subscription { rx: mpsc::Receiver, sub_id: RpcSubscriptionId<'static>, - _permit: SubscriptionPermit, } impl Subscription { @@ -1168,3 +1166,8 @@ impl Drop for Subscription { self.close(); } } + +// Mock subscription permit to be able to make a call. +fn mock_subscription_permit() -> SubscriptionPermit { + BoundedSubscriptions::new(1).acquire().expect("1 permit should exist; qed") +}