-
Notifications
You must be signed in to change notification settings - Fork 166
/
main.rs
111 lines (98 loc) · 3.29 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
use axum::{
body::Bytes,
extract::{Path, State},
http::{header, HeaderValue, StatusCode},
response::IntoResponse,
routing::get,
Router,
};
use clap::Parser;
use std::{
collections::HashMap,
net::{Ipv4Addr, SocketAddr},
sync::{Arc, RwLock},
time::Duration,
};
use tokio::net::TcpListener;
use tower::ServiceBuilder;
use tower_http::{
timeout::TimeoutLayer,
trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer},
LatencyUnit, ServiceBuilderExt,
};
/// Simple key/value store with an HTTP API
#[derive(Debug, Parser)]
struct Config {
/// The port to listen on
#[clap(short = 'p', long, default_value = "3000")]
port: u16,
}
#[derive(Clone, Debug)]
struct AppState {
db: Arc<RwLock<HashMap<String, Bytes>>>,
}
#[tokio::main]
async fn main() {
// Setup tracing
tracing_subscriber::fmt::init();
// Parse command line arguments
let config = Config::parse();
// Run our service
let addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, config.port));
tracing::info!("Listening on {}", addr);
axum::serve(
TcpListener::bind(addr).await.expect("bind error"),
app().into_make_service(),
)
.await
.expect("server error");
}
fn app() -> Router {
// Build our database for holding the key/value pairs
let state = AppState {
db: Arc::new(RwLock::new(HashMap::new())),
};
let sensitive_headers: Arc<[_]> = vec![header::AUTHORIZATION, header::COOKIE].into();
// Build our middleware stack
let middleware = ServiceBuilder::new()
// Mark the `Authorization` and `Cookie` headers as sensitive so it doesn't show in logs
.sensitive_request_headers(sensitive_headers.clone())
// Add high level tracing/logging to all requests
.layer(
TraceLayer::new_for_http()
.on_body_chunk(|chunk: &Bytes, latency: Duration, _: &tracing::Span| {
tracing::trace!(size_bytes = chunk.len(), latency = ?latency, "sending body chunk")
})
.make_span_with(DefaultMakeSpan::new().include_headers(true))
.on_response(DefaultOnResponse::new().include_headers(true).latency_unit(LatencyUnit::Micros)),
)
.sensitive_response_headers(sensitive_headers)
// Set a timeout
.layer(TimeoutLayer::new(Duration::from_secs(10)))
// Compress responses
.compression()
// Set a `Content-Type` if there isn't one already.
.insert_response_header_if_not_present(
header::CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
// Build route service
Router::new()
.route("/:key", get(get_key).post(set_key))
.layer(middleware)
.with_state(state)
}
async fn get_key(path: Path<String>, state: State<AppState>) -> impl IntoResponse {
let state = state.db.read().unwrap();
if let Some(value) = state.get(&*path).cloned() {
Ok(value)
} else {
Err(StatusCode::NOT_FOUND)
}
}
async fn set_key(Path(path): Path<String>, state: State<AppState>, value: Bytes) {
let mut state = state.db.write().unwrap();
state.insert(path, value);
}
// See https://github.com/tokio-rs/axum/blob/main/examples/testing/src/main.rs for an example of
// how to test axum apps