-
Notifications
You must be signed in to change notification settings - Fork 131
/
security.rs
125 lines (111 loc) · 3.78 KB
/
security.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0.
use crate::Result;
use grpcio::{Channel, ChannelBuilder, ChannelCredentialsBuilder, Environment};
use regex::Regex;
use std::{
fs::File,
io::Read,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
lazy_static::lazy_static! {
static ref SCHEME_REG: Regex = Regex::new(r"^\s*(https?://)").unwrap();
}
fn check_pem_file(tag: &str, path: &Path) -> Result<File> {
File::open(path)
.map_err(|e| internal_err!("failed to open {} to load {}: {:?}", path.display(), tag, e))
}
fn load_pem_file(tag: &str, path: &Path) -> Result<Vec<u8>> {
let mut file = check_pem_file(tag, path)?;
let mut key = vec![];
file.read_to_end(&mut key)
.map_err(|e| {
internal_err!(
"failed to load {} from path {}: {:?}",
tag,
path.display(),
e
)
})
.map(|_| key)
}
/// Manages the TLS protocol
#[derive(Default)]
pub struct SecurityManager {
/// The PEM encoding of the server’s CA certificates.
ca: Vec<u8>,
/// The PEM encoding of the server’s certificate chain.
cert: Vec<u8>,
/// The path to the file that contains the PEM encoding of the server’s private key.
key: PathBuf,
}
impl SecurityManager {
/// Load TLS configuration from files.
pub fn load(
ca_path: impl AsRef<Path>,
cert_path: impl AsRef<Path>,
key_path: impl Into<PathBuf>,
) -> Result<SecurityManager> {
let key_path = key_path.into();
check_pem_file("private key", &key_path)?;
Ok(SecurityManager {
ca: load_pem_file("ca", ca_path.as_ref())?,
cert: load_pem_file("certificate", cert_path.as_ref())?,
key: key_path,
})
}
/// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection.
pub fn connect<Factory, Client>(
&self,
env: Arc<Environment>,
addr: &str,
factory: Factory,
) -> Result<Client>
where
Factory: FnOnce(Channel) -> Client,
{
info!("connect to rpc server at endpoint: {:?}", addr);
let addr = SCHEME_REG.replace(addr, "");
let cb = ChannelBuilder::new(env)
.keepalive_time(Duration::from_secs(10))
.keepalive_timeout(Duration::from_secs(3));
let channel = if self.ca.is_empty() {
cb.connect(&addr)
} else {
let cred = ChannelCredentialsBuilder::new()
.root_cert(self.ca.clone())
.cert(self.cert.clone(), load_pem_file("private key", &self.key)?)
.build();
cb.secure_connect(&addr, cred)
};
Ok(factory(channel))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{fs::File, io::Write, path::PathBuf};
use tempfile;
#[test]
fn test_security() {
let temp = tempfile::tempdir().unwrap();
let example_ca = temp.path().join("ca");
let example_cert = temp.path().join("cert");
let example_pem = temp.path().join("key");
for (id, f) in (&[&example_ca, &example_cert, &example_pem])
.iter()
.enumerate()
{
File::create(f).unwrap().write_all(&[id as u8]).unwrap();
}
let cert_path: PathBuf = format!("{}", example_cert.display()).into();
let key_path: PathBuf = format!("{}", example_pem.display()).into();
let ca_path: PathBuf = format!("{}", example_ca.display()).into();
let mgr = SecurityManager::load(&ca_path, &cert_path, &key_path).unwrap();
assert_eq!(mgr.ca, vec![0]);
assert_eq!(mgr.cert, vec![1]);
let key = load_pem_file("private key", &key_path).unwrap();
assert_eq!(key, vec![2]);
}
}