Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for typed HTTP headers #1829

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sdk/core/azure_core/src/request_options/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ impl From<RangeFrom<usize>> for Range {
}

impl AsHeaders for Range {
type Error = std::convert::Infallible;
type Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>;

fn as_headers(&self) -> Self::Iter {
fn as_headers(&self) -> Result<Self::Iter, Self::Error> {
let mut headers = vec![(headers::MS_RANGE, format!("{self}").into())];
if let Some(len) = self.optional_len() {
if len < 1024 * 1024 * 4 {
Expand All @@ -63,7 +64,7 @@ impl AsHeaders for Range {
));
}
}
headers.into_iter()
Ok(headers.into_iter())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
Query, QueryPartitionStrategy,
};

use azure_core::{headers::HeaderValue, Context, Request};
use azure_core::{Context, Request};
use serde::{de::DeserializeOwned, Deserialize};
use url::Url;

Expand Down Expand Up @@ -177,10 +177,7 @@ impl ContainerClientMethods for ContainerClient {
base_req.add_mandatory_header(&constants::QUERY_CONTENT_TYPE);

let QueryPartitionStrategy::SinglePartition(partition_key) = partition_key.into();
base_req.insert_header(
constants::PARTITION_KEY,
HeaderValue::from_cow(partition_key.into_header_value()?),
);
base_req.insert_headers(&partition_key)?;

base_req.set_json(&query.into())?;

Expand Down
44 changes: 28 additions & 16 deletions sdk/cosmos/azure_data_cosmos/src/partition_key.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

use azure_core::headers::{AsHeaders, HeaderName, HeaderValue};

use crate::constants;

/// Describes the partition strategy that will be used when querying.
///
/// Currently, the only supported strategy is [`QueryPartitionStrategy::SinglePartition`], which executes the query against a single partition, specified by the [`PartitionKey`] provided.
Expand Down Expand Up @@ -90,8 +94,11 @@ impl<T: Into<PartitionKey>> From<T> for QueryPartitionStrategy {
#[derive(Debug, Clone)]
pub struct PartitionKey(Vec<PartitionKeyValue>);

impl PartitionKey {
pub(crate) fn into_header_value(self) -> azure_core::Result<String> {
impl AsHeaders for PartitionKey {
type Error = azure_core::Error;
type Iter = std::iter::Once<(HeaderName, HeaderValue)>;

fn as_headers(&self) -> Result<Self::Iter, Self::Error> {
// We have to do some manual JSON serialization here.
// The partition key is sent in an HTTP header, when used to set the partition key for a query.
// It's not safe to use non-ASCII characters in HTTP headers, and serde_json will not escape non-ASCII characters if they are otherwise valid as UTF-8.
Expand All @@ -100,10 +107,10 @@ impl PartitionKey {
let mut json = String::new();
let mut utf_buf = [0; 2]; // A buffer for encoding UTF-16 characters.
json.push('[');
for key in self.0 {
for key in &self.0 {
match key.0 {
InnerPartitionKeyValue::Null => json.push_str("null"),
InnerPartitionKeyValue::String(string_key) => {
InnerPartitionKeyValue::String(ref string_key) => {
json.push('"');
for char in string_key.chars() {
match char {
Expand All @@ -125,8 +132,10 @@ impl PartitionKey {
}
json.push('"');
}
InnerPartitionKeyValue::Number(num) => {
json.push_str(serde_json::to_string(&serde_json::Value::Number(num))?.as_str());
InnerPartitionKeyValue::Number(ref num) => {
json.push_str(
serde_json::to_string(&serde_json::Value::Number(num.clone()))?.as_str(),
);
}
}

Expand All @@ -137,7 +146,10 @@ impl PartitionKey {
json.pop();
json.push(']');

Ok(json)
Ok(std::iter::once((
constants::PARTITION_KEY,
HeaderValue::from_cow(json),
)))
}
}

Expand Down Expand Up @@ -265,19 +277,22 @@ impl_from_tuple!(0 A 1 B 2 C);

#[cfg(test)]
mod tests {
use crate::PartitionKey;

use super::QueryPartitionStrategy;
use crate::{constants, PartitionKey, QueryPartitionStrategy};
use typespec_client_core::http::headers::AsHeaders;

fn key_to_string(v: impl Into<PartitionKey>) -> String {
v.into().into_header_value().unwrap()
let key = v.into();
let mut headers_iter = key.as_headers().unwrap();
let (name, value) = headers_iter.next().unwrap();
assert_eq!(constants::PARTITION_KEY, name);
value.as_str().into()
}

/// Validates that a given value is `impl Into<QueryPartitionStrategy>` and works as-expected.
fn key_to_single_partition_strategy_string(v: impl Into<QueryPartitionStrategy>) -> String {
let strategy = v.into();
let QueryPartitionStrategy::SinglePartition(key) = strategy;
key.into_header_value().unwrap()
key_to_string(key)
}

#[test]
Expand Down Expand Up @@ -327,10 +342,7 @@ mod tests {
#[test]
pub fn non_ascii_string() {
let key = PartitionKey::from("smile 😀");
assert_eq!(
key.into_header_value().unwrap().as_str(),
r#"["smile \ud83d\ude00"]"#
);
assert_eq!(key_to_string(key), r#"["smile \ud83d\ude00"]"#);
}

#[test]
Expand Down
187 changes: 179 additions & 8 deletions sdk/typespec/typespec_client_core/src/http/headers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,66 @@ mod microsoft;
pub use common::*;
pub use microsoft::*;

use std::{borrow::Cow, fmt::Debug, str::FromStr};
use std::{borrow::Cow, convert::Infallible, fmt::Debug, str::FromStr};
use typespec::error::{Error, ErrorKind, ResultExt};

/// A trait for converting a type into request headers.
pub trait AsHeaders {
type Error: std::error::Error + Send + Sync + 'static;
type Iter: Iterator<Item = (HeaderName, HeaderValue)>;
fn as_headers(&self) -> Self::Iter;

fn as_headers(&self) -> Result<Self::Iter, Self::Error>;
}

impl<T> AsHeaders for T
where
T: Header,
{
type Error = Infallible;
type Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>;

/// Iterate over all the header name/value pairs.
fn as_headers(&self) -> Self::Iter {
vec![(self.name(), self.value())].into_iter()
fn as_headers(&self) -> Result<Self::Iter, Self::Error> {
Ok(vec![(self.name(), self.value())].into_iter())
}
}

impl<T> AsHeaders for Option<T>
where
T: AsHeaders<Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>>,
{
type Error = T::Error;
type Iter = T::Iter;

/// Iterate over all the header name/value pairs.
fn as_headers(&self) -> Self::Iter {
fn as_headers(&self) -> Result<Self::Iter, T::Error> {
match self {
Some(h) => h.as_headers(),
None => vec![].into_iter(),
None => Ok(vec![].into_iter()),
}
}
}

/// Extract a value from the [`Headers`] collection.
///
/// The [`FromHeaders::from_headers()`] method is usually used implicitly, through [`Headers::get()`] or [`Headers::get_optional()`].
pub trait FromHeaders: Sized {
type Error: std::error::Error + Send + Sync + 'static;
analogrelay marked this conversation as resolved.
Show resolved Hide resolved

/// Gets a list of the header names that [`FromHeaders::from_headers`] expects.
///
/// Used by [`Headers::get()`] to generate an error if the headers are not present.
fn header_names() -> &'static [&'static str];
heaths marked this conversation as resolved.
Show resolved Hide resolved

/// Extracts the value from the provided [`Headers`] collection.
///
/// This method returns one of the following three values:
/// * `Ok(Some(...))` if the relevant headers are present and could be parsed into the value.
/// * `Ok(None)` if the relevant headers are not present, so no attempt to parse them can be made.
/// * `Err(...)` if an error occurred when trying to parse the headers. This likely indicates that the headers are present but invalid.
fn from_headers(headers: &Headers) -> Result<Option<Self>, Self::Error>;
analogrelay marked this conversation as resolved.
Show resolved Hide resolved
}

/// View a type as an HTTP header.
///
// Ad interim there are two default functions: `add_to_builder` and `add_to_request`.
Expand All @@ -69,6 +93,34 @@ impl Headers {
Self::default()
}

/// Gets the headers represented by `H`, or return an error if the header is not found.
pub fn get<H: FromHeaders>(&self) -> crate::Result<H> {
match H::from_headers(self) {
Ok(Some(x)) => Ok(x),
Ok(None) => Err(crate::Error::with_message(
ErrorKind::DataConversion,
|| {
let required_headers = H::header_names();
format!(
"required header(s) not found: {}",
required_headers.join(", ")
)
},
)),
Err(e) => Err(crate::Error::new(ErrorKind::DataConversion, e)),
}
}

/// Gets the headers represented by `H`, if they are present.
///
/// This method returns one of the following three values:
/// * `Ok(Some(...))` if the relevant headers are present and could be parsed into the value.
/// * `Ok(None)` if the relevant headers are not present, so no attempt to parse them can be made.
/// * `Err(...)` if an error occurred when trying to parse the headers. This likely indicates that the headers are present but invalid.
pub fn get_optional<H: FromHeaders>(&self) -> Result<Option<H>, H::Error> {
H::from_headers(self)
}

/// Optionally get a header value as a `String`.
pub fn get_optional_string(&self, key: &HeaderName) -> Option<String> {
self.get_as(key).ok()
Expand Down Expand Up @@ -146,13 +198,20 @@ impl Headers {
}

/// Add headers to the headers collection.
pub fn add<H>(&mut self, header: H)
///
/// ## Errors
///
/// The error this returns depends on the type `H`.
/// Many header types are infallible, return a `Result` with [`Infallible`] as the error type.
/// In this case, you can safely `.unwrap()` the value without risking a panic.
pub fn add<H>(&mut self, header: H) -> Result<(), H::Error>
where
H: AsHeaders,
{
for (key, value) in header.as_headers() {
for (key, value) in header.as_headers()? {
self.insert(key, value);
}
Ok(())
}

/// Iterate over all the header name/value pairs.
Expand Down Expand Up @@ -273,3 +332,115 @@ impl From<&String> for HeaderValue {
s.clone().into()
}
}

#[cfg(test)]
mod tests {
use crate::http::Url;
use typespec::error::ErrorKind;

use super::{FromHeaders, HeaderName, Headers};

// Just in case we add a ContentLocation struct later, this one is named "ForTest" to indicate it's just here for this test.
#[derive(Debug)]
struct ContentLocationForTest(Url);

impl FromHeaders for ContentLocationForTest {
type Error = url::ParseError;

fn header_names() -> &'static [&'static str] {
&["content-location"]
}

fn from_headers(headers: &super::Headers) -> Result<Option<Self>, Self::Error> {
let Some(loc) = headers.get_optional_str(&HeaderName::from("content-location")) else {
return Ok(None);
};

Ok(Some(ContentLocationForTest(loc.parse()?)))
}
}

#[test]
pub fn headers_get_optional_returns_ok_some_if_header_present_and_valid() {
let mut headers = Headers::new();
headers.insert("content-location", "https://example.com");
let content_location: ContentLocationForTest = headers.get_optional().unwrap().unwrap();
assert_eq!("https://example.com/", content_location.0.as_str())
}

#[test]
pub fn headers_get_optional_returns_ok_none_if_header_not_present() {
let headers = Headers::new();
let content_location: Option<ContentLocationForTest> = headers.get_optional().unwrap();
assert!(content_location.is_none())
}

#[test]
pub fn headers_get_optional_returns_err_if_conversion_fails() {
let mut headers = Headers::new();
headers.insert("content-location", "not a URL");
let err = headers
.get_optional::<ContentLocationForTest>()
.unwrap_err();
assert_eq!(url::ParseError::RelativeUrlWithoutBase, err)
}

#[test]
pub fn headers_get_returns_ok_if_header_present_and_valid() {
let mut headers = Headers::new();
headers.insert("content-location", "https://example.com");
let content_location: ContentLocationForTest = headers.get().unwrap();
assert_eq!("https://example.com/", content_location.0.as_str())
}

#[test]
pub fn headers_get_returns_err_if_header_not_present() {
let headers = Headers::new();
let err = headers.get::<ContentLocationForTest>().unwrap_err();
assert_eq!(&ErrorKind::DataConversion, err.kind());

// The "Display" implementation is the canonical way to get an error's "message"
assert_eq!(
"required header(s) not found: content-location",
format!("{}", err)
);
}

#[test]
pub fn headers_get_returns_err_if_header_requiring_multiple_headers_not_present() {
#[derive(Debug)]
struct HasTwoHeaders;

impl FromHeaders for HasTwoHeaders {
type Error = std::convert::Infallible;

fn header_names() -> &'static [&'static str] {
&["header-a", "header-b"]
}

fn from_headers(_: &Headers) -> Result<Option<Self>, Self::Error> {
Ok(None)
}
}

let headers = Headers::new();
let err = headers.get::<HasTwoHeaders>().unwrap_err();
assert_eq!(&ErrorKind::DataConversion, err.kind());

// The "Display" implementation is the canonical way to get an error's "message"
assert_eq!(
"required header(s) not found: header-a, header-b",
format!("{}", err)
);
}

#[test]
pub fn headers_get_returns_err_if_conversion_fails() {
let mut headers = Headers::new();
headers.insert("content-location", "not a URL");
let err = headers.get::<ContentLocationForTest>().unwrap_err();
assert_eq!(&ErrorKind::DataConversion, err.kind());
let inner: Box<url::ParseError> = err.into_inner().unwrap().downcast().unwrap();
assert_eq!(Box::new(url::ParseError::RelativeUrlWithoutBase), inner)
}
}
Loading