Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove needless Semaphore::(u32::MAX) #1051

Merged
merged 1 commit into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions core/src/server/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore};

use super::rpc_module::{DisconnectError, SendTimeoutError, SubscriptionMessage, TrySendError};

/// Subscription permit.
pub type SubscriptionPermit = OwnedSemaphorePermit;

/// Bounded writer that allows writing at most `max_len` bytes.
///
/// ```
Expand Down Expand Up @@ -191,20 +194,6 @@ pub fn prepare_error(data: &[u8]) -> (Id<'_>, ErrorCode) {
}
}

/// A permitted subscription.
#[derive(Debug)]
pub struct SubscriptionPermit {
_permit: OwnedSemaphorePermit,
resource: Arc<Notify>,
}

impl SubscriptionPermit {
/// Get the handle to [`tokio::sync::Notify`].
pub fn handle(&self) -> Arc<Notify> {
self.resource.clone()
}
}

/// Wrapper over [`tokio::sync::Notify`] with bounds check.
#[derive(Debug, Clone)]
pub struct BoundedSubscriptions {
Expand All @@ -227,10 +216,7 @@ impl BoundedSubscriptions {
///
/// Fails if `max_subscriptions` have been exceeded.
pub fn acquire(&self) -> Option<SubscriptionPermit> {
Arc::clone(&self.guard)
.try_acquire_owned()
.ok()
.map(|p| SubscriptionPermit { _permit: p, resource: self.resource.clone() })
Arc::clone(&self.guard).try_acquire_owned().ok()
}

/// Get the maximum number of permitted subscriptions.
Expand Down
33 changes: 18 additions & 15 deletions core/src/server/rpc_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ pub type MaxResponseSize = usize;
/// A 3-tuple containing:
/// - Call result as a `String`,
/// - a [`mpsc::UnboundedReceiver<String>`] to receive future subscription results
/// - a [`crate::server::helpers::SubscriptionPermit`] to allow subscribers to notify their [`SubscriptionSink`] when
/// they disconnect.
pub type RawRpcResponse = (MethodResponse, mpsc::Receiver<String>, SubscriptionPermit);
pub type RawRpcResponse = (MethodResponse, mpsc::Receiver<String>);

/// Error that may occur during [`SubscriptionSink::try_send`].
#[derive(Debug)]
Expand Down Expand Up @@ -408,7 +406,7 @@ impl Methods {
let params = params.to_rpc_params()?;
let req = Request::new(method.into(), params.as_ref().map(|p| p.as_ref()), Id::Number(0));
tracing::trace!("[Methods::call] Method: {:?}, params: {:?}", method, params);
let (resp, _, _) = self.inner_call(req, 1).await;
let (resp, _) = self.inner_call(req, 1, mock_subscription_permit()).await;

if resp.success {
serde_json::from_str::<Response<T>>(&resp.result).map(|r| r.result).map_err(Into::into)
Expand Down Expand Up @@ -456,27 +454,28 @@ impl Methods {
) -> Result<(MethodResponse, mpsc::Receiver<String>), Error> {
tracing::trace!("[Methods::raw_json_request] Request: {:?}", request);
let req: Request = serde_json::from_str(request)?;
let (resp, rx, _) = self.inner_call(req, buf_size).await;
let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;

Ok((resp, rx))
}

/// Execute a callback.
async fn inner_call(&self, req: Request<'_>, buf_size: usize) -> RawRpcResponse {
async fn inner_call(
&self,
req: Request<'_>,
buf_size: usize,
subscription_permit: SubscriptionPermit,
) -> RawRpcResponse {
let (tx, mut rx) = mpsc::channel(buf_size);
let id = req.id.clone();
let params = Params::new(req.params.map(|params| params.get()));
let bounded_subs = BoundedSubscriptions::new(u32::MAX);
let p1 = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed");
let p2 = bounded_subs.acquire().expect("u32::MAX permits is sufficient; qed");

let response = match self.method(&req.method) {
None => MethodResponse::error(req.id, ErrorObject::from(ErrorCode::MethodNotFound)),
Some(MethodCallback::Sync(cb)) => (cb)(id, params, usize::MAX),
Some(MethodCallback::Async(cb)) => (cb)(id.into_owned(), params.into_owned(), 0, usize::MAX).await,
Some(MethodCallback::Subscription(cb)) => {
let conn_state =
ConnState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit: p1 };
let conn_state = ConnState { conn_id: 0, id_provider: &RandomIntegerIdProvider, subscription_permit };
let res = match (cb)(id, params, MethodSink::new(tx.clone()), conn_state).await {
Ok(rp) => rp,
Err(id) => MethodResponse::error(id, ErrorObject::from(ErrorCode::InternalError)),
Expand All @@ -495,7 +494,7 @@ impl Methods {

tracing::trace!("[Methods::inner_call] Method: {}, response: {:?}", req.method, response);

(response, rx, p2)
(response, rx)
}

/// Helper to create a subscription on the `RPC module` without having to spin up a server.
Expand Down Expand Up @@ -544,7 +543,7 @@ impl Methods {

tracing::trace!("[Methods::subscribe] Method: {}, params: {:?}", sub_method, params);

let (resp, rx, permit) = self.inner_call(req, buf_size).await;
let (resp, rx) = self.inner_call(req, buf_size, mock_subscription_permit()).await;

let subscription_response = match serde_json::from_str::<Response<RpcSubscriptionId>>(&resp.result) {
Ok(r) => r,
Expand All @@ -556,7 +555,7 @@ impl Methods {

let sub_id = subscription_response.result.into_owned();

Ok(Subscription { sub_id, rx, _permit: permit })
Ok(Subscription { sub_id, rx })
}

/// Returns an `Iterator` with all the method names registered on this server.
Expand Down Expand Up @@ -1127,7 +1126,6 @@ impl Drop for SubscriptionSink {
pub struct Subscription {
rx: mpsc::Receiver<String>,
sub_id: RpcSubscriptionId<'static>,
_permit: SubscriptionPermit,
}

impl Subscription {
Expand Down Expand Up @@ -1168,3 +1166,8 @@ impl Drop for Subscription {
self.close();
}
}

// Mock subscription permit to be able to make a call.
fn mock_subscription_permit() -> SubscriptionPermit {
BoundedSubscriptions::new(1).acquire().expect("1 permit should exist; qed")
}