Skip to content

Commit

Permalink
Added 100-continue handling to async client.
Browse files Browse the repository at this point in the history
  • Loading branch information
5cript committed Dec 11, 2023
1 parent b502c04 commit 9420042
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 26 deletions.
169 changes: 147 additions & 22 deletions include/roar/client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,23 +129,21 @@ namespace Roar
}

/**
* @brief Connects the client to a server and performs a request
* @brief Resolve host and connect to server and perform SSL handshake.
*
* @param requestParameters see RequestParameters
* @return A promise to continue after the connect.
* @param host The host to connect to.
* @param port The port to connect to.
* @param timeout The timeout for the connection.
* @return Detail::PromiseTypeBind<Detail::PromiseTypeBindThen<>, Detail::PromiseTypeBindFail<Error>>
*/
template <typename BodyT>
Detail::PromiseTypeBind<Detail::PromiseTypeBindThen<>, Detail::PromiseTypeBindFail<Error>>
request(Request<BodyT>&& request, std::chrono::seconds timeout = defaultTimeout)
connect(std::string const& host, std::string const& port, std::chrono::seconds timeout = defaultTimeout)
{
return promise::newPromise([&, this](promise::Defer d) mutable {
const auto host = request.host();
const auto port = request.port();
std::shared_ptr<Request<BodyT>> requestPtr = std::make_shared<Request<BodyT>>(std::move(request));
doResolve(
host,
port,
[weak = weak_from_this(), timeout, requestPtr, d = std::move(d)](
[weak = weak_from_this(), timeout, d = std::move(d), host](
boost::beast::error_code ec, boost::asio::ip::tcp::resolver::results_type results) mutable {
auto self = weak.lock();
if (!self)
Expand All @@ -158,7 +156,7 @@ namespace Roar
socket.expires_after(timeout);
socket.async_connect(
results,
[weak = self->weak_from_this(), d = std::move(d), requestPtr, timeout](
[weak = self->weak_from_this(), d = std::move(d), timeout, host](
boost::beast::error_code ec,
boost::asio::ip::tcp::resolver::results_type::endpoint_type endpoint) mutable {
auto self = weak.lock();
Expand All @@ -170,14 +168,40 @@ namespace Roar
return d.reject(Error{.error = ec, .additionalInfo = "TCP connect failed."});

self->endpoint_ = endpoint;

self->onConnect(std::move(*requestPtr), std::move(d), timeout);
self->onConnect(host, std::move(d), timeout);
});
});
});
});
}

/**
* @brief Connects the client to a server and performs a request
*
* @param requestParameters see RequestParameters
* @return A promise to continue after the connect.
*/
template <typename BodyT>
Detail::PromiseTypeBind<Detail::PromiseTypeBindThen<>, Detail::PromiseTypeBindFail<Error>>
request(Request<BodyT>&& req, std::chrono::seconds timeout = defaultTimeout)
{
return promise::newPromise([req = std::move(req), timeout, this](promise::Defer d) mutable {
const auto host = req.host();
const auto port = req.port();
connect(host, port, timeout)
.then([weak = weak_from_this(), req = std::move(req), timeout, d]() mutable {
auto self = weak.lock();
if (!self)
return d.reject(Error{.additionalInfo = "Client is no longer alive."});

self->performRequest(std::move(req), std::move(d), timeout);
})
.fail([d](auto error) mutable {
d.reject(std::move(error));
});
});
}

