Skip to content

Commit

Permalink
Fix state order
Browse files Browse the repository at this point in the history
  • Loading branch information
0xphen committed Dec 2, 2023
1 parent be1e1cf commit 6d85481
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 80 deletions.
109 changes: 69 additions & 40 deletions aes/src/key_schedule.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use super::{
constants::*,
error::AesError,
utils::{gen_matrix, rotate_left, xor_array},
};
use super::{constants::*, error::AesError, utils::rotate_left};

const AES_KEY_SIZE_128: usize = 128 / 8;
const AES_KEY_SIZE_192: usize = 192 / 8;
Expand Down Expand Up @@ -53,6 +49,13 @@ impl KeySchedule {
}
}

pub fn round_key(&self, round: usize) -> [[u8; 4]; 4] {
let mut key: [[u8; 4]; 4] = [[0; 4]; 4];
key.copy_from_slice(&self.keys[round..(round + 4)]);

key
}

/// Performs key expansion for AES encryption.
///
/// This function expands an initial key into a series of round keys used
Expand All @@ -72,21 +75,15 @@ impl KeySchedule {
/// Returns `AesError` if the initial key is too short or if any
/// part of the key expansion process fails.
fn key_expansion(pk: &[u8]) -> Result<Vec<[u8; 4]>, AesError> {
let mut words: Vec<[u8; 4]> = vec![[0; 4]; 4];
let mut words: Vec<[u8; 4]> = vec![];

// Generate the initial words `w0-w3`
// pk.chunks(4).for_each(|chunk| {
// let mut array = [0u8; 4];
// let len = chunk.len().min(4);
// array[..len].copy_from_slice(&chunk[..len]);
// words.push(array);
// });

for (r, chunk) in pk.chunks(4).enumerate() {
for (c, &byte) in chunk.iter().enumerate() {
words[c][r] = byte;
}
}
pk.chunks(4).for_each(|chunk| {
let mut array = [0u8; 4];
let len = chunk.len().min(4);
array[..len].copy_from_slice(&chunk[..len]);
words.push(array);
});

for round in 0..10 {
let previous_key_matrix_slice = &words[words.len().saturating_sub(4)..];
Expand Down Expand Up @@ -123,14 +120,14 @@ impl KeySchedule {
let mut new_key_matrix: [[u8; 4]; 4] = [[0u8; 4]; 4];

// Apply the g_function to the last column of the previous round key
let mut array_rc = KeySchedule::g_function(Self::get_column(key_matrix, 3), rc);
let mut array_rc = KeySchedule::g_function(key_matrix[key_matrix.len() - 1], rc);
for c in 0..4 {
let mut next_array_rc: [u8; 4] = [0u8; 4];
// XOR each column of the previous key with the transformed column
// to create the new round key
for r in 0..4 {
new_key_matrix[r][c] = array_rc[r] ^ key_matrix[r][c];
next_array_rc[r] = new_key_matrix[r][c];
new_key_matrix[c][r] = array_rc[r] ^ key_matrix[c][r];
next_array_rc[r] = new_key_matrix[c][r];
}

// Update array_rc for the next iteration
Expand All @@ -140,18 +137,6 @@ impl KeySchedule {
new_key_matrix
}

/// Retrieves a column specified by index `col` from a 4x4 matrix
#[inline]
fn get_column(matrix: &[[u8; 4]; 4], col: usize) -> [u8; 4] {
let mut column: [u8; 4] = [0u8; 4];

for r in 0..4 {
column[r] = matrix[r][col];
}

column
}

/// Performs the 'g' function of the AES key expansion.
///
/// This function is part of the key expansion routine for AES encryption. It
Expand Down Expand Up @@ -197,13 +182,6 @@ mod tests {
assert_eq!(new_word, [118, 123, 242, 124]);
}

#[test]
fn test_get_column() {
let column = KeySchedule::get_column(&MATRIX, 3);

assert_eq!(column, [4, 8, 12, 16]);
}

#[test]
fn test_generate_new_round() {
let new_round = KeySchedule::generate_new_round(&MATRIX, 1);
Expand All @@ -221,5 +199,56 @@ mod tests {
#[test]
fn test_key_expansion() {
let pk: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];

let key_schedule = KeySchedule::new(&pk).unwrap();
assert_eq!(
key_schedule.keys,
[
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
[214, 170, 116, 253],
[210, 175, 114, 250],
[218, 166, 120, 241],
[214, 171, 118, 254],
[182, 146, 207, 11],
[100, 61, 189, 241],
[190, 155, 197, 0],
[104, 48, 179, 254],
[182, 255, 116, 78],
[210, 194, 201, 191],
[108, 89, 12, 191],
[4, 105, 191, 65],
[71, 247, 247, 188],
[149, 53, 62, 3],
[249, 108, 50, 188],
[253, 5, 141, 253],
[60, 170, 163, 232],
[169, 159, 157, 235],
[80, 243, 175, 87],
[173, 246, 34, 170],
[94, 57, 15, 125],
[247, 166, 146, 150],
[167, 85, 61, 193],
[10, 163, 31, 107],
[20, 249, 112, 26],
[227, 95, 226, 140],
[68, 10, 223, 77],
[78, 169, 192, 38],
[71, 67, 135, 53],
[164, 28, 101, 185],
[224, 22, 186, 244],
[174, 191, 122, 210],
[84, 153, 50, 209],
[240, 133, 87, 104],
[16, 147, 237, 156],
[190, 44, 151, 78],
[19, 17, 29, 127],
[227, 148, 74, 23],
[243, 7, 167, 139],
[77, 43, 48, 197]
]
);
}
}
55 changes: 16 additions & 39 deletions aes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,33 +92,30 @@ mod tests {
use super::*;

const INPUT: [u8; 16] = [
184, 186, 199, 73, 3, 159, 73, 223, 184, 186, 199, 73, 3, 159, 73, 223,
0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255,
];

const PK: [u8; 16] = [
144, 151, 52, 80, 105, 108, 207, 250, 242, 244, 87, 51, 11, 15, 172, 153,
];
const PK: [u8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];

#[test]
fn aes_encryption() {
let mut aes = AES::new(&PK, &INPUT).unwrap();
}

// #[test]
// fn test_add_round_key() {
// let mut aes = AES::new(&PK, &INPUT).unwrap();
// aes.add_round_key();

// assert_eq!(
// aes.state,
// [
// [52, 2, 214, 48],
// [216, 13, 47, 96],
// [13, 68, 96, 217],
// [28, 63, 92, 90]
// ]
// );
// }
#[test]
fn test_add_round_key() {
let mut aes = AES::new(&PK, &INPUT).unwrap();
aes.add_round_key(aes.key_schedule.round_key(0));
assert_eq!(
aes.state,
[
[0, 16, 32, 48],
[64, 80, 96, 112],
[128, 144, 160, 176],
[192, 208, 224, 240]
]
);
}

#[test]
fn test_substitution() {
Expand Down Expand Up @@ -152,23 +149,3 @@ mod tests {
);
}
}

// [0, 1, 2, 3],
// [4, 5, 6, 7],
// [8, 9, 10, 11],
// [12, 13, 14, 15]

// G function input: [3, 7, 11, 15], 1
// G function output: [7, 11, 15, 3] -> [197, 43, 118, 123]
// COL_1: [196, 43, 118, 123] ^ [0, 4, 8, 12] = [196, 47, 126, 119]
// COL_2: [196, 47, 126, 119] ^ [1,5,9,13] = [197, 42, 119, 122]
// COL_3: [197, 42, 119, 122] ^ [2,6,10,14] = [199, 44, 125, 116]
// COL_4: [199, 44, 125, 116] ^ [3,7,11,15] = [196, 43, 118, 123]
//
//[[196, 197, 199, 196], [47, 42, 44, 43], [126, 119, 125, 118], [119, 122, 116, 123]]
//
//
//
// [8, 12, 16, 4] => [49, 254, 202, 242]

// 214, 170
2 changes: 1 addition & 1 deletion aes/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub fn gen_matrix(bytes: &[u8; 16]) -> [[u8; 4]; 4] {

for (i, chunk) in bytes.chunks(4).enumerate() {
for (j, &byte) in chunk.iter().enumerate() {
matrix[j][i] = byte;
matrix[i][j] = byte;
}
}

Expand Down

0 comments on commit 6d85481

Please sign in to comment.