From 7fbd8350c7e3cdc14b73f24dd23b5d60481943fd Mon Sep 17 00:00:00 2001 From: mamoniot Date: Wed, 22 Mar 2023 00:31:40 -0400 Subject: [PATCH] tested threading --- zssp/src/fragged.rs | 81 +++++++++++++++++++++++++++------------------ zssp/src/zssp.rs | 8 ++--- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/zssp/src/fragged.rs b/zssp/src/fragged.rs index 01a0c3b20..a1d1e8c9e 100644 --- a/zssp/src/fragged.rs +++ b/zssp/src/fragged.rs @@ -1,11 +1,14 @@ +use std::cell::UnsafeCell; use std::mem::{needs_drop, size_of, zeroed, MaybeUninit}; use std::ptr::slice_from_raw_parts; +use std::sync::RwLock; +use std::sync::atomic::{AtomicU64, Ordering}; /// Fast packet defragmenter pub struct Fragged { - have: u64, - counter: u64, - frags: [MaybeUninit; MAX_FRAGMENTS], + have: AtomicU64, + counter: RwLock, + frags: UnsafeCell<[MaybeUninit; MAX_FRAGMENTS]>, } pub struct Assembled([MaybeUninit; MAX_FRAGMENTS], usize); @@ -49,42 +52,56 @@ impl Fragged { /// When a fully assembled packet is returned the internal state is reset and this object can /// be reused to assemble another packet. #[inline(always)] - pub fn assemble(&mut self, counter: u64, fragment: Fragment, fragment_no: u8, fragment_count: u8) -> Option> { + pub fn assemble(&self, counter: u64, fragment: Fragment, fragment_no: u8, fragment_count: u8) -> Option> { if fragment_no < fragment_count && (fragment_count as usize) <= MAX_FRAGMENTS { - let mut have = self.have; + let r = self.counter.read().unwrap(); + let cur_counter = *r; + let mut r_guard = Some(r); + let mut w_guard = None; // If the counter has changed, reset the structure to receive a new packet. - if counter != self.counter { - self.counter = counter; - if needs_drop::() { - let mut i = 0; - while have != 0 { - if (have & 1) != 0 { - debug_assert!(i < MAX_FRAGMENTS); - unsafe { self.frags.get_unchecked_mut(i).assume_init_drop() }; + if counter != cur_counter { + drop(r_guard.take()); + let mut w = self.counter.write().unwrap(); + if *w != counter { + *w = counter; + if needs_drop::() { + let mut have = self.have.load(Ordering::Relaxed); + let mut i = 0; + while have != 0 { + if (have & 1) != 0 { + debug_assert!(i < MAX_FRAGMENTS); + unsafe { (*self.frags.get()).get_unchecked_mut(i).assume_init_drop() }; + } + have = have.wrapping_shr(1); + i += 1; } - have = have.wrapping_shr(1); - i += 1; } - } else { - have = 0; + self.have.store(0, Ordering::Relaxed); } - } - - unsafe { - self.frags.get_unchecked_mut(fragment_no as usize).write(fragment); + w_guard = Some(w); } let want = 0xffffffffffffffffu64.wrapping_shr((64 - fragment_count) as u32); - have |= 1u64.wrapping_shl(fragment_no as u32); - if (have & want) == want { - self.have = 0; - // Setting 'have' to 0 resets the state of this object, and the fragments - // are effectively moved into the Assembled<> container and returned. That - // container will drop them when it is dropped. - return Some(Assembled(unsafe { std::mem::transmute_copy(&self.frags) }, fragment_count as usize)); - } else { - self.have = have; + let got = 1u64.wrapping_shl(fragment_no as u32); + let have = self.have.fetch_or(got, Ordering::Relaxed); + + if have & got == 0 { + unsafe { + (*self.frags.get()).get_unchecked_mut(fragment_no as usize).write(fragment); + } + if ((have | got) & want) == want { + drop(r_guard.take()); + let mut w = w_guard.unwrap_or_else(|| self.counter.write().unwrap()); + if *w == counter { + *w = 0; + self.have.store(0, Ordering::Relaxed); + // Setting 'have' to 0 resets the state of this object, and the fragments + // are effectively moved into the Assembled<> container and returned. That + // container will drop them when it is dropped. + return Some(Assembled(unsafe { std::mem::transmute_copy(&self.frags) }, fragment_count as usize)); + } + } } } return None; @@ -95,12 +112,12 @@ impl Drop for Fragged() { - let mut have = self.have; + let mut have = self.have.load(Ordering::Relaxed); let mut i = 0; while have != 0 { if (have & 1) != 0 { debug_assert!(i < MAX_FRAGMENTS); - unsafe { self.frags.get_unchecked_mut(i).assume_init_drop() }; + unsafe { (*self.frags.get()).get_unchecked_mut(i).assume_init_drop() }; } have = have.wrapping_shr(1); i += 1; diff --git a/zssp/src/zssp.rs b/zssp/src/zssp.rs index 660b3e65d..0d1f16cc4 100644 --- a/zssp/src/zssp.rs +++ b/zssp/src/zssp.rs @@ -90,7 +90,7 @@ pub struct Session { receive_window: [AtomicU64; COUNTER_WINDOW_MAX_OOO], header_protection_cipher: Aes, state: RwLock, - defrag: [Mutex>; COUNTER_WINDOW_MAX_OOO], + defrag: [Fragged; COUNTER_WINDOW_MAX_OOO], } /// Most of the mutable parts of a session state. @@ -343,7 +343,7 @@ impl Context { init_packet: [0u8; AliceNoiseXKInit::SIZE], })), }), - defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), + defrag: std::array::from_fn(|_| Fragged::new()), }); sessions.active.insert(local_session_id, Arc::downgrade(&session)); @@ -461,8 +461,6 @@ impl Context { let (assembled_packet, incoming_packet_buf_arr); let incoming_packet = if fragment_count > 1 { assembled_packet = session.defrag[(incoming_counter as usize) % COUNTER_WINDOW_MAX_OOO] - .lock() - .unwrap() .assemble(incoming_counter, incoming_physical_packet_buf, fragment_no, fragment_count); if let Some(assembled_packet) = assembled_packet.as_ref() { assembled_packet.as_ref() @@ -1102,7 +1100,7 @@ impl Context { current_key: 0, outgoing_offer: Offer::None, }), - defrag: std::array::from_fn(|_| Mutex::new(Fragged::new())), + defrag: std::array::from_fn(|_| Fragged::new()), }); // Promote incoming session to active.