Skip to content

Commit

Permalink
Add Pubsub getVersion, and support programSubscribe filter mapping (#…
Browse files Browse the repository at this point in the history
…26482)

* Add pubsub getVersion api

* Generalize maybe_map_filters

* Add filter mapping to blocking PubsubClient

* Add version tracking to nonblocking PubsubClient

* Add filter mapping to nonblocking PubsubClient
  • Loading branch information
Tyera Eulberg authored Jul 8, 2022
1 parent 3127487 commit b8b5215
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 26 deletions.
97 changes: 92 additions & 5 deletions client/src/nonblocking/pubsub_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use {
RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
RpcTransactionLogsFilter,
},
rpc_filter::maybe_map_filters,
rpc_response::{
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate,
RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate,
},
},
futures_util::{
Expand All @@ -25,7 +26,7 @@ use {
thiserror::Error,
tokio::{
net::TcpStream,
sync::{mpsc, oneshot},
sync::{mpsc, oneshot, RwLock},
task::JoinHandle,
time::{sleep, Duration},
},
Expand Down Expand Up @@ -62,18 +63,28 @@ pub enum PubsubClientError {

#[error("subscribe failed: {reason}")]
SubscribeFailed { reason: String, message: String },

#[error("request failed: {reason}")]
RequestFailed { reason: String, message: String },
}

type UnsubscribeFn = Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>;
type SubscribeResponseMsg =
Result<(mpsc::UnboundedReceiver<Value>, UnsubscribeFn), PubsubClientError>;
type SubscribeRequestMsg = (String, Value, oneshot::Sender<SubscribeResponseMsg>);
type SubscribeResult<'a, T> = PubsubClientResult<(BoxStream<'a, T>, UnsubscribeFn)>;
type RequestMsg = (
String,
Value,
oneshot::Sender<Result<Value, PubsubClientError>>,
);

#[derive(Debug)]
pub struct PubsubClient {
subscribe_tx: mpsc::UnboundedSender<SubscribeRequestMsg>,
request_tx: mpsc::UnboundedSender<RequestMsg>,
shutdown_tx: oneshot::Sender<()>,
node_version: RwLock<Option<semver::Version>>,
ws: JoinHandle<PubsubClientResult>,
}

Expand All @@ -85,12 +96,20 @@ impl PubsubClient {
.map_err(PubsubClientError::ConnectionError)?;

let (subscribe_tx, subscribe_rx) = mpsc::unbounded_channel();
let (request_tx, request_rx) = mpsc::unbounded_channel();
let (shutdown_tx, shutdown_rx) = oneshot::channel();

Ok(Self {
subscribe_tx,
request_tx,
shutdown_tx,
ws: tokio::spawn(PubsubClient::run_ws(ws, subscribe_rx, shutdown_rx)),
node_version: RwLock::new(None),
ws: tokio::spawn(PubsubClient::run_ws(
ws,
subscribe_rx,
request_rx,
shutdown_rx,
)),
})
}

Expand All @@ -99,6 +118,37 @@ impl PubsubClient {
self.ws.await.unwrap() // WS future should not be cancelled or panicked
}

async fn get_node_version(&self) -> PubsubClientResult<semver::Version> {
let r_node_version = self.node_version.read().await;
if let Some(version) = &*r_node_version {
Ok(version.clone())
} else {
drop(r_node_version);
let mut w_node_version = self.node_version.write().await;
let node_version = self.get_version().await?;
*w_node_version = Some(node_version.clone());
Ok(node_version)
}
}

async fn get_version(&self) -> PubsubClientResult<semver::Version> {
let (response_tx, response_rx) = oneshot::channel();
self.request_tx
.send(("getVersion".to_string(), Value::Null, response_tx))
.map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))?;
let result = response_rx
.await
.map_err(|err| PubsubClientError::ConnectionClosed(err.to_string()))??;
let node_version: RpcVersionInfo = serde_json::from_value(result)?;
let node_version = semver::Version::parse(&node_version.solana_core).map_err(|e| {
PubsubClientError::RequestFailed {
reason: format!("failed to parse cluster version: {}", e),
message: "getVersion".to_string(),
}
})?;
Ok(node_version)
}