/**
* @brief Reads only the header, will need be followed up by a readResponse.
*
Expand Down Expand Up @@ -305,8 +329,8 @@ namespace Roar
Detail::PromiseTypeBindFail<Error>>
requestAndReadResponse(Request<RequestBodyT>&& request, std::chrono::seconds timeout = defaultTimeout)
{
return promise::newPromise([r = std::move(request), timeout, this](promise::Defer d) mutable {
this->request(std::move(r), timeout)
return promise::newPromise([request = std::move(request), timeout, this](promise::Defer d) mutable {
this->request(std::move(request), timeout)
.then([weak = weak_from_this(), timeout, d]() {
auto client = weak.lock();
if (!client)
Expand Down Expand Up @@ -410,12 +434,11 @@ namespace Roar
std::function<void(boost::beast::error_code ec, boost::asio::ip::tcp::resolver::results_type results)>
onResolve);

template <typename BodyT>
void onConnect(Request<BodyT>&& request, promise::Defer&& d, std::chrono::seconds timeout)
void onConnect(std::string const& host, promise::Defer&& d, std::chrono::seconds timeout)
{
if (std::holds_alternative<boost::beast::ssl_stream<boost::beast::tcp_stream>>(socket_))
{
auto maybeError = setupSsl(request.host());
auto maybeError = setupSsl(host);
if (maybeError)
return d.reject(*maybeError);

Expand All @@ -425,29 +448,32 @@ namespace Roar
});
sslSocket.async_handshake(
boost::asio::ssl::stream_base::client,
[weak = weak_from_this(), d = std::move(d), request = std::move(request), timeout](
boost::beast::error_code ec) mutable {
[weak = weak_from_this(), d = std::move(d)](boost::beast::error_code ec) mutable {
auto self = weak.lock();
if (!self)
return d.reject(Error{.error = ec, .additionalInfo = "Client is no longer alive."});

if (ec)
return d.reject(Error{.error = ec, .additionalInfo = "SSL handshake failed."});

self->performRequest(std::move(request), std::move(d), timeout);
d.resolve();
});
}
else
performRequest(std::move(request), std::move(d), timeout);
d.resolve();
}

template <typename BodyT>
void performRequest(Request<BodyT>&& request, promise::Defer&& d, std::chrono::seconds timeout)
{
auto iter = request.find(boost::beast::http::field::expect);
if (iter != request.end() && iter->value() == "100-continue")
return performRequestWithExpectContinue(std::move(request), std::move(d), timeout);

withLowerLayerDo([timeout](auto& socket) {
socket.expires_after(timeout);
});
withStreamDo([this, request = std::move(request), &d](auto& socket) mutable {
withStreamDo([this, request = std::move(request), &d, timeout](auto& socket) mutable {
std::shared_ptr<Request<BodyT>> requestPtr = std::make_shared<Request<BodyT>>(std::move(request));
boost::beast::http::async_write(
socket,
Expand All @@ -466,6 +492,105 @@ namespace Roar
});
}

template <typename BodyT>
void
performRequestWithExpectContinue(Request<BodyT>&& request, promise::Defer&& d, std::chrono::seconds timeout)
{
withLowerLayerDo([timeout](auto& socket) {
socket.expires_after(timeout);
});

auto sharedRequest = std::make_shared<Request<BodyT>>(std::move(request));
auto serializerPtr = std::make_shared<boost::beast::http::request_serializer<BodyT>>(*sharedRequest);

withStreamDo([this, serializerPtr, sharedRequest, d = std::move(d), timeout](auto& socket) mutable {
boost::beast::http::async_write_header(
socket,
*serializerPtr,
[weak = weak_from_this(), d = std::move(d), serializerPtr, sharedRequest, timeout](
boost::beast::error_code ec, std::size_t) mutable {
auto self = weak.lock();
if (!self)
return d.reject(Error{.error = ec, .additionalInfo = "Client is no longer alive."});

if (ec)
return d.reject(Error{.error = ec, .additionalInfo = "HTTP write failed."});

self->read100ContinueResponse(
std::make_pair(std::move(serializerPtr), std::move(sharedRequest)), std::move(d), timeout);
});
});
}

template <typename BodyT>
void read100ContinueResponse(
std::pair<std::shared_ptr<boost::beast::http::request_serializer<BodyT>>, std::shared_ptr<Request<BodyT>>>&&
requestPair,
promise::Defer&& d,
std::chrono::seconds timeout)
{
withStreamDo([requestPair = std::move(requestPair), d = std::move(d), timeout, this](auto& socket) mutable {
auto response = std::make_shared<boost::beast::http::response<boost::beast::http::string_body>>();

withLowerLayerDo([timeout](auto& socket) {
socket.expires_after(timeout);
});
boost::beast::http::async_read(
socket,
*buffer_,
*response,
[d = std::move(d),
buffer = this->buffer_,
response,
requestPair = std::move(requestPair),
timeout,
weak = weak_from_this()](boost::beast::error_code ec, std::size_t) mutable {
auto self = weak.lock();
if (!self)
return d.reject(Error{.error = ec, .additionalInfo = "Client is no longer alive."});

if (ec)
return d.reject(Error{.error = ec, .additionalInfo = "HTTP read failed."});

if (response->result() != boost::beast::http::status::continue_)
{
using namespace std::string_literals;
return d.reject(Error{
.additionalInfo = "Server did not respond with 100-continue, but with "s +
std::to_string(response->result_int()) + "."s});
}
else
{
self->complete100ContinueRequest<BodyT>(std::move(requestPair), std::move(d), timeout);
}
});
});
}

template <typename BodyT>
void complete100ContinueRequest(
std::pair<std::shared_ptr<boost::beast::http::request_serializer<BodyT>>, std::shared_ptr<Request<BodyT>>>&&
requestPair,
promise::Defer&& d,
std::chrono::seconds timeout)
{
withLowerLayerDo([timeout](auto& socket) {
socket.expires_after(timeout);
});

withStreamDo([&requestPair, &d](auto& socket) mutable {
boost::beast::http::async_write(
socket,
*requestPair.first,
[d = std::move(d), requestPair](boost::beast::error_code ec, std::size_t) mutable {
if (ec)
return d.reject(Error{.error = ec, .additionalInfo = "HTTP write failed."});

d.resolve();
});
});
}

std::optional<Error> setupSsl(std::string const& host);

private:
Expand Down
76 changes: 74 additions & 2 deletions include/roar/session/session.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ namespace Roar
}

template <typename BodyT>
[[nodiscard]] std::shared_ptr<SendIntermediate<BodyT>> sendWithAllAcceptedCors()
[[nodiscard]] std::shared_ptr<SendIntermediate<BodyT>> send()
{
return std::shared_ptr<SendIntermediate<BodyT>>(new SendIntermediate<BodyT>{*this});
}
Expand Down Expand Up @@ -445,9 +445,10 @@ namespace Roar
if constexpr (std::is_same_v<BodyT, boost::beast::http::empty_body>)
throw std::runtime_error("Attempting to read with empty_body type.");
else
{
return boost::beast::http::request_parser<BodyT>{
std::move(*session.parser()), std::forward<Forwards>(forwardArgs)...};
session.parser().reset();
}
}()}
, originalExtensions_{std::move(req).ejectExtensions()}
, onChunk_{}
Expand Down Expand Up @@ -533,6 +534,30 @@ namespace Roar
return {*promise_};
}

/**
* @brief Start reading the header here.
*
* @return Promise
*/
Detail::PromiseTypeBind<
Detail::PromiseTypeBindThen<
Detail::PromiseReferenceWrap<Session>,
Detail::PromiseReferenceWrap<boost::beast::http::request_parser<BodyT> const>,
std::shared_ptr<ReadIntermediate<BodyT>>>,
Detail::PromiseTypeBindFail<Error const&>>
commitHeaderOnly()
{
if (overallTimeout_)
{
session_->withStreamDo([this](auto& stream) {
boost::beast::get_lowest_layer(stream).expires_after(*overallTimeout_);
});
}
promise_ = std::make_unique<promise::Promise>(promise::newPromise());
readHeader();
return {*promise_};
}

