Skip to content

Commit

Permalink
Add IMDSv1 fallback (#2609)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Aug 30, 2022
1 parent 6ab208c commit fb7c505
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 69 deletions.
131 changes: 111 additions & 20 deletions object_store/src/aws/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ use bytes::Buf;
use chrono::{DateTime, Utc};
use futures::TryFutureExt;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, Method, Request, RequestBuilder};
use reqwest::{Client, Method, Request, RequestBuilder, StatusCode};
use serde::Deserialize;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Instant;
use tracing::warn;

type StdError = Box<dyn std::error::Error + Send + Sync>;

Expand Down Expand Up @@ -365,31 +366,39 @@ async fn instance_creds(
const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";

let token_url = format!("{}/latest/api/token", endpoint);
let token = client

let token_result = client
.request(Method::PUT, token_url)
.header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL
.send_retry(retry_config)
.await?
.text()
.await?;
.await;

let token = match token_result {
Ok(t) => Some(t.text().await?),
Err(e) if matches!(e.status(), Some(StatusCode::FORBIDDEN)) => {
warn!("received 403 from metadata endpoint, falling back to IMDSv1");
None
}
Err(e) => return Err(e.into()),
};

let role_url = format!("{}/{}/", endpoint, CREDENTIALS_PATH);
let role = client
.request(Method::GET, role_url)
.header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
.send_retry(retry_config)
.await?
.text()
.await?;
let mut role_request = client.request(Method::GET, role_url);

if let Some(token) = &token {
role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
}

let role = role_request.send_retry(retry_config).await?.text().await?;

let creds_url = format!("{}/{}/{}", endpoint, CREDENTIALS_PATH, role);
let creds: InstanceCredentials = client
.request(Method::GET, creds_url)
.header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
.send_retry(retry_config)
.await?
.json()
.await?;
let mut creds_request = client.request(Method::GET, creds_url);
if let Some(token) = &token {
creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
}

let creds: InstanceCredentials =
creds_request.send_retry(retry_config).await?.json().await?;

let now = Utc::now();
let ttl = (creds.expiration - now).to_std().unwrap_or_default();
Expand Down Expand Up @@ -470,6 +479,8 @@ async fn web_identity(
#[cfg(test)]
mod tests {
use super::*;
use crate::client::mock_server::MockServer;
use hyper::{Body, Response};
use reqwest::{Client, Method};
use std::env;

Expand Down Expand Up @@ -567,7 +578,7 @@ mod tests {

assert_eq!(
resp.status(),
reqwest::StatusCode::UNAUTHORIZED,
StatusCode::UNAUTHORIZED,
"Ensure metadata endpoint is set to only allow IMDSv2"
);

Expand All @@ -583,4 +594,84 @@ mod tests {
assert!(!secret.is_empty());
assert!(!token.is_empty())
}

#[tokio::test]
async fn test_mock() {
let server = MockServer::new();

const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token";

let secret_access_key = "SECRET";
let access_key_id = "KEYID";
let token = "TOKEN";

let endpoint = server.url();
let client = Client::new();
let retry_config = RetryConfig::default();

// Test IMDSv2
server.push_fn(|req| {
assert_eq!(req.uri().path(), "/latest/api/token");
assert_eq!(req.method(), &Method::PUT);
Response::new(Body::from("cupcakes"))
});
server.push_fn(|req| {
assert_eq!(
req.uri().path(),
"/latest/meta-data/iam/security-credentials/"
);
assert_eq!(req.method(), &Method::GET);
let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
assert_eq!(t, "cupcakes");
Response::new(Body::from("myrole"))
});
server.push_fn(|req| {
assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
assert_eq!(req.method(), &Method::GET);
let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
assert_eq!(t, "cupcakes");
Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#))
});

let creds = instance_creds(&client, &retry_config, endpoint)
.await
.unwrap();

assert_eq!(creds.token.token.as_deref().unwrap(), token);
assert_eq!(&creds.token.key_id, access_key_id);
assert_eq!(&creds.token.secret_key, secret_access_key);

// Test IMDSv1
server.push_fn(|req| {
assert_eq!(req.uri().path(), "/latest/api/token");
assert_eq!(req.method(), &Method::PUT);
Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::empty())
.unwrap()
});
server.push_fn(|req| {
assert_eq!(
req.uri().path(),
"/latest/meta-data/iam/security-credentials/"
);
assert_eq!(req.method(), &Method::GET);
assert!(req.headers().get(IMDSV2_HEADER).is_none());
Response::new(Body::from("myrole"))
});
server.push_fn(|req| {
assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
assert_eq!(req.method(), &Method::GET);
assert!(req.headers().get(IMDSV2_HEADER).is_none());
Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#))
});

let creds = instance_creds(&client, &retry_config, endpoint)
.await
.unwrap();

assert_eq!(creds.token.token.as_deref().unwrap(), token);
assert_eq!(&creds.token.key_id, access_key_id);
assert_eq!(&creds.token.secret_key, secret_access_key);
}
}
105 changes: 105 additions & 0 deletions object_store/src/client/mock_server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server};
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;

pub type ResponseFn = Box<dyn FnOnce(Request<Body>) -> Response<Body> + Send>;

/// A mock server
pub struct MockServer {
responses: Arc<Mutex<VecDeque<ResponseFn>>>,
shutdown: oneshot::Sender<()>,
handle: JoinHandle<()>,
url: String,
}

impl MockServer {
pub fn new() -> Self {
let responses: Arc<Mutex<VecDeque<ResponseFn>>> =
Arc::new(Mutex::new(VecDeque::with_capacity(10)));

let r = Arc::clone(&responses);
let make_service = make_service_fn(move |_conn| {
let r = Arc::clone(&r);
async move {
Ok::<_, Infallible>(service_fn(move |req| {
let r = Arc::clone(&r);
async move {
Ok::<_, Infallible>(match r.lock().pop_front() {
Some(r) => r(req),
None => Response::new(Body::from("Hello World")),
})
}
}))
}
});

let (shutdown, rx) = oneshot::channel::<()>();
let server =
Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service);

let url = format!("http://{}", server.local_addr());

let handle = tokio::spawn(async move {
server
.with_graceful_shutdown(async {
rx.await.ok();
})
.await
.unwrap()
});

Self {
responses,
shutdown,
handle,
url,
}
}

/// The url of the mock server
pub fn url(&self) -> &str {
&self.url
}

/// Add a response
pub fn push(&self, response: Response<Body>) {
self.push_fn(|_| response)
}

/// Add a response function
pub fn push_fn<F>(&self, f: F)
where
F: FnOnce(Request<Body>) -> Response<Body> + Send + 'static,
{
self.responses.lock().push_back(Box::new(f))
}

/// Shutdown the mock server
pub async fn shutdown(self) {
let _ = self.shutdown.send(());
self.handle.await.unwrap()
}
}
2 changes: 2 additions & 0 deletions object_store/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
//! Generic utilities reqwest based ObjectStore implementations
pub mod backoff;
#[cfg(test)]
pub mod mock_server;
pub mod pagination;
pub mod retry;
pub mod token;
Loading

0 comments on commit fb7c505

Please sign in to comment.