async fn subscribe<'a, T>(&self, operation: &str, params: Value) -> SubscribeResult<'a, T>
where
T: DeserializeOwned + Send + 'a,
Expand Down Expand Up @@ -147,8 +197,22 @@ impl PubsubClient {
pub async fn program_subscribe(
&self,
pubkey: &Pubkey,
config: Option<RpcProgramAccountsConfig>,
mut config: Option<RpcProgramAccountsConfig>,
) -> SubscribeResult<'_, RpcResponse<RpcKeyedAccount>> {
if let Some(ref mut config) = config {
if let Some(ref mut filters) = config.filters {
let node_version = self.get_node_version().await.ok();
// If node does not support the pubsub `getVersion` method, assume version is old
// and filters should be mapped (node_version.is_none()).
maybe_map_filters(node_version, filters).map_err(|e| {
PubsubClientError::RequestFailed {
reason: e,
message: "maybe_map_filters".to_string(),
}
})?;
}
}

let params = json!([pubkey.to_string(), config]);
self.subscribe("program", params).await
}
Expand Down Expand Up @@ -181,12 +245,14 @@ impl PubsubClient {
async fn run_ws(
mut ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
mut subscribe_rx: mpsc::UnboundedReceiver<SubscribeRequestMsg>,
mut request_rx: mpsc::UnboundedReceiver<RequestMsg>,
mut shutdown_rx: oneshot::Receiver<()>,
) -> PubsubClientResult {
let mut request_id: u64 = 0;

let mut requests_subscribe = BTreeMap::new();
let mut requests_unsubscribe = BTreeMap::<u64, oneshot::Sender<()>>::new();
let mut other_requests = BTreeMap::new();
let mut subscriptions = BTreeMap::new();
let (unsubscribe_tx, mut unsubscribe_rx) = mpsc::unbounded_channel();

Expand Down Expand Up @@ -220,6 +286,13 @@ impl PubsubClient {
ws.send(Message::Text(text)).await?;
requests_unsubscribe.insert(request_id, response_tx);
},
// Read message for other requests
Some((method, params, response_tx)) = request_rx.recv() => {
request_id += 1;
let text = json!({"jsonrpc":"2.0","id":request_id,"method":method,"params":params}).to_string();
ws.send(Message::Text(text)).await?;
other_requests.insert(request_id, response_tx);
}
// Read incoming WebSocket message
next_msg = ws.next() => {
let msg = match next_msg {
Expand Down Expand Up @@ -264,7 +337,21 @@ impl PubsubClient {
}
});

if let Some(response_tx) = requests_unsubscribe.remove(&id) {
if let Some(response_tx) = other_requests.remove(&id) {
match err {
Some(reason) => {
let _ = response_tx.send(Err(PubsubClientError::RequestFailed { reason, message: text.clone()}));
},
None => {
let json_result = json.get("result").ok_or_else(|| {
PubsubClientError::RequestFailed { reason: "missing `result` field".into(), message: text.clone() }
})?;
if response_tx.send(Ok(json_result.clone())).is_err() {
break;
}
}
}
} else if let Some(response_tx) = requests_unsubscribe.remove(&id) {
let _ = response_tx.send(()); // do not care if receiver is closed
} else if let Some((operation, response_tx)) = requests_subscribe.remove(&id) {
match err {
Expand Down
22 changes: 3 additions & 19 deletions client/src/nonblocking/rpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use {
mock_sender::MockSender,
rpc_client::{GetConfirmedSignaturesForAddress2Config, RpcClientConfig},
rpc_config::{RpcAccountInfoConfig, *},
rpc_filter::{MemcmpEncodedBytes, RpcFilterType},
rpc_filter::{self, RpcFilterType},
rpc_request::{RpcError, RpcRequest, RpcResponseErrorData, TokenAccountsFilter},
rpc_response::*,
rpc_sender::*,
Expand Down Expand Up @@ -587,24 +587,8 @@ impl RpcClient {
mut filters: Vec<RpcFilterType>,
) -> Result<Vec<RpcFilterType>, RpcError> {
let node_version = self.get_node_version().await?;
if node_version < semver::Version::new(1, 11, 2) {
for filter in filters.iter_mut() {
if let RpcFilterType::Memcmp(memcmp) = filter {
match &memcmp.bytes {
MemcmpEncodedBytes::Base58(string) => {
memcmp.bytes = MemcmpEncodedBytes::Binary(string.clone());
}
MemcmpEncodedBytes::Base64(_) => {
return Err(RpcError::RpcRequestError(format!(
"RPC node on old version {} does not support base64 encoding for memcmp filters",
node_version
)));
}
_ => {}
}
}
}
}
rpc_filter::maybe_map_filters(Some(node_version), &mut filters)
.map_err(RpcError::RpcRequestError)?;
Ok(filters)
}

Expand Down
54 changes: 53 additions & 1 deletion client/src/pubsub_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use {
RpcProgramAccountsConfig, RpcSignatureSubscribeConfig, RpcTransactionLogsConfig,
RpcTransactionLogsFilter,
},
rpc_filter,
rpc_response::{
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate,
Expand Down Expand Up @@ -48,6 +49,9 @@ pub enum PubsubClientError {

#[error("unexpected message format: {0}")]
UnexpectedMessageError(String),

#[error("request error: {0}")]
RequestError(String),
}

pub struct PubsubClientSubscription<T>
Expand Down Expand Up @@ -123,6 +127,43 @@ where
.map_err(|err| err.into())
}

fn get_version(
writable_socket: &Arc<RwLock<WebSocket<MaybeTlsStream<TcpStream>>>>,
) -> Result<semver::Version, PubsubClientError> {
writable_socket
.write()
.unwrap()
.write_message(Message::Text(
json!({
"jsonrpc":"2.0","id":1,"method":"getVersion",
})
.to_string(),
))?;
let message = writable_socket.write().unwrap().read_message()?;
let message_text = &message.into_text()?;
let json_msg: Map<String, Value> = serde_json::from_str(message_text)?;

if let Some(Object(version_map)) = json_msg.get("result") {
if let Some(node_version) = version_map.get("solana-core") {
let node_version = semver::Version::parse(
node_version.as_str().unwrap_or_default(),
)
.map_err(|e| {
PubsubClientError::RequestError(format!(
"failed to parse cluster version: {}",
e
))
})?;
return Ok(node_version);
}
}
// TODO: Add proper JSON RPC response/error handling...
Err(PubsubClientError::UnexpectedMessageError(format!(
"{:?}",
json_msg
)))
}

fn read_message(
writable_socket: &Arc<RwLock<WebSocket<MaybeTlsStream<TcpStream>>>>,
) -> Result<T, PubsubClientError> {
Expand Down Expand Up @@ -357,7 +398,7 @@ impl PubsubClient {
pub fn program_subscribe(
url: &str,
pubkey: &Pubkey,
config: Option<RpcProgramAccountsConfig>,
mut config: Option<RpcProgramAccountsConfig>,
) -> Result<ProgramSubscription, PubsubClientError> {
let url = Url::parse(url)?;
let socket = connect_with_retry(url)?;
Expand All @@ -367,6 +408,17 @@ impl PubsubClient {
let socket_clone = socket.clone();
let exit = Arc::new(AtomicBool::new(false));
let exit_clone = exit.clone();

if let Some(ref mut config) = config {
if let Some(ref mut filters) = config.filters {
let node_version = PubsubProgramClientSubscription::get_version(&socket_clone).ok();
// If node does not support the pubsub `getVersion` method, assume version is old
// and filters should be mapped (node_version.is_none()).
rpc_filter::maybe_map_filters(node_version, filters)
.map_err(PubsubClientError::RequestError)?;
}
}

let body = json!({
"jsonrpc":"2.0",
"id":1,
Expand Down
24 changes: 24 additions & 0 deletions client/src/rpc_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,30 @@ impl From<RpcMemcmp> for Memcmp {
}
}

pub(crate) fn maybe_map_filters(
node_version: Option<semver::Version>,
filters: &mut [RpcFilterType],
) -> Result<(), String> {
if node_version.is_none() || node_version.unwrap() < semver::Version::new(1, 11, 2) {
for filter in filters.iter_mut() {
if let RpcFilterType::Memcmp(memcmp) = filter {
match &memcmp.bytes {
MemcmpEncodedBytes::Base58(string) => {
memcmp.bytes = MemcmpEncodedBytes::Binary(string.clone());
}
MemcmpEncodedBytes::Base64(_) => {
return Err("RPC node on old version does not support base64 \
encoding for memcmp filters"
.to_string());
}
_ => {}
}
}
}
}
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
31 changes: 30 additions & 1 deletion rpc/src/rpc_pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use {
},
rpc_response::{
Response as RpcResponse, RpcBlockUpdate, RpcKeyedAccount, RpcLogsResponse,
RpcSignatureResult, RpcVote, SlotInfo, SlotUpdate,
RpcSignatureResult, RpcVersionInfo, RpcVote, SlotInfo, SlotUpdate,
},
},
solana_sdk::{clock::Slot, pubkey::Pubkey, signature::Signature},
Expand Down Expand Up @@ -348,6 +348,10 @@ mod internal {
// Unsubscribe from slot notification subscription.
#[rpc(name = "rootUnsubscribe")]
fn root_unsubscribe(&self, id: SubscriptionId) -> Result<bool>;

// Get the current solana version running on the node
#[rpc(name = "getVersion")]
fn get_version(&self) -> Result<RpcVersionInfo>;
}
}

Expand Down Expand Up @@ -576,6 +580,14 @@ impl RpcSolPubSubInternal for RpcSolPubSubImpl {
fn root_unsubscribe(&self, id: SubscriptionId) -> Result<bool> {
self.unsubscribe(id)
}

fn get_version(&self) -> Result<RpcVersionInfo> {
let version = solana_version::Version::default();
Ok(RpcVersionInfo {
solana_core: version.to_string(),
feature_set: Some(version.feature_set),
})
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1370,4 +1382,21 @@ mod tests {
assert!(rpc.vote_unsubscribe(42.into()).is_err());
assert!(rpc.vote_unsubscribe(sub_id).is_ok());
}

#[test]
fn test_get_version() {
let GenesisConfigInfo { genesis_config, .. } = create_genesis_config(10_000);
let bank = Bank::new_for_tests(&genesis_config);
let bank_forks = Arc::new(RwLock::new(BankForks::new(bank)));
let max_complete_transaction_status_slot = Arc::new(AtomicU64::default());
let rpc_subscriptions = Arc::new(RpcSubscriptions::default_with_bank_forks(
max_complete_transaction_status_slot,
bank_forks,
));
let (rpc, _receiver) = rpc_pubsub_service::test_connection(&rpc_subscriptions);
let version = rpc.get_version().unwrap();
let expected_version = solana_version::Version::default();
assert_eq!(version.to_string(), expected_version.to_string());
assert_eq!(version.feature_set.unwrap(), expected_version.feature_set);
}
}

0 comments on commit b8b5215

Please sign in to comment.