Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
algesten committed Jul 7, 2024
1 parent 1def28a commit 6ade769
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 80 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ features = []
[features]

[dependencies]
hoot = { git = "https://github.com/algesten/hoot", rev = "040c74c" }
hoot = { git = "https://github.com/algesten/hoot", rev = "3e52ff1" }
http = "1.1.0"
log = "0.4.22"
once_cell = "1.19.0"
Expand Down
15 changes: 13 additions & 2 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,18 @@ impl Agent {
.unwrap_or(Buffers::empty());

match unit.poll_event(current_time(), buffers)? {
Event::Reset => {
Event::Reset { must_close } => {
addr = None;
connection = None;
response = None;

if let Some(c) = connection.take() {
if must_close {
c.close();
} else {
c.reuse();
}
}

unit.handle_input(current_time(), Input::Begin, &mut [])?;
}

Expand Down Expand Up @@ -227,6 +235,9 @@ impl Agent {
}
}

let response = response.expect("above loop to exit when there is a response");
let unit = unit.release_body();

todo!()
}
}
Expand Down
8 changes: 8 additions & 0 deletions src/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,12 @@ impl Connection {
pub fn consume_input(&mut self, amount: usize) {
self.conn.consume_input(amount)
}

pub(crate) fn close(self) {
todo!()
}

pub(crate) fn reuse(self) {
todo!()
}
}
212 changes: 136 additions & 76 deletions src/unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::collections::VecDeque;
use std::mem;
use std::time::Duration;

use hoot::client::flow::{state::*, Await100Result, RecvResponseResult, SendRequestResult};
use hoot::client::flow::{
state::*, Await100Result, RecvBodyResult, RecvResponseResult, SendRequestResult,
};
use http::{Request, Response, Uri};

use crate::error::TimeoutReason;
Expand Down Expand Up @@ -46,7 +48,7 @@ macro_rules! extract {
}

pub enum Event<'a> {
Reset,
Reset { must_close: bool },
Resolve { uri: &'a Uri, timeout: Duration },
OpenConnection { uri: &'a Uri, timeout: Duration },
Await100 { timeout: Duration },
Expand Down Expand Up @@ -91,8 +93,6 @@ impl<'c, 'b, 'a> Unit<'c, 'b, 'a> {
}

pub fn poll_event(&mut self, now: Instant, buffers: Buffers) -> Result<Event, Error> {
let Buffers { input, output } = buffers;

// Queued events go first.
if let Some(queued) = self.queued_event.pop_front() {
return Ok(queued);
Expand All @@ -114,52 +114,36 @@ impl<'c, 'b, 'a> Unit<'c, 'b, 'a> {
}));
}

// These outputs don't borrow from the State, but they might proceed the FSM. Hence
// we return an Output<'static> meaning we are free the call self.maybe_change_state()
// since self.state is not borrowed.
let output: Option<Event<'static>> = match &mut self.state {
State::Begin(_) => Some(Event::Reset),
// Events that do not borrow any state, but might proceed the FSM
let maybe_event = self.poll_event_static(buffers, timeout)?;

if let Some(event) = maybe_event {
self.poll_event_maybe_proceed_state(now);
return Ok(event);
}

// Events that borrow the state and don't proceed the FSM.
self.poll_event_borrow(timeout)
}

// These events don't borrow from the State, but they might proceed the FSM. Hence
// we return an Event<'static> meaning we are free the call self.poll_event_maybe_proceed_state()
// since self.state is not borrowed.
fn poll_event_static(
&mut self,
buffers: Buffers,
timeout: Duration,
) -> Result<Option<Event<'static>>, Error> {
let Buffers { input, output } = buffers;

Ok(match &mut self.state {
State::Begin(_) => Some(Event::Reset { must_close: false }),

// State::Resolve (see below)
// State::OpenConnection (see below)
State::SendRequest(flow) => {
let output_used = flow.write(output)?;

Some(Event::Transmit {
amount: output_used,
timeout,
})
}
State::SendRequest(flow) => Some(send_request(flow, output, timeout)?),

State::SendBody(flow) => {
let input_len = input.len();

// The + 1 and floor() is to make even powers of 16 right.
// The + 4 is for the \r\n overhead. A chunk is:
// <digits_in_hex>\r\n
// <chunk>\r\n
// 0\r\n
// \r\n
let chunk_overhead = ((output.len() as f64).log(16.0) + 1.0).floor() as usize + 4;
assert!(input_len > chunk_overhead);
let max_input = input_len - chunk_overhead;

// TODO(martin): for any body that is BodyInner::ByteSlice, it's not great to
// go via self.body.read() since we're incurring on more memcopy than we need.
let input = &mut input[..max_input];
let n = self.body.read(input)?;

let (input_used, output_used) = flow.write(&input[..n], output)?;

// Since output is "a bit" larger than the input (compensate for chunk ovherhead),
// the entire input we read from the body should also be shipped to the output.
assert!(input_used == n);

Some(Event::Transmit {
amount: output_used,
timeout,
})
}
State::SendBody(flow) => Some(send_body(flow, input, output, timeout, &mut self.body)?),

State::Await100(_) => Some(Event::Await100 { timeout }),

Expand All @@ -174,14 +158,17 @@ impl<'c, 'b, 'a> Unit<'c, 'b, 'a> {
}),

State::Redirect(flow) => {
// Whether the previous connection must be closed.
let must_close = flow.must_close_connection();

let maybe_new_flow = flow.as_new_flow(self.config.redirect_auth_headers)?;

if let Some(flow) = maybe_new_flow {
// Start over the state
self.state = State::Begin(flow);

// Tell caller to reset state
Some(Event::Reset)
Some(Event::Reset { must_close })
} else {
return Err(Error::RedirectFailed);
}
Expand All @@ -192,36 +179,14 @@ impl<'c, 'b, 'a> Unit<'c, 'b, 'a> {
State::Empty => unreachable!("self.state should never be in State::Empty"),

_ => None,
};

if let Some(output) = output {
self.poll_output_maybe_proceed_state(now);
return Ok(output);
}

// These Outputs borrow from the State, but they don't proceed the FSM.
let output = match &mut self.state {
State::Resolve(flow) => Event::Resolve {
uri: flow.uri(),
timeout,
},

State::OpenConnection(flow) => Event::OpenConnection {
uri: flow.uri(),
timeout,
},

_ => unreachable!("State must be covered in first or second match"),
};

Ok(output)
})
}

