ZSSP basically works...

This commit is contained in:
Adam Ierymenko 2023-02-28 17:52:18 -05:00
commit c7388879ee
2 changed files with 48 additions and 9 deletions

View file

@ -55,6 +55,9 @@ fn alice_main(
println!("[alice] opening session {}", alice_session.id.to_string());
let test_data = [1u8; 10000];
let mut up = false;
while run.load(Ordering::Relaxed) {
let pkt = alice_in.try_recv();
let current_time = ms_monotonic();
@ -74,10 +77,10 @@ fn alice_main(
current_time,
) {
Ok(zssp::ReceiveResult::Ok) => {
println!("[alice] ok");
//println!("[alice] ok");
}
Ok(zssp::ReceiveResult::OkData(_, data)) => {
println!("[alice] received {}", data.len());
Ok(zssp::ReceiveResult::OkData(_, _)) => {
//println!("[alice] received {}", data.len());
}
Ok(zssp::ReceiveResult::OkNewSession(s)) => {
println!("[alice] new session {}", s.id.to_string());
@ -89,6 +92,22 @@ fn alice_main(
}
}
if up {
assert!(alice_session
.send(
|b| {
let _ = alice_out.send(b.to_vec());
},
&mut data_buf[..TEST_MTU],
&test_data[..2048 + ((zerotier_crypto::random::xorshift64_random() as usize) % (test_data.len() - 2048))],
)
.is_ok());
} else {
if alice_session.established() {
up = true;
}
}
if current_time >= next_service {
next_service = current_time
+ context.service(
@ -111,7 +130,11 @@ fn bob_main(
) {
let context = zssp::Context::<TestApplication>::new(16);
let mut data_buf = [0u8; 65536];
let mut next_service = ms_monotonic() + 500;
let mut last_speed_metric = ms_monotonic();
let mut next_service = last_speed_metric + 500;
let mut transferred = 0u64;
let mut bob_session = None;
while run.load(Ordering::Relaxed) {
let pkt = bob_in.recv_timeout(Duration::from_millis(10));
@ -132,13 +155,15 @@ fn bob_main(
current_time,
) {
Ok(zssp::ReceiveResult::Ok) => {
println!("[bob] ok");
//println!("[bob] ok");
}
Ok(zssp::ReceiveResult::OkData(_, data)) => {
println!("[bob] received {}", data.len());
//println!("[bob] received {}", data.len());
transferred += data.len() as u64;
}
Ok(zssp::ReceiveResult::OkNewSession(s)) => {
println!("[bob] new session {}", s.id.to_string());
let _ = bob_session.replace(s);
}
Ok(zssp::ReceiveResult::Rejected) => {}
Err(e) => {
@ -147,6 +172,16 @@ fn bob_main(
}
}
let speed_metric_elapsed = current_time - last_speed_metric;
if speed_metric_elapsed >= 1000 {
last_speed_metric = current_time;
println!(
"[bob] RX speed {} MiB/sec",
((transferred as f64) / 1048576.0) / ((speed_metric_elapsed as f64) / 1000.0)
);
transferred = 0;
}
if current_time >= next_service {
next_service = current_time
+ context.service(

View file

@ -465,7 +465,6 @@ impl<Application: ApplicationLayer> Context<Application> {
.decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
incoming = Some(i);
} else {
println!("unknown {}", local_session_id.to_string());
return Err(Error::UnknownLocalSessionId);
}
}
@ -928,7 +927,7 @@ impl<Application: ApplicationLayer> Context<Application> {
let mut state = session.state.write().unwrap();
let _ = state.remote_session_id.insert(bob_session_id);
let _ =
state.keys[0].insert(SessionKey::new::<Application>(noise_es_ee_se_hk_psk, current_time, 2, true));
state.keys[0].insert(SessionKey::new::<Application>(noise_es_ee_se_hk_psk, current_time, 2, false));
state.current_key = 0;
state.current_offer = Offer::None;
}
@ -1244,7 +1243,7 @@ impl<Application: ApplicationLayer> Session<Application> {
let state = self.state.read().unwrap();
if let Some(remote_session_id) = state.remote_session_id {
if let Some(session_key) = state.keys[state.current_key].as_ref() {
let counter = self.send_counter.fetch_add(1, Ordering::SeqCst);
let counter = self.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?.get();
let mut c = session_key.get_send_cipher(counter)?;
c.reset_init_gcm(&create_message_nonce(PACKET_TYPE_DATA, counter));
@ -1253,9 +1252,11 @@ impl<Application: ApplicationLayer> Session<Application> {
(((data.len() + AES_GCM_TAG_SIZE) as f32) / (mtu_sized_buffer.len() - HEADER_SIZE) as f32).ceil() as usize;
let fragment_max_chunk_size = mtu_sized_buffer.len() - HEADER_SIZE;
let last_fragment_no = fragment_count - 1;
for fragment_no in 0..fragment_count {
let chunk_size = fragment_max_chunk_size.min(data.len());
let mut fragment_size = chunk_size + HEADER_SIZE;
set_packet_header(
mtu_sized_buffer,
fragment_count,
@ -1265,14 +1266,17 @@ impl<Application: ApplicationLayer> Session<Application> {
state.current_key,
counter,
);
c.crypt(&data[..chunk_size], &mut mtu_sized_buffer[HEADER_SIZE..fragment_size]);
data = &data[chunk_size..];
if fragment_no == last_fragment_no {
debug_assert!(data.is_empty());
let tagged_fragment_size = fragment_size + AES_GCM_TAG_SIZE;
mtu_sized_buffer[fragment_size..tagged_fragment_size].copy_from_slice(&c.finish_encrypt());
fragment_size = tagged_fragment_size;
}
self.header_protection_cipher
.encrypt_block_in_place(&mut mtu_sized_buffer[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
send(&mut mtu_sized_buffer[..fragment_size]);