From 0e90092e67293b8b67855da2bb45902de8363635 Mon Sep 17 00:00:00 2001 From: Adam Binford Date: Fri, 5 Apr 2024 14:49:21 -0400 Subject: [PATCH] Add support for encrypting data transfers with AES (#94) --- .../minidfs/src/main/java/main/Main.java | 3 + crates/hdfs-native/src/hdfs/connection.rs | 14 +- crates/hdfs-native/src/minidfs.rs | 2 + crates/hdfs-native/src/security/digest.rs | 8 + crates/hdfs-native/src/security/sasl.rs | 298 +++++++++++++----- crates/hdfs-native/tests/test_integration.rs | 12 + 6 files changed, 244 insertions(+), 93 deletions(-) diff --git a/crates/hdfs-native/minidfs/src/main/java/main/Main.java b/crates/hdfs-native/minidfs/src/main/java/main/Main.java index 99aab6a..33c58cb 100644 --- a/crates/hdfs-native/minidfs/src/main/java/main/Main.java +++ b/crates/hdfs-native/minidfs/src/main/java/main/Main.java @@ -57,7 +57,10 @@ public static void main(String args[]) throws Exception { conf.set(HADOOP_RPC_PROTECTION, "privacy"); conf.set(DFS_DATA_TRANSFER_PROTECTION_KEY, "privacy"); if (flags.contains("data_transfer_encryption")) { + // Force encryption for all connections conf.set(DFSConfigKeys.DFS_ENCRYPT_DATA_TRANSFER_KEY, "true"); + } + if (flags.contains("aes")) { conf.set(DFS_ENCRYPT_DATA_TRANSFER_CIPHER_SUITES_KEY, "AES/CTR/NoPadding"); } } else if (flags.contains("integrity")) { diff --git a/crates/hdfs-native/src/hdfs/connection.rs b/crates/hdfs-native/src/hdfs/connection.rs index a86e9db..50c9849 100644 --- a/crates/hdfs-native/src/hdfs/connection.rs +++ b/crates/hdfs-native/src/hdfs/connection.rs @@ -550,12 +550,9 @@ impl DatanodeConnection { .await?; self.writer.flush().await?; - let msg_length = self.reader.read_length_delimiter().await?; + let message = self.reader.read_proto().await?; - let mut response_buf = BytesMut::zeroed(msg_length); - self.reader.read_exact(&mut response_buf).await?; - - let response = hdfs::BlockOpResponseProto::decode(response_buf.freeze())?; + let response = hdfs::BlockOpResponseProto::decode(message)?; Ok(response) } @@ -630,12 +627,9 @@ pub(crate) struct DatanodeReader { impl DatanodeReader { pub(crate) async fn read_ack(&mut self) -> Result { - let ack_length = self.reader.read_length_delimiter().await?; - - let mut response_buf = BytesMut::zeroed(ack_length); - self.reader.read_exact(&mut response_buf).await?; + let message = self.reader.read_proto().await?; - let response = hdfs::PipelineAckProto::decode(response_buf.freeze())?; + let response = hdfs::PipelineAckProto::decode(message)?; Ok(response) } } diff --git a/crates/hdfs-native/src/minidfs.rs b/crates/hdfs-native/src/minidfs.rs index d42a0cb..9edf6fe 100644 --- a/crates/hdfs-native/src/minidfs.rs +++ b/crates/hdfs-native/src/minidfs.rs @@ -13,6 +13,7 @@ pub enum DfsFeatures { Token, Integrity, Privacy, + AES, HA, ViewFS, EC, @@ -28,6 +29,7 @@ impl DfsFeatures { DfsFeatures::Privacy => "privacy", DfsFeatures::Security => "security", DfsFeatures::Integrity => "integrity", + DfsFeatures::AES => "aes", DfsFeatures::Token => "token", DfsFeatures::RBF => "rbf", } diff --git a/crates/hdfs-native/src/security/digest.rs b/crates/hdfs-native/src/security/digest.rs index cce5434..1b61a0c 100644 --- a/crates/hdfs-native/src/security/digest.rs +++ b/crates/hdfs-native/src/security/digest.rs @@ -357,6 +357,14 @@ impl DigestSaslSession { server: kis, } } + + pub(crate) fn supports_encryption(&self) -> bool { + match &self.state { + DigestState::Stepped(ctx) => matches!(ctx.qop, Qop::AuthConf), + DigestState::Completed(ctx) => ctx.as_ref().is_some_and(|c| c.encryptor.is_some()), + _ => false, + } + } } impl SaslSession for DigestSaslSession { diff --git a/crates/hdfs-native/src/security/sasl.rs b/crates/hdfs-native/src/security/sasl.rs index d60b7c7..8e87876 100644 --- a/crates/hdfs-native/src/security/sasl.rs +++ b/crates/hdfs-native/src/security/sasl.rs @@ -1,4 +1,5 @@ -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use cipher::{KeyIvInit, StreamCipher}; use log::{debug, warn}; use prost::Message; use std::io; @@ -11,7 +12,7 @@ use tokio::{ }; use super::user::BlockTokenIdentifier; -// use crate::proto::hdfs::{CipherOptionProto, CipherSuiteProto}; +use crate::proto::hdfs::{CipherOptionProto, CipherSuiteProto}; use crate::proto::{ common::{ rpc_response_header_proto::RpcStatusProto, @@ -30,6 +31,10 @@ use crate::{HdfsError, Result}; use super::gssapi::GssapiSession; use super::user::{User, UserInfo}; +type Aes128Ctr = ctr::Ctr128BE; +type Aes192Ctr = ctr::Ctr128BE; +type Aes256Ctr = ctr::Ctr128BE; + const SASL_CALL_ID: i32 = -33; const SASL_TRANSFER_MAGIC_NUMBER: i32 = 0xDEADBEEFu32 as i32; const HDFS_DELEGATION_TOKEN: &str = "HDFS_DELEGATION_TOKEN"; @@ -388,27 +393,16 @@ impl std::fmt::Debug for SaslWriter { } } -pub(crate) struct SaslDatanodeReader { - stream: BufReader, - session: Option>>, +struct SaslDecryptor { + session: Arc>, size_buffer: [u8; 4], response_buffer: Vec, data_buffer: Bytes, } -impl SaslDatanodeReader { - fn new(stream: OwnedReadHalf, session: Option>>) -> Self { - Self { - stream: BufReader::new(stream), - session, - size_buffer: [0u8; 4], - response_buffer: Vec::with_capacity(65536), - data_buffer: Bytes::new(), - } - } - - async fn read_more_data(&mut self) -> Result<()> { - self.stream.read_exact(&mut self.size_buffer).await?; +impl SaslDecryptor { + async fn read_more_data(&mut self, stream: &mut BufReader) -> Result<()> { + stream.read_exact(&mut self.size_buffer).await?; let msg_length = u32::from_be_bytes(self.size_buffer) as usize; // Resize our internal buffer if the message is larger @@ -416,14 +410,12 @@ impl SaslDatanodeReader { self.response_buffer.resize(msg_length, 0); } - self.stream + stream .read_exact(&mut self.response_buffer[..msg_length]) .await?; self.data_buffer = self .session - .as_ref() - .unwrap() .lock() .unwrap() .decode(&self.response_buffer[..msg_length])? @@ -431,72 +423,177 @@ impl SaslDatanodeReader { Ok(()) } +} + +enum DatanodeDecryptor { + Sasl(SaslDecryptor), + Cipher(Box), +} + +pub(crate) struct SaslDatanodeReader { + stream: BufReader, + decryptor: Option, +} + +impl SaslDatanodeReader { + fn unencrypted(stream: OwnedReadHalf) -> Self { + Self { + stream: BufReader::new(stream), + decryptor: None, + } + } + + fn sasl(stream: OwnedReadHalf, session: Arc>) -> Self { + let decryptor = SaslDecryptor { + session, + size_buffer: [0u8; 4], + response_buffer: Vec::with_capacity(65536), + data_buffer: Bytes::new(), + }; + Self { + stream: BufReader::new(stream), + decryptor: Some(DatanodeDecryptor::Sasl(decryptor)), + } + } + + fn cipher(stream: OwnedReadHalf, cipher: Box) -> Self { + Self { + stream: BufReader::new(stream), + decryptor: Some(DatanodeDecryptor::Cipher(cipher)), + } + } pub(crate) async fn read_exact(&mut self, buf: &mut [u8]) -> Result { - if self.session.is_some() { - let read_len = buf.len(); - let mut bytes_remaining = read_len; - while bytes_remaining > 0 { - if !self.data_buffer.has_remaining() { - self.read_more_data().await?; + match &mut self.decryptor { + Some(DatanodeDecryptor::Sasl(sasl)) => { + let read_len = buf.len(); + let mut bytes_remaining = read_len; + while bytes_remaining > 0 { + if !sasl.data_buffer.has_remaining() { + sasl.read_more_data(&mut self.stream).await?; + } + let copy_len = usize::min(bytes_remaining, sasl.data_buffer.remaining()); + let copy_start = read_len - bytes_remaining; + sasl.data_buffer + .copy_to_slice(&mut buf[copy_start..(copy_start + copy_len)]); + bytes_remaining -= copy_len; } - let copy_len = usize::min(bytes_remaining, self.data_buffer.remaining()); - let copy_start = read_len - bytes_remaining; - self.data_buffer - .copy_to_slice(&mut buf[copy_start..(copy_start + copy_len)]); - bytes_remaining -= copy_len; - } - Ok(read_len) - } else { - Ok(self.stream.read_exact(buf).await?) + Ok(read_len) + } + Some(DatanodeDecryptor::Cipher(cipher)) => { + let read_len = self.stream.read_exact(buf).await?; + cipher.apply_keystream(buf); + Ok(read_len) + } + None => Ok(self.stream.read_exact(buf).await?), } } - /// Reads a length delimiter from the stream without advancing the position of the stream - pub(crate) async fn read_length_delimiter(&mut self) -> Result { - if self.session.is_some() { - // assumption is we'll have the whole length in a single message - if !self.data_buffer.has_remaining() { - self.read_more_data().await?; + /// Reads a length delimiter from the stream and then reads that many bytes for a full proto message + pub(crate) async fn read_proto(&mut self) -> Result { + match &mut self.decryptor { + Some(DatanodeDecryptor::Sasl(sasl)) => { + // assumption is we'll have the whole length in a single message + if !sasl.data_buffer.has_remaining() { + sasl.read_more_data(&mut self.stream).await?; + } + let decoded_len = prost::decode_length_delimiter(&mut sasl.data_buffer)?; + + let mut buf = BytesMut::zeroed(decoded_len); + self.read_exact(&mut buf).await?; + Ok(buf.freeze()) } - let decoded_len = prost::decode_length_delimiter(&mut self.data_buffer)?; - Ok(decoded_len) - } else { - let mut buf = self.stream.fill_buf().await?; - if buf.is_empty() { - // The stream has been closed - return Err(HdfsError::DataTransferError( - "Datanode connection closed while waiting for ack".to_string(), - )); + Some(DatanodeDecryptor::Cipher(cipher)) => { + let mut msg_len = BytesMut::with_capacity(10); + // Known from varint parsing, once we either get 10 bytes or a byte less than 0x80 + // we have enough to decode the length + while msg_len.len() < 10 { + let mut byte = [self.stream.read_u8().await?]; + cipher.apply_keystream(&mut byte); + msg_len.put(&byte[..]); + if byte[0] < 0x80 { + break; + } + } + + let decoded_len = prost::decode_length_delimiter(&mut msg_len.freeze())?; + + let mut msg_buf = BytesMut::zeroed(decoded_len); + self.stream.read_exact(&mut msg_buf).await?; + cipher.apply_keystream(&mut msg_buf); + + Ok(msg_buf.freeze()) } + None => { + let mut buf = self.stream.fill_buf().await?; + if buf.is_empty() { + // The stream has been closed + return Err(HdfsError::DataTransferError( + "Datanode connection closed while waiting for ack".to_string(), + )); + } + + let decoded_len = prost::decode_length_delimiter(&mut buf)?; + self.stream + .consume(prost::length_delimiter_len(decoded_len)); - let decoded_len = prost::decode_length_delimiter(&mut buf)?; - self.stream - .consume(prost::length_delimiter_len(decoded_len)); + let mut msg_buf = BytesMut::zeroed(decoded_len); + self.stream.read_exact(&mut msg_buf).await?; - Ok(decoded_len) + Ok(msg_buf.freeze()) + } } } } +enum DatanodeEncryptor { + Sasl(Arc>), + Cipher(Box), +} + pub(crate) struct SaslDatanodeWriter { stream: OwnedWriteHalf, - session: Option>>, + encryptor: Option, } impl SaslDatanodeWriter { - fn new(stream: OwnedWriteHalf, session: Option>>) -> Self { - Self { stream, session } + fn unencrypted(stream: OwnedWriteHalf) -> Self { + Self { + stream, + encryptor: None, + } + } + + fn sasl(stream: OwnedWriteHalf, session: Arc>) -> Self { + Self { + stream, + encryptor: Some(DatanodeEncryptor::Sasl(session)), + } + } + + fn cipher(stream: OwnedWriteHalf, cipher: Box) -> Self { + Self { + stream, + encryptor: Some(DatanodeEncryptor::Cipher(cipher)), + } } pub(crate) async fn write_all(&mut self, buf: &[u8]) -> Result<()> { - if let Some(session) = self.session.as_ref() { - let wrapped = session.lock().unwrap().encode(buf)?; - self.stream.write_u32(wrapped.len() as u32).await?; - self.stream.write_all(&wrapped).await?; - } else { - self.stream.write_all(buf).await?; + match &mut self.encryptor { + Some(DatanodeEncryptor::Sasl(sasl)) => { + let wrapped = sasl.lock().unwrap().encode(buf)?; + self.stream.write_u32(wrapped.len() as u32).await?; + self.stream.write_all(&wrapped).await?; + } + Some(DatanodeEncryptor::Cipher(cipher)) => { + let mut encrypted = vec![0u8; buf.len()]; + cipher.apply_keystream_b2b(buf, &mut encrypted).unwrap(); + self.stream.write_all(&encrypted).await?; + } + None => { + self.stream.write_all(buf).await?; + } } Ok(()) } @@ -524,7 +621,7 @@ impl SaslDatanodeConnection { ) -> Result<(SaslDatanodeReader, SaslDatanodeWriter)> { // If there's no token identifier or it's a privileged port, don't do SASL negotation if token.identifier.is_empty() || datanode_id.xfer_port <= 1024 { - return Ok(self.split(None)); + return self.split(None, None); } self.stream.write_i32(SASL_TRANSFER_MAGIC_NUMBER).await?; @@ -563,19 +660,19 @@ impl SaslDatanodeConnection { let (payload, finished) = session.step(response.payload.as_ref().map(|p| &p[..]))?; assert!(!finished); - // let cipher_option = if session.supports_encryption() { - // vec![CipherOptionProto { - // suite: CipherSuiteProto::AesCtrNopadding as i32, - // ..Default::default() - // }] - // } else { - // vec![] - // }; + let cipher_option = if session.supports_encryption() { + vec![CipherOptionProto { + suite: CipherSuiteProto::AesCtrNopadding as i32, + ..Default::default() + }] + } else { + vec![] + }; let message = DataTransferEncryptorMessageProto { status: DataTransferEncryptorStatus::Success as i32, payload: Some(payload), - // cipher_option, + cipher_option, ..Default::default() }; @@ -594,9 +691,9 @@ impl SaslDatanodeConnection { assert!(finished); if session.has_security_layer() { - Ok(self.split(Some(session))) + self.split(Some(session), response.cipher_option.first()) } else { - Ok(self.split(None)) + self.split(None, None) } } @@ -618,14 +715,49 @@ impl SaslDatanodeConnection { )?) } - fn split(self, session: Option) -> (SaslDatanodeReader, SaslDatanodeWriter) { - let reader_session = session.map(|s| Arc::new(Mutex::new(s))); - let writer_session = reader_session.clone(); - + fn split( + self, + session: Option, + cipher_option: Option<&CipherOptionProto>, + ) -> Result<(SaslDatanodeReader, SaslDatanodeWriter)> { let (stream_reader, stream_writer) = self.stream.into_inner().into_split(); + if let Some(cipher) = cipher_option { + let mut session = session.unwrap(); + match cipher.suite() { + CipherSuiteProto::AesCtrNopadding => { + let in_key = session.decode(cipher.in_key())?; + let out_key = session.decode(cipher.out_key())?; + + // For the client, the in_key is used to encrypt data to send and the out_key is for decrypting incoming data + let encryptor = Self::create_aes_cipher(&in_key, cipher.in_iv()); + let decryptor = Self::create_aes_cipher(&out_key, cipher.out_iv()); + + let reader = SaslDatanodeReader::cipher(stream_reader, decryptor); + let writer = SaslDatanodeWriter::cipher(stream_writer, encryptor); + Ok((reader, writer)) + } + c => Err(HdfsError::SASLError(format!("Unsupported cipher {:?}", c))), + } + } else if let Some(session) = session { + let reader_session = Arc::new(Mutex::new(session)); + let writer_session = Arc::clone(&reader_session); + let reader = SaslDatanodeReader::sasl(stream_reader, reader_session); + let writer = SaslDatanodeWriter::sasl(stream_writer, writer_session); + Ok((reader, writer)) + } else { + Ok(( + SaslDatanodeReader::unencrypted(stream_reader), + SaslDatanodeWriter::unencrypted(stream_writer), + )) + } + } - let reader = SaslDatanodeReader::new(stream_reader, reader_session); - let writer = SaslDatanodeWriter::new(stream_writer, writer_session); - (reader, writer) + fn create_aes_cipher(key: &[u8], iv: &[u8]) -> Box { + match key.len() * 8 { + 128 => Box::new(Aes128Ctr::new(key.into(), iv.into())), + 192 => Box::new(Aes192Ctr::new(key.into(), iv.into())), + 256 => Box::new(Aes256Ctr::new(key.into(), iv.into())), + x => panic!("Unsupported AES bit length {}", x), + } } } diff --git a/crates/hdfs-native/tests/test_integration.rs b/crates/hdfs-native/tests/test_integration.rs index 453b4c5..6778093 100644 --- a/crates/hdfs-native/tests/test_integration.rs +++ b/crates/hdfs-native/tests/test_integration.rs @@ -79,6 +79,18 @@ mod test { .unwrap(); } + #[tokio::test] + #[serial] + async fn test_aes() { + test_with_features(&HashSet::from([ + DfsFeatures::Security, + DfsFeatures::Privacy, + DfsFeatures::AES, + ])) + .await + .unwrap(); + } + #[tokio::test] #[serial] async fn test_basic_ha() {