Skip to content

Commit

Permalink
Client: add [getter, setter] for auth, auth_bearer, params, headers
Browse files Browse the repository at this point in the history
  • Loading branch information
deedy5 committed Dec 18, 2024
1 parent cf25070 commit f7bfd13
Showing 1 changed file with 37 additions and 28 deletions.
65 changes: 37 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::str::FromStr;
use std::sync::{Arc, LazyLock};
use std::sync::{Arc, LazyLock, Mutex};
use std::time::Duration;

use ahash::RandomState;
Expand All @@ -11,7 +11,7 @@ use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pythonize::depythonize;
use rquest::{
header::{HeaderMap, HeaderName, HeaderValue, COOKIE},
header::{HeaderValue, COOKIE},
multipart,
redirect::Policy,
tls::Impersonate,
Expand All @@ -23,6 +23,9 @@ use tokio::runtime::{self, Runtime};
mod response;
use response::Response;

mod traits;
use traits::HeadersTraits;

mod utils;
use utils::load_ca_certs;

Expand All @@ -37,10 +40,14 @@ static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
#[pyclass]
/// HTTP client that can impersonate web browsers.
pub struct Client {
client: Arc<rquest::Client>,
client: Arc<Mutex<rquest::Client>>,
#[pyo3(get, set)]
auth: Option<(String, Option<String>)>,
#[pyo3(get, set)]
auth_bearer: Option<String>,
#[pyo3(get, set)]
params: Option<IndexMap<String, String, RandomState>>,
#[pyo3(get)]
cookies: Option<IndexMap<String, String, RandomState>>,
}

Expand Down Expand Up @@ -133,15 +140,7 @@ impl Client {

// Headers
if let Some(headers) = headers {
let headers_new = headers
.iter()
.filter_map(|(k, v)| {
HeaderName::from_bytes(k.as_bytes())
.ok()
.and_then(|name| HeaderValue::from_str(v).ok().map(|value| (name, value)))
})
.collect::<HeaderMap>();
client_builder = client_builder.default_headers(headers_new);
client_builder = client_builder.default_headers(headers.to_headermap());
}

// Cookie_store
Expand Down Expand Up @@ -195,7 +194,7 @@ impl Client {
client_builder = client_builder.http2_only();
}

let client = Arc::new(client_builder.build()?);
let client = Arc::new(Mutex::new(client_builder.build()?));

Ok(Client {
client,
Expand All @@ -206,6 +205,28 @@ impl Client {
})
}

#[getter]
pub fn get_headers(&self) -> Result<IndexMap<String, String, RandomState>> {
let headers = self.client.lock().unwrap().headers_mut().to_indexmap();
Ok(headers)
}

#[setter]
pub fn set_headers(
&self,
new_headers: Option<IndexMap<String, String, RandomState>>,
) -> Result<()> {
let mut client = self.client.lock().unwrap();
let headers = client.headers_mut();
headers.clear();
if let Some(new_headers) = new_headers {
for (k, v) in new_headers {
headers.insert_key_value(k, v)?
}
}
Ok(())
}

/// Constructs an HTTP request with the given method, URL, and optionally sets a timeout, headers, and query parameters.
/// Sends the request and returns a `Response` object containing the server's response.
///
Expand Down Expand Up @@ -274,7 +295,7 @@ impl Client {

let future = async move {
// Create request builder
let mut request_builder = client.request(method, url);
let mut request_builder = client.lock().unwrap().request(method, url);

// Params
if let Some(params) = params {
Expand All @@ -283,15 +304,7 @@ impl Client {

// Headers
if let Some(headers) = headers {
let headers_new = headers
.iter()
.filter_map(|(k, v)| {
HeaderName::from_bytes(k.as_bytes()).ok().and_then(|name| {
HeaderValue::from_str(v).ok().map(|value| (name, value))
})
})
.collect::<HeaderMap>();
request_builder = request_builder.headers(headers_new);
request_builder = request_builder.headers(headers.to_headermap());
}

// Cookies
Expand Down Expand Up @@ -351,11 +364,7 @@ impl Client {
.cookies()
.map(|cookie| (cookie.name().to_string(), cookie.value().to_string()))
.collect();
let headers: IndexMap<String, String, RandomState> = resp
.headers()
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let headers: IndexMap<String, String, RandomState> = resp.headers().to_indexmap();
let status_code = resp.status().as_u16();
let url = resp.url().to_string();
let buf = resp.bytes().await?;
Expand Down

0 comments on commit f7bfd13

Please sign in to comment.