fn poll_output_maybe_proceed_state(&mut self, now: Instant) {
fn poll_event_maybe_proceed_state(&mut self, now: Instant) {
let state = mem::replace(&mut self.state, State::Empty);

let new_state = match state {
// State might move on poll_output
// State moves on poll_output
State::SendRequest(flow) => {
if flow.can_proceed() {
self.call_timings.time_send_request = Some(now);
Expand All @@ -243,23 +208,44 @@ impl<'c, 'b, 'a> Unit<'c, 'b, 'a> {
}
}

// State might move on handle_input()
// Special handling above.
State::Redirect(flow) => State::Redirect(flow),

// State moves on handle_input()
State::Begin(flow) => State::Begin(flow),
State::Resolve(flow) => State::Resolve(flow),
State::OpenConnection(flow) => State::OpenConnection(flow),
State::Await100(flow) => State::Await100(flow),
State::RecvResponse(flow) => State::RecvResponse(flow),

// TODO(martin): decide when state moves
State::RecvBody(flow) => State::RecvBody(flow),
State::Redirect(flow) => State::Redirect(flow),

State::Cleanup(flow) => State::Cleanup(flow),

State::Empty => unreachable!("self.state should never be State::Empty"),
};

self.state = new_state;
}

// These events borrow from the State, but they don't proceed the FSM.
fn poll_event_borrow(&self, timeout: Duration) -> Result<Event, Error> {
let event = match &self.state {
State::Resolve(flow) => Event::Resolve {
uri: flow.uri(),
timeout,
},

State::OpenConnection(flow) => Event::OpenConnection {
uri: flow.uri(),
timeout,
},

_ => unreachable!("State must be covered in first or second match"),
};

Ok(event)
}

pub fn handle_input(
&mut self,
now: Instant,
Expand Down Expand Up @@ -348,6 +334,19 @@ impl<'c, 'b, 'a> Unit<'c, 'b, 'a> {
amount: output_used,
});

if flow.can_proceed() {
let flow = extract!(&mut self.state, State::RecvBody)
.expect("Input::Input requires State::RecvBody");

let state = match flow.proceed().unwrap() {
RecvBodyResult::Redirect(flow) => State::Redirect(flow),
RecvBodyResult::Cleanup(flow) => State::Cleanup(flow),
};

self.call_timings.time_recv_body = Some(now);
self.state = state;
}

return Ok(input_used);
}
_ => {}
Expand All @@ -367,6 +366,67 @@ impl<'c, 'b, 'a> Unit<'c, 'b, 'a> {
Await100Result::RecvResponse(flow) => State::RecvResponse(flow),
};
}

pub fn release_body(self) -> Unit<'c, 'b, 'static> {
Unit {
config: self.config,
global_start: self.global_start,
call_timings: self.call_timings,
state: self.state,
body: Body::empty(),
queued_event: self.queued_event,
redirect_count: self.redirect_count,
}
}
}

fn send_request(
flow: &mut Flow<SendRequest>,
output: &mut [u8],
timeout: Duration,
) -> Result<Event<'static>, Error> {
let output_used = flow.write(output)?;

Ok(Event::Transmit {
amount: output_used,
timeout,
})
}

fn send_body(
flow: &mut Flow<SendBody>,
input: &mut [u8],
output: &mut [u8],
timeout: Duration,
body: &mut Body,
) -> Result<Event<'static>, Error> {
let input_len = input.len();

// The + 1 and floor() is to make even powers of 16 right.
// The + 4 is for the \r\n overhead. A chunk is:
// <digits_in_hex>\r\n
// <chunk>\r\n
// 0\r\n
// \r\n
let chunk_overhead = ((output.len() as f64).log(16.0) + 1.0).floor() as usize + 4;
assert!(input_len > chunk_overhead);
let max_input = input_len - chunk_overhead;

// TODO(martin): for any body that is BodyInner::ByteSlice, it's not great to
// go via self.body.read() since we're incurring on more memcopy than we need.
let input = &mut input[..max_input];
let n = body.read(input)?;

let (input_used, output_used) = flow.write(&input[..n], output)?;

// Since output is "a bit" larger than the input (compensate for chunk ovherhead),
// the entire input we read from the body should also be shipped to the output.
assert!(input_used == n);

Ok(Event::Transmit {
amount: output_used,
timeout,
})
}

#[derive(Debug, Default)]
Expand Down

0 comments on commit 6ade769

Please sign in to comment.