diff --git a/crypto/src/aes_fruity.rs b/crypto/src/aes_fruity.rs index 2a5e1e542..9179005ad 100644 --- a/crypto/src/aes_fruity.rs +++ b/crypto/src/aes_fruity.rs @@ -3,6 +3,7 @@ // MacOS implementation of AES primitives since CommonCrypto seems to be faster than OpenSSL, especially on ARM64. use std::os::raw::{c_int, c_void}; use std::ptr::{null, null_mut}; +use std::sync::Mutex; use crate::secret::Secret; use crate::secure_eq; @@ -173,14 +174,14 @@ impl AesGcm { -pub struct Aes(*mut c_void, *mut c_void); +pub struct Aes(Mutex<*mut c_void>, Mutex<*mut c_void>); impl Drop for Aes { #[inline(always)] fn drop(&mut self) { unsafe { - CCCryptorRelease(self.0); - CCCryptorRelease(self.1); + CCCryptorRelease(*self.0.lock().unwrap()); + CCCryptorRelease(*self.1.lock().unwrap()); } } } @@ -189,7 +190,7 @@ impl Aes { pub fn new(k: &Secret) -> Self { unsafe { debug_assert!(KEY_SIZE == 32 || KEY_SIZE == 24 || KEY_SIZE == 16, "AES supports 128, 192, or 256 bits keys"); - let mut aes: Self = std::mem::zeroed(); + let aes: Self = std::mem::zeroed(); assert_eq!( CCCryptorCreateWithMode( kCCEncrypt, @@ -203,7 +204,7 @@ impl Aes { 0, 0, kCCOptionECBMode, - &mut aes.0 + &mut *aes.0.lock().unwrap() ), 0 ); @@ -220,7 +221,7 @@ impl Aes { 0, 0, kCCOptionECBMode, - &mut aes.1 + &mut *aes.1.lock().unwrap() ), 0 ); @@ -229,20 +230,20 @@ impl Aes { } #[inline(always)] - pub fn encrypt_block_in_place(&mut self, data: &mut [u8]) { + pub fn encrypt_block_in_place(&self, data: &mut [u8]) { assert_eq!(data.len(), 16); unsafe { let mut data_out_written = 0; - CCCryptorUpdate(self.0, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + CCCryptorUpdate(*self.0.lock().unwrap(), data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); } } #[inline(always)] - pub fn decrypt_block_in_place(&mut self, data: &mut [u8]) { + pub fn decrypt_block_in_place(&self, data: &mut [u8]) { assert_eq!(data.len(), 16); unsafe { let mut data_out_written = 0; - CCCryptorUpdate(self.1, data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); + CCCryptorUpdate(*self.1.lock().unwrap(), data.as_ptr().cast(), 16, data.as_mut_ptr().cast(), 16, &mut data_out_written); } } } diff --git a/crypto/src/aes_openssl.rs b/crypto/src/aes_openssl.rs index 7b46ddd9e..98cb433e2 100644 --- a/crypto/src/aes_openssl.rs +++ b/crypto/src/aes_openssl.rs @@ -116,14 +116,14 @@ impl Aes { /// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls. #[inline(always)] - pub fn encrypt_block_in_place(&mut self, data: &mut [u8]) { + pub fn encrypt_block_in_place(&self, data: &mut [u8]) { debug_assert_eq!(data.len(), AES_BLOCK_SIZE, "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing."); let ptr = data.as_mut_ptr(); unsafe { self.0.update::(data, ptr).unwrap() } } /// Do not ever encrypt the same plaintext twice. Make sure data is always different between calls. #[inline(always)] - pub fn decrypt_block_in_place(&mut self, data: &mut [u8]) { + pub fn decrypt_block_in_place(&self, data: &mut [u8]) { debug_assert_eq!(data.len(), AES_BLOCK_SIZE, "AesEcb should not be used to encrypt more than one block at a time unless you really know what you are doing."); let ptr = data.as_mut_ptr(); unsafe { self.1.update::(data, ptr).unwrap() } diff --git a/crypto/src/lib.rs b/crypto/src/lib.rs index 3163613c0..8bfdc0569 100644 --- a/crypto/src/lib.rs +++ b/crypto/src/lib.rs @@ -15,7 +15,6 @@ pub mod salsa; pub mod typestate; pub mod x25519; -/// NOTE: we assume that each aes library is threadsafe pub mod aes_fruity; pub mod aes_openssl; #[cfg(target_os = "macos")] diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 2fc9bf72f..0429b81ac 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -83,7 +83,7 @@ pub struct Session { psk: Secret, send_counter: AtomicU64, receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO], - header_protection_cipher: Mutex, + header_protection_cipher: Aes, state: RwLock, defrag: [Mutex>; COUNTER_WINDOW_MAX_OOO], } @@ -216,7 +216,7 @@ impl Context { state.remote_session_id, 0, 2, - Some(&mut *session.get_header_cipher()), + Some(&session.header_protection_cipher), ); } false @@ -314,7 +314,7 @@ impl Context { psk, send_counter: AtomicU64::new(3), // 1 and 2 are reserved for init and final ack receive_window: std::array::from_fn(|_| AtomicU64::new(0)), - header_protection_cipher: Mutex::new(Aes::new(&header_protection_key)), + header_protection_cipher: Aes::new(&header_protection_key), state: RwLock::new(State { remote_session_id: None, keys: [None, None], @@ -443,7 +443,7 @@ impl Context { debug_assert!(!self.sessions.read().unwrap().incoming.contains_key(&local_session_id)); session - .get_header_cipher() + .header_protection_cipher .decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); let (key_index, packet_type, fragment_count, fragment_no, incoming_counter) = parse_packet_header(&incoming_packet); @@ -834,7 +834,7 @@ impl Context { Some(alice_session_id), 0, 1, - Some(&mut Aes::new(&header_protection_key)), + Some(&Aes::new(&header_protection_key)), )?; return Ok(ReceiveResult::Ok(session)); @@ -979,7 +979,7 @@ impl Context { Some(bob_session_id), 0, 2, - Some(&mut *session.get_header_cipher()), + Some(&session.header_protection_cipher), )?; return Ok(ReceiveResult::Ok(Some(session))); @@ -1076,7 +1076,7 @@ impl Context { psk, send_counter: AtomicU64::new(2), // 1 was already used during negotiation receive_window: std::array::from_fn(|_| AtomicU64::new(0)), - header_protection_cipher: Mutex::new(Aes::new(&incoming.header_protection_key)), + header_protection_cipher: Aes::new(&incoming.header_protection_key), state: RwLock::new(State { remote_session_id: Some(incoming.alice_session_id), keys: [ @@ -1165,7 +1165,7 @@ impl Context { drop(c); session - .get_header_cipher() + .header_protection_cipher .encrypt_block_in_place(&mut reply_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(Some(&session), &mut reply_buf); @@ -1314,7 +1314,7 @@ impl Session { fragment_size = tagged_fragment_size; } - self.get_header_cipher() + self.header_protection_cipher .encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(&mut mtu_sized_buffer[..fragment_size]); } @@ -1340,7 +1340,7 @@ impl Session { nop[HEADER_SIZE..].copy_from_slice(&c.finish_encrypt()); drop(c); set_packet_header(&mut nop, 1, 0, PACKET_TYPE_NOP, u64::from(remote_session_id), state.current_key, counter); - self.get_header_cipher() + self.header_protection_cipher .encrypt_block_in_place(&mut nop[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(&mut nop); } @@ -1402,7 +1402,7 @@ impl Session { //drop(gcm); //drop(state); - self.get_header_cipher() + self.header_protection_cipher .encrypt_block_in_place(&mut rekey_buf[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); send(&mut rekey_buf); @@ -1433,11 +1433,6 @@ impl Session { let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel); prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD } - - #[inline(always)] - fn get_header_cipher<'a>(&'a self) -> MutexGuard<'a, Aes>{ - self.header_protection_cipher.lock().unwrap() - } } #[inline(always)] @@ -1500,7 +1495,7 @@ fn send_with_fragmentation( remote_session_id: Option, key_index: usize, counter: u64, - mut header_protect_cipher: Option<&mut Aes>, + header_protect_cipher: Option<&Aes>, ) -> Result<(), Error> { let packet_len = packet.len(); let recipient_session_id = remote_session_id.map_or(SessionId::NONE, |s| u64::from(s)); @@ -1518,7 +1513,7 @@ fn send_with_fragmentation( key_index, counter, ); - if let Some(hcc) = &mut header_protect_cipher { + if let Some(hcc) = header_protect_cipher { hcc.encrypt_block_in_place(&mut fragment[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]); } send(fragment);