diff --git a/aes/src/key_schedule.rs b/aes/src/key_schedule.rs index 9dbe36c..186f26b 100644 --- a/aes/src/key_schedule.rs +++ b/aes/src/key_schedule.rs @@ -1,8 +1,4 @@ -use super::{ - constants::*, - error::AesError, - utils::{gen_matrix, rotate_left, xor_array}, -}; +use super::{constants::*, error::AesError, utils::rotate_left}; const AES_KEY_SIZE_128: usize = 128 / 8; const AES_KEY_SIZE_192: usize = 192 / 8; @@ -53,6 +49,13 @@ impl KeySchedule { } } + pub fn round_key(&self, round: usize) -> [[u8; 4]; 4] { + let mut key: [[u8; 4]; 4] = [[0; 4]; 4]; + key.copy_from_slice(&self.keys[round..(round + 4)]); + + key + } + /// Performs key expansion for AES encryption. /// /// This function expands an initial key into a series of round keys used @@ -72,21 +75,15 @@ impl KeySchedule { /// Returns `AesError` if the initial key is too short or if any /// part of the key expansion process fails. fn key_expansion(pk: &[u8]) -> Result, AesError> { - let mut words: Vec<[u8; 4]> = vec![[0; 4]; 4]; + let mut words: Vec<[u8; 4]> = vec![]; // Generate the initial words `w0-w3` - // pk.chunks(4).for_each(|chunk| { - // let mut array = [0u8; 4]; - // let len = chunk.len().min(4); - // array[..len].copy_from_slice(&chunk[..len]); - // words.push(array); - // }); - - for (r, chunk) in pk.chunks(4).enumerate() { - for (c, &byte) in chunk.iter().enumerate() { - words[c][r] = byte; - } - } + pk.chunks(4).for_each(|chunk| { + let mut array = [0u8; 4]; + let len = chunk.len().min(4); + array[..len].copy_from_slice(&chunk[..len]); + words.push(array); + }); for round in 0..10 { let previous_key_matrix_slice = &words[words.len().saturating_sub(4)..]; @@ -123,14 +120,14 @@ impl KeySchedule { let mut new_key_matrix: [[u8; 4]; 4] = [[0u8; 4]; 4]; // Apply the g_function to the last column of the previous round key - let mut array_rc = KeySchedule::g_function(Self::get_column(key_matrix, 3), rc); + let mut array_rc = KeySchedule::g_function(key_matrix[key_matrix.len() - 1], rc); for c in 0..4 { let mut next_array_rc: [u8; 4] = [0u8; 4]; // XOR each column of the previous key with the transformed column // to create the new round key for r in 0..4 { - new_key_matrix[r][c] = array_rc[r] ^ key_matrix[r][c]; - next_array_rc[r] = new_key_matrix[r][c]; + new_key_matrix[c][r] = array_rc[r] ^ key_matrix[c][r]; + next_array_rc[r] = new_key_matrix[c][r]; } // Update array_rc for the next iteration @@ -140,18 +137,6 @@ impl KeySchedule { new_key_matrix } - /// Retrieves a column specified by index `col` from a 4x4 matrix - #[inline] - fn get_column(matrix: &[[u8; 4]; 4], col: usize) -> [u8; 4] { - let mut column: [u8; 4] = [0u8; 4]; - - for r in 0..4 { - column[r] = matrix[r][col]; - } - - column - } - /// Performs the 'g' function of the AES key expansion. /// /// This function is part of the key expansion routine for AES encryption. It @@ -197,13 +182,6 @@ mod tests { assert_eq!(new_word, [118, 123, 242, 124]); } - #[test] - fn test_get_column() { - let column = KeySchedule::get_column(&MATRIX, 3); - - assert_eq!(column, [4, 8, 12, 16]); - } - #[test] fn test_generate_new_round() { let new_round = KeySchedule::generate_new_round(&MATRIX, 1); @@ -221,5 +199,56 @@ mod tests { #[test] fn test_key_expansion() { let pk: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; + + let key_schedule = KeySchedule::new(&pk).unwrap(); + assert_eq!( + key_schedule.keys, + [ + [0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15], + [214, 170, 116, 253], + [210, 175, 114, 250], + [218, 166, 120, 241], + [214, 171, 118, 254], + [182, 146, 207, 11], + [100, 61, 189, 241], + [190, 155, 197, 0], + [104, 48, 179, 254], + [182, 255, 116, 78], + [210, 194, 201, 191], + [108, 89, 12, 191], + [4, 105, 191, 65], + [71, 247, 247, 188], + [149, 53, 62, 3], + [249, 108, 50, 188], + [253, 5, 141, 253], + [60, 170, 163, 232], + [169, 159, 157, 235], + [80, 243, 175, 87], + [173, 246, 34, 170], + [94, 57, 15, 125], + [247, 166, 146, 150], + [167, 85, 61, 193], + [10, 163, 31, 107], + [20, 249, 112, 26], + [227, 95, 226, 140], + [68, 10, 223, 77], + [78, 169, 192, 38], + [71, 67, 135, 53], + [164, 28, 101, 185], + [224, 22, 186, 244], + [174, 191, 122, 210], + [84, 153, 50, 209], + [240, 133, 87, 104], + [16, 147, 237, 156], + [190, 44, 151, 78], + [19, 17, 29, 127], + [227, 148, 74, 23], + [243, 7, 167, 139], + [77, 43, 48, 197] + ] + ); } } diff --git a/aes/src/lib.rs b/aes/src/lib.rs index 35d9709..4f7b8d0 100644 --- a/aes/src/lib.rs +++ b/aes/src/lib.rs @@ -92,33 +92,30 @@ mod tests { use super::*; const INPUT: [u8; 16] = [ - 184, 186, 199, 73, 3, 159, 73, 223, 184, 186, 199, 73, 3, 159, 73, 223, + 0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, ]; - const PK: [u8; 16] = [ - 144, 151, 52, 80, 105, 108, 207, 250, 242, 244, 87, 51, 11, 15, 172, 153, - ]; + const PK: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; #[test] fn aes_encryption() { let mut aes = AES::new(&PK, &INPUT).unwrap(); } - // #[test] - // fn test_add_round_key() { - // let mut aes = AES::new(&PK, &INPUT).unwrap(); - // aes.add_round_key(); - - // assert_eq!( - // aes.state, - // [ - // [52, 2, 214, 48], - // [216, 13, 47, 96], - // [13, 68, 96, 217], - // [28, 63, 92, 90] - // ] - // ); - // } + #[test] + fn test_add_round_key() { + let mut aes = AES::new(&PK, &INPUT).unwrap(); + aes.add_round_key(aes.key_schedule.round_key(0)); + assert_eq!( + aes.state, + [ + [0, 16, 32, 48], + [64, 80, 96, 112], + [128, 144, 160, 176], + [192, 208, 224, 240] + ] + ); + } #[test] fn test_substitution() { @@ -152,23 +149,3 @@ mod tests { ); } } - -// [0, 1, 2, 3], -// [4, 5, 6, 7], -// [8, 9, 10, 11], -// [12, 13, 14, 15] - -// G function input: [3, 7, 11, 15], 1 -// G function output: [7, 11, 15, 3] -> [197, 43, 118, 123] -// COL_1: [196, 43, 118, 123] ^ [0, 4, 8, 12] = [196, 47, 126, 119] -// COL_2: [196, 47, 126, 119] ^ [1,5,9,13] = [197, 42, 119, 122] -// COL_3: [197, 42, 119, 122] ^ [2,6,10,14] = [199, 44, 125, 116] -// COL_4: [199, 44, 125, 116] ^ [3,7,11,15] = [196, 43, 118, 123] -// -//[[196, 197, 199, 196], [47, 42, 44, 43], [126, 119, 125, 118], [119, 122, 116, 123]] -// -// -// -// [8, 12, 16, 4] => [49, 254, 202, 242] - -// 214, 170 diff --git a/aes/src/utils.rs b/aes/src/utils.rs index faa8948..a70ee2d 100644 --- a/aes/src/utils.rs +++ b/aes/src/utils.rs @@ -4,7 +4,7 @@ pub fn gen_matrix(bytes: &[u8; 16]) -> [[u8; 4]; 4] { for (i, chunk) in bytes.chunks(4).enumerate() { for (j, &byte) in chunk.iter().enumerate() { - matrix[j][i] = byte; + matrix[i][j] = byte; } }