/**
* @brief Set a timeout for the whole read operation.
*
Expand Down Expand Up @@ -590,6 +615,45 @@ namespace Roar
});
}

void readHeader()
{
session_->withStreamDo([this](auto& stream) {
if (!overallTimeout_)
boost::beast::get_lowest_layer(stream).expires_after(std::chrono::seconds(sessionTimeout));

boost::beast::http::async_read_header(
stream,
session_->buffer(),
req_,
[self = this->shared_from_this()](boost::beast::error_code ec, std::size_t) {
if (ec)
{
self->session_->close();
self->promise_->reject(Error{.error = ec});
}
else
{
try
{
self->promise_->template resolve(
Detail::ref(*self->session_), Detail::cref(self->req_), self);
}
catch (std::exception const& exc)
{
using namespace std::string_literals;
self->session_
->send(self->session_->standardResponseProvider().makeStandardResponse(
*self->session_,
boost::beast::http::status::internal_server_error,
"An exception was thrown in the header read completion handler: "s +
exc.what()))
->commit();
}
}
});
});
}

private:
std::shared_ptr<Session> session_;
boost::beast::http::request_parser<BodyT> req_;
Expand Down Expand Up @@ -632,6 +696,14 @@ namespace Roar
new ReadIntermediate<BodyT>{*this, std::move(req), std::forward<Forwards>(forwardArgs)...});
}

template <typename BodyT, typename OriginalBodyT, typename... Forwards>
[[nodiscard]] std::shared_ptr<ReadIntermediate<BodyT>>
readHeader(Request<OriginalBodyT> req, Forwards&&... forwardArgs)
{
return std::shared_ptr<ReadIntermediate<BodyT>>(
new ReadIntermediate<BodyT>{*this, std::move(req), std::forward<Forwards>(forwardArgs)...});
}

/**
* @brief Prepares a response with some header values already set.
*
Expand Down
Loading

0 comments on commit 9420042

Please sign in to comment.