Skip to content

Commit

Permalink
Merge pull request #45 from xmakro/main
Browse files Browse the repository at this point in the history
Fix deadlock in Session::renew
  • Loading branch information
dmzmk authored May 24, 2024
2 parents cc165e9 + acb1ddc commit 951791f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 42 deletions.
2 changes: 1 addition & 1 deletion snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license = "Apache-2.0"
name = "snowflake-api"
readme = "README.md"
repository = "https://github.com/mycelial/snowflake-rs"
version = "0.8.0"
version = "0.8.1"

[features]
all = ["cert-auth", "polars"]
Expand Down
78 changes: 37 additions & 41 deletions snowflake-api/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,13 @@ impl Session {
.is_some_and(|at| at.session_token.is_expired())
{
// Renew old session token
let tokens = self.renew().await?;
let old_token = auth_tokens.take().unwrap();
let tokens = self.renew(old_token).await?;
*auth_tokens = Some(tokens);
}
auth_tokens.as_mut().unwrap().sequence_id += 1;
let session_token_auth_header = auth_tokens.as_ref().unwrap().session_token.auth_header();
Ok(AuthParts {
session_token_auth_header,
session_token_auth_header: auth_tokens.as_ref().unwrap().session_token.auth_header(),
sequence_id: auth_tokens.as_ref().unwrap().sequence_id,
})
}
Expand Down Expand Up @@ -380,47 +380,43 @@ impl Session {
}
}

async fn renew(&self) -> Result<AuthTokens, AuthError> {
if let Some(token) = self.auth_tokens.lock().await.take() {
log::debug!("Renewing the token");
let auth = token.master_token.auth_header();
let body = RenewSessionRequest {
old_session_token: token.session_token.token.clone(),
request_type: "RENEW".to_string(),
};
async fn renew(&self, token: AuthTokens) -> Result<AuthTokens, AuthError> {
log::debug!("Renewing the token");
let auth = token.master_token.auth_header();
let body = RenewSessionRequest {
old_session_token: token.session_token.token.clone(),
request_type: "RENEW".to_string(),
};

let resp = self
.connection
.request(
QueryType::TokenRequest,
&self.account_identifier,
&[],
Some(&auth),
body,
)
.await?;
let resp = self
.connection
.request(
QueryType::TokenRequest,
&self.account_identifier,
&[],
Some(&auth),
body,
)
.await?;

match resp {
AuthResponse::Renew(rs) => {
let session_token =
AuthToken::new(&rs.data.session_token, rs.data.validity_in_seconds_s_t);
let master_token =
AuthToken::new(&rs.data.master_token, rs.data.validity_in_seconds_m_t);

Ok(AuthTokens {
session_token,
master_token,
sequence_id: token.sequence_id,
})
}
AuthResponse::Error(e) => Err(AuthError::AuthFailed(
e.code.unwrap_or_default(),
e.message.unwrap_or_default(),
)),
_ => Err(AuthError::UnexpectedResponse),
match resp {
AuthResponse::Renew(rs) => {
let session_token =
AuthToken::new(&rs.data.session_token, rs.data.validity_in_seconds_s_t);
let master_token =
AuthToken::new(&rs.data.master_token, rs.data.validity_in_seconds_m_t);

Ok(AuthTokens {
session_token,
master_token,
sequence_id: token.sequence_id,
})
}
} else {
Err(AuthError::OutOfOrderRenew)
AuthResponse::Error(e) => Err(AuthError::AuthFailed(
e.code.unwrap_or_default(),
e.message.unwrap_or_default(),
)),
_ => Err(AuthError::UnexpectedResponse),
}
}
}

0 comments on commit 951791f

Please sign in to comment.