Add automatic retransmission in the earliest stages of session init.

This commit is contained in:
Adam Ierymenko 2023-02-27 17:52:49 -05:00
commit 5143aa6a4e
4 changed files with 332 additions and 205 deletions

View file

@ -49,6 +49,12 @@ pub trait ApplicationLayer: Sized {
/// over very long distances.
const INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS: i64 = 2000;
/// Retry interval for outgoing connection initiation or rekey attempts.
///
/// Retry attepmpts will be no more often than this, but the delay may end up being slightly more
/// in some cases depending on where in the cycle the initial attempt falls.
const RETRY_INTERVAL: i64 = 500;
/// Type for arbitrary opaque object for use by the application that is attached to each session.
type Data;

View file

@ -15,6 +15,6 @@ mod zssp;
pub use crate::applicationlayer::ApplicationLayer;
pub use crate::error::Error;
pub use crate::proto::{MAX_METADATA_SIZE, MIN_PACKET_SIZE, MIN_TRANSPORT_MTU};
pub use crate::proto::{MAX_INIT_PAYLOAD_SIZE, MIN_PACKET_SIZE, MIN_TRANSPORT_MTU};
pub use crate::sessionid::SessionId;
pub use crate::zssp::{Context, ReceiveResult, Session};

View file

@ -22,8 +22,8 @@ pub const MIN_PACKET_SIZE: usize = HEADER_SIZE + AES_GCM_TAG_SIZE;
/// Minimum physical MTU for ZSSP to function.
pub const MIN_TRANSPORT_MTU: usize = 128;
/// Maximum size of init meta-data objects.
pub const MAX_METADATA_SIZE: usize = 256;
/// Maximum combined size of static public blob and metadata.
pub const MAX_INIT_PAYLOAD_SIZE: usize = MAX_NOISE_HANDSHAKE_SIZE - ALICE_NOISE_XK_ACK_MIN_SIZE;
pub(crate) const SESSION_PROTOCOL_VERSION: u8 = 0x00;

View file

@ -11,11 +11,11 @@
use std::collections::HashMap;
use std::num::NonZeroU64;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock, Weak};
use zerotier_crypto::aes::{Aes, AesCtr, AesGcm};
use zerotier_crypto::hash::{hmac_sha512, HMACSHA384, HMAC_SHA384_SIZE, SHA384};
use zerotier_crypto::hash::{hmac_sha512, HMACSHA384, HMAC_SHA384_SIZE, SHA384, SHA384_HASH_SIZE};
use zerotier_crypto::p384::{P384KeyPair, P384PublicKey, P384_ECDH_SHARED_SECRET_SIZE, P384_PUBLIC_KEY_SIZE};
use zerotier_crypto::secret::Secret;
use zerotier_crypto::{random, secure_eq};
@ -25,7 +25,7 @@ use zerotier_utils::gatherarray::GatherArray;
use zerotier_utils::memory;
use zerotier_utils::ringbuffermap::RingBufferMap;
use pqc_kyber::{KYBER_SECRETKEYBYTES, KYBER_SSBYTES};
use pqc_kyber::{KYBER_CIPHERTEXTBYTES, KYBER_SECRETKEYBYTES, KYBER_SSBYTES};
use crate::applicationlayer::ApplicationLayer;
use crate::error::Error;
@ -43,6 +43,15 @@ pub struct Context<Application: ApplicationLayer> {
sessions: RwLock<SessionMaps<Application>>,
}
/// Lookup maps for sessions within a session context.
struct SessionMaps<Application: ApplicationLayer> {
// Active sessions, automatically closed if the application no longer holds their Arc<>.
active: HashMap<SessionId, Weak<Session<Application>>>,
// Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout.
incomplete: HashMap<SessionId, Arc<IncompleteIncomingSession>>,
}
/// Result generated by the context packet receive function, with possible payloads.
pub enum ReceiveResult<'b, Application: ApplicationLayer> {
/// Packet was valid, but no action needs to be taken.
@ -76,45 +85,45 @@ pub struct Session<Application: ApplicationLayer> {
defrag: Mutex<RingBufferMap<u64, GatherArray<Application::IncomingPacketBuffer, MAX_FRAGMENTS>, 16, 16>>,
}
struct SessionMaps<Application: ApplicationLayer> {
// Active sessions, automatically closed if the application no longer holds their Arc<>.
active: HashMap<SessionId, Weak<Session<Application>>>,
// Incomplete sessions in the middle of three-phase Noise_XK negotiation, expired after timeout.
incomplete: HashMap<SessionId, Arc<NoiseXKIncoming>>,
/// Most of the mutable parts of a session state.
struct State {
remote_session_id: Option<SessionId>,
keys: [Option<SessionKey>; 2],
current_key: usize,
current_offer: Offer,
}
struct NoiseXKIncoming {
/// State related to an incoming session not yet fully established.
struct IncompleteIncomingSession {
timestamp: i64,
request_hash: [u8; SHA384_HASH_SIZE],
alice_session_id: SessionId,
bob_session_id: SessionId,
noise_es_ee: Secret<BASE_KEY_SIZE>,
bob_hk_ciphertext: [u8; KYBER_CIPHERTEXTBYTES],
hk: Secret<KYBER_SSBYTES>,
header_protection_key: Secret<AES_HEADER_PROTECTION_KEY_SIZE>,
bob_noise_e_secret: P384KeyPair,
}
struct NoiseXKOutgoing {
timestamp: i64,
/// State related to an outgoing session attempt.
struct OutgoingSessionInit {
last_retry_time: AtomicI64,
alice_noise_e_secret: P384KeyPair,
noise_es: Secret<P384_ECDH_SHARED_SECRET_SIZE>,
alice_hk_secret: Secret<KYBER_SECRETKEYBYTES>,
metadata: Option<ArrayVec<u8, MAX_METADATA_SIZE>>,
metadata: Option<ArrayVec<u8, MAX_INIT_PAYLOAD_SIZE>>,
init_packet: [u8; AliceNoiseXKInit::SIZE],
}
enum EphemeralOffer {
/// Latest outgoing offer, either an outgoing attempt or a rekey attempt.
enum Offer {
None,
NoiseXKInit(Box<NoiseXKOutgoing>),
RekeyInit(P384KeyPair),
}
struct State {
remote_session_id: Option<SessionId>,
keys: [Option<SessionKey>; 2],
current_key: usize,
offer: EphemeralOffer,
NoiseXKInit(Box<OutgoingSessionInit>),
RekeyInit(P384KeyPair, [u8; AliceRekeyInit::SIZE], AtomicI64),
}
/// An ephemeral session key with expiration info.
struct SessionKey {
ratchet_key: Secret<BASE_KEY_SIZE>, // Key used in derivation of the next session key
receive_key: Secret<AES_KEY_SIZE>, // Receive side AES-GCM key
@ -126,12 +135,12 @@ struct SessionKey {
rekey_at_counter: u64, // Rekey at or after this counter
expire_at_counter: u64, // Hard error when this counter value is reached or exceeded
confirmed: bool, // We have confirmed that the other side has this key
role_is_bob: bool, // Was this side "Bob" in this exchange?
bob: bool, // Was this side "Bob" in this exchange?
}
impl<Application: ApplicationLayer> Context<Application> {
/// Create a new session context.
pub fn new(_: &Application, max_incomplete_session_queue_size: usize) -> Self {
pub fn new(max_incomplete_session_queue_size: usize) -> Self {
Self {
max_incomplete_session_queue_size,
initial_offer_defrag: Mutex::new(RingBufferMap::new(random::next_u32_secure())),
@ -145,28 +154,67 @@ impl<Application: ApplicationLayer> Context<Application> {
/// Perform periodic background service and cleanup tasks.
///
/// This returns the number of milliseconds until it should be called again.
pub fn service<SendFunction: FnMut(&Arc<Session<Application>>, &mut [u8])>(&self, mut send: SendFunction, current_time: i64) -> i64 {
///
/// * `send` - Function to send packets to remote sessions
/// * `mtu` - Physical MTU
/// * `current_time` - Current monotonic time in milliseconds
pub fn service<SendFunction: FnMut(&Arc<Session<Application>>, &mut [u8])>(
&self,
mut send: SendFunction,
mtu: usize,
current_time: i64,
) -> i64 {
let mut dead_active = Vec::new();
let mut dead_pending = Vec::new();
let retry_cutoff = current_time - Application::RETRY_INTERVAL;
let negotiation_timeout_cutoff = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
{
let sessions = self.sessions.read().unwrap();
for (id, s) in sessions.active.iter() {
if let Some(session) = s.upgrade() {
let state = session.state.read().unwrap();
if let Some(key) = state.keys[state.current_key].as_ref() {
if key.role_is_bob
&& (current_time >= key.rekey_at_time || session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter)
{
session.send_rekey(|b| send(&session, b));
match &state.current_offer {
Offer::None => {
if let Some(key) = state.keys[state.current_key].as_ref() {
if key.bob
&& (current_time >= key.rekey_at_time
|| session.send_counter.load(Ordering::Relaxed) >= key.rekey_at_counter)
{
session.initiate_rekey(|b| send(&session, b), current_time);
}
}
}
Offer::NoiseXKInit(offer) => {
if offer.last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
offer.last_retry_time.store(current_time, Ordering::Relaxed);
let _ = send_with_fragmentation(
|b| send(&session, b),
&mut (offer.init_packet.clone()),
mtu,
PACKET_TYPE_ALICE_NOISE_XK_INIT,
None,
0,
1,
None,
);
}
}
Offer::RekeyInit(_, rekey_packet, last_retry_time) => {
if last_retry_time.load(Ordering::Relaxed) < retry_cutoff {
last_retry_time.store(current_time, Ordering::Relaxed);
send(&session, &mut (rekey_packet.clone()));
}
}
}
} else {
dead_active.push(*id);
}
}
for (id, p) in sessions.incomplete.iter() {
if (p.timestamp - current_time) > Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS {
for (id, incomplete) in sessions.incomplete.iter() {
if incomplete.timestamp < negotiation_timeout_cutoff {
dead_pending.push(*id);
}
}
@ -182,20 +230,22 @@ impl<Application: ApplicationLayer> Context<Application> {
}
}
Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS * 2
Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS.min(Application::RETRY_INTERVAL)
}
/// Create a new session and send initial packet(s) to other side.
///
/// This will return Error::DataTooLarge if the combined size of the metadata and the local static public
/// blob (as retrieved from the application layer) exceed MAX_INIT_PAYLOAD_SIZE.
///
/// * `app` - Application layer instance
/// * `send` - User-supplied packet sending function
/// * `mtu` - Physical MTU for calls to send()
/// * `local_session_id` - This side's session ID
/// * `remote_s_public_p384` - Remote side's static public NIST P-384 key
/// * `psk` - Pre-shared key (use all zero if none)
/// * `metadata` - Optional metadata to be included in initial handshake
/// * `application_data` - Arbitrary opaque data to include with session object
#[allow(unused_variables)]
/// * `current_time` - Current monotonic time in milliseconds
pub fn open<SendFunction: FnMut(&mut [u8])>(
&self,
app: &Application,
@ -207,10 +257,8 @@ impl<Application: ApplicationLayer> Context<Application> {
application_data: Application::Data,
current_time: i64,
) -> Result<Arc<Session<Application>>, Error> {
if let Some(md) = metadata.as_ref() {
if md.len() > MAX_METADATA_SIZE {
return Err(Error::DataTooLarge);
}
if (metadata.map(|md| md.len()).unwrap_or(0) + app.get_local_s_public_blob().len()) > MAX_INIT_PAYLOAD_SIZE {
return Err(Error::DataTooLarge);
}
let alice_noise_e_secret = P384KeyPair::generate();
@ -234,19 +282,20 @@ impl<Application: ApplicationLayer> Context<Application> {
id: local_session_id,
application_data,
psk,
send_counter: AtomicU64::new(2), // 1 is the counter value for this INIT message
send_counter: AtomicU64::new(3), // 1 and 2 are reserved for init and final ack
receive_window: std::array::from_fn(|_| AtomicU64::new(0)),
header_protection_cipher: Aes::new(&header_protection_key),
state: RwLock::new(State {
remote_session_id: None,
keys: [None, None],
current_key: 0,
offer: EphemeralOffer::NoiseXKInit(Box::new(NoiseXKOutgoing {
timestamp: current_time,
current_offer: Offer::NoiseXKInit(Box::new(OutgoingSessionInit {
last_retry_time: AtomicI64::new(current_time),
alice_noise_e_secret,
noise_es: noise_es.clone(),
alice_hk_secret: Secret(alice_hk_secret.secret),
metadata: metadata.map(|md| ArrayVec::try_from(md).unwrap()),
init_packet: [0u8; AliceNoiseXKInit::SIZE],
})),
}),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
@ -257,26 +306,43 @@ impl<Application: ApplicationLayer> Context<Application> {
(local_session_id, session)
};
let mut init_buffer = [0u8; AliceNoiseXKInit::SIZE];
let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(&mut init_buffer).unwrap();
init.session_protocol_version = SESSION_PROTOCOL_VERSION;
init.alice_noise_e = alice_noise_e;
init.alice_session_id = *local_session_id.as_bytes();
init.alice_hk_public = alice_hk_secret.public;
init.header_protection_key = header_protection_key;
{
let mut state = session.state.write().unwrap();
let init_packet = if let Offer::NoiseXKInit(offer) = &mut state.current_offer {
&mut offer.init_packet
} else {
panic!();
};
let mut ctr = AesCtr::new(kbkdf::<AES_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es.as_bytes()).as_bytes());
ctr.reset_set_iv(&alice_noise_e[P384_PUBLIC_KEY_SIZE - AES_CTR_NONCE_SIZE..]);
ctr.crypt_in_place(&mut init_buffer[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]);
let init: &mut AliceNoiseXKInit = byte_array_as_proto_buffer_mut(init_packet).unwrap();
init.session_protocol_version = SESSION_PROTOCOL_VERSION;
init.alice_noise_e = alice_noise_e;
init.alice_session_id = *local_session_id.as_bytes();
init.alice_hk_public = alice_hk_secret.public;
init.header_protection_key = header_protection_key;
let hmac = hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es.as_bytes()).as_bytes(),
&create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_INIT, 1),
&init_buffer[HEADER_SIZE..AliceNoiseXKInit::AUTH_START],
);
init_buffer[AliceNoiseXKInit::AUTH_START..AliceNoiseXKInit::AUTH_START + HMAC_SHA384_SIZE].copy_from_slice(&hmac);
let mut ctr = AesCtr::new(kbkdf::<AES_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es.as_bytes()).as_bytes());
ctr.reset_set_iv(&alice_noise_e[P384_PUBLIC_KEY_SIZE - AES_CTR_NONCE_SIZE..]);
ctr.crypt_in_place(&mut init_packet[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]);
send_with_fragmentation(&mut send, &mut init_buffer, mtu, PACKET_TYPE_ALICE_NOISE_XK_INIT, None, 0, 1, None)?;
let hmac = hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es.as_bytes()).as_bytes(),
&create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_INIT, 1),
&init_packet[HEADER_SIZE..AliceNoiseXKInit::AUTH_START],
);
init_packet[AliceNoiseXKInit::AUTH_START..AliceNoiseXKInit::AUTH_START + HMAC_SHA384_SIZE].copy_from_slice(&hmac);
send_with_fragmentation(
&mut send,
&mut (init_packet.clone()),
mtu,
PACKET_TYPE_ALICE_NOISE_XK_INIT,
None,
0,
1,
None,
)?;
}
return Ok(session);
}
@ -308,7 +374,6 @@ impl<Application: ApplicationLayer> Context<Application> {
/// * `incoming_packet_buf` - Buffer containing incoming wire packet (receive() takes ownership)
/// * `mtu` - Physical wire MTU for sending packets
/// * `current_time` - Current monotonic time in milliseconds
#[inline]
pub fn receive<
'b,
SendFunction: FnMut(Option<&Arc<Session<Application>>>, &mut [u8]),
@ -332,7 +397,16 @@ impl<Application: ApplicationLayer> Context<Application> {
let mut incomplete = None;
if let Some(local_session_id) = SessionId::new_from_u64_le(memory::load_raw(incoming_packet)) {
if let Some(session) = self.look_up_session(local_session_id) {
if let Some(session) = self
.sessions
.read()
.unwrap()
.active
.get(&local_session_id)
.and_then(|s| s.upgrade())
{
debug_assert!(self.sessions.read().unwrap().incomplete.contains_key(&local_session_id));
session
.header_protection_cipher
.decrypt_block_in_place(&mut incoming_packet[HEADER_PROTECT_ENCRYPT_START..HEADER_PROTECT_ENCRYPT_END]);
@ -455,18 +529,18 @@ impl<Application: ApplicationLayer> Context<Application> {
check_allow_incoming_session: &mut CheckAllowIncomingSession,
check_accept_session: &mut CheckAcceptSession,
data_buf: &'b mut [u8],
counter: u64,
incoming_counter: u64,
fragments: &[Application::IncomingPacketBuffer],
packet_type: u8,
session: Option<Arc<Session<Application>>>,
incomplete: Option<Arc<NoiseXKIncoming>>,
incomplete: Option<Arc<IncompleteIncomingSession>>,
key_index: usize,
mtu: usize,
current_time: i64,
) -> Result<ReceiveResult<'b, Application>, Error> {
debug_assert!(fragments.len() >= 1);
let incoming_message_nonce = create_message_nonce(packet_type, counter);
let incoming_message_nonce = create_message_nonce(packet_type, incoming_counter);
if packet_type == PACKET_TYPE_DATA {
if let Some(session) = session {
let state = session.state.read().unwrap();
@ -510,7 +584,7 @@ impl<Application: ApplicationLayer> Context<Application> {
key.return_receive_cipher(c);
if aead_authentication_ok {
if session.update_receive_window(counter) {
if session.update_receive_window(incoming_counter) {
// If the packet authenticated, this confirms that the other side indeed
// knows this session key. In that case mark the session key as confirmed
// and if the current active key is older switch it to point to this one.
@ -575,132 +649,166 @@ impl<Application: ApplicationLayer> Context<Application> {
* to the current exchange.
*/
if session.is_some() || incomplete.is_some() || counter != 1 {
if incoming_counter != 1 || session.is_some() {
return Err(Error::OutOfSequence);
}
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_noise_e = P384PublicKey::from_bytes(&pkt.alice_noise_e).ok_or(Error::FailedAuthentication)?;
let noise_es = app.get_local_s_keypair().agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?;
// Hash the init packet so we can check to see if it's just being retransmitted.
let request_hash = SHA384::hash(&pkt_assembled);
// Authenticate packet and prove that Alice knows our static public key.
if !secure_eq(
&pkt.hmac_es,
&hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es.as_bytes()).as_bytes(),
&incoming_message_nonce,
&pkt_assembled[HEADER_SIZE..AliceNoiseXKInit::AUTH_START],
),
) {
return Err(Error::FailedAuthentication);
}
let (alice_session_id, bob_session_id, noise_es_ee, bob_hk_ciphertext, header_protection_key, bob_noise_e) =
if let Some(incomplete) = incomplete {
// If we already have an incoming incomplete session record and the hash matches, recall the
// previous state so we can send an identical reply in response to a retransmit.
// Let application filter incoming connection attempt by whatever criteria it wants.
if !check_allow_incoming_session() {
return Ok(ReceiveResult::Rejected);
}
// Decrypt encrypted part of payload.
let mut ctr = AesCtr::new(kbkdf::<AES_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es.as_bytes()).as_bytes());
ctr.reset_set_iv(&SHA384::hash(&pkt.alice_noise_e)[..AES_CTR_NONCE_SIZE]);
ctr.crypt_in_place(&mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]);
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?;
// Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create
// a Kyber ciphertext to send back to Alice.
let bob_noise_e_secret = P384KeyPair::generate();
let bob_noise_e = bob_noise_e_secret.public_key_bytes().clone();
let noise_es_ee = Secret(hmac_sha512(
noise_es.as_bytes(),
bob_noise_e_secret
.agree(&alice_noise_e)
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
));
let (bob_hk_ciphertext, hk) = pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default())
.map_err(|_| Error::FailedAuthentication)
.map(|(ct, hk)| (ct, Secret(hk)))?;
// Pick a session ID for our side and save the intermediate ephemeral state for this exchange.
let bob_session_id = {
let mut sessions = self.sessions.write().unwrap();
let mut bob_session_id;
loop {
bob_session_id = SessionId::random();
if !sessions.active.contains_key(&bob_session_id) && !sessions.incomplete.contains_key(&bob_session_id) {
break;
if secure_eq(&request_hash, &incomplete.request_hash) {
(
incomplete.alice_session_id,
incomplete.bob_session_id,
incomplete.noise_es_ee.clone(),
incomplete.bob_hk_ciphertext,
incomplete.header_protection_key.clone(),
*incomplete.bob_noise_e_secret.public_key_bytes(),
)
} else {
return Err(Error::FailedAuthentication);
}
}
} else {
// Otherwise parse the packet, authenticate, generate keys, etc. and record state in an
// incomplete state object until this phase of the negotiation is done.
if sessions.incomplete.len() >= self.max_incomplete_session_queue_size {
// If this queue is too big, we remove the latest entry and replace it. The latest
// is used because under flood conditions this is most likely to be another bogus
// entry. If we find one that is actually timed out, that one is replaced instead.
let mut newest = i64::MIN;
let mut replace_id = None;
let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
for (id, s) in sessions.incomplete.iter() {
if s.timestamp <= cutoff_time {
replace_id = Some(*id);
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_noise_e = P384PublicKey::from_bytes(&pkt.alice_noise_e).ok_or(Error::FailedAuthentication)?;
let noise_es = app.get_local_s_keypair().agree(&alice_noise_e).ok_or(Error::FailedAuthentication)?;
// Authenticate packet and also prove that Alice knows our static public key.
if !secure_eq(
&pkt.hmac_es,
&hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es.as_bytes()).as_bytes(),
&incoming_message_nonce,
&pkt_assembled[HEADER_SIZE..AliceNoiseXKInit::AUTH_START],
),
) {
return Err(Error::FailedAuthentication);
}
// Let application filter incoming connection attempt by whatever criteria it wants.
if !check_allow_incoming_session() {
return Ok(ReceiveResult::Rejected);
}
// Decrypt encrypted part of payload.
let mut ctr =
AesCtr::new(kbkdf::<AES_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es.as_bytes()).as_bytes());
ctr.reset_set_iv(&SHA384::hash(&pkt.alice_noise_e)[..AES_CTR_NONCE_SIZE]);
ctr.crypt_in_place(&mut pkt_assembled[AliceNoiseXKInit::ENC_START..AliceNoiseXKInit::AUTH_START]);
let pkt: &AliceNoiseXKInit = byte_array_as_proto_buffer(pkt_assembled)?;
let alice_session_id = SessionId::new_from_bytes(&pkt.alice_session_id).ok_or(Error::InvalidPacket)?;
// Create Bob's ephemeral keys and derive noise_es_ee by agreeing with Alice's. Also create
// a Kyber ciphertext to send back to Alice.
let bob_noise_e_secret = P384KeyPair::generate();
let bob_noise_e = bob_noise_e_secret.public_key_bytes().clone();
let noise_es_ee = Secret(hmac_sha512(
noise_es.as_bytes(),
bob_noise_e_secret
.agree(&alice_noise_e)
.ok_or(Error::FailedAuthentication)?
.as_bytes(),
));
let (bob_hk_ciphertext, hk) =
pqc_kyber::encapsulate(&pkt.alice_hk_public, &mut random::SecureRandom::default())
.map_err(|_| Error::FailedAuthentication)
.map(|(ct, hk)| (ct, Secret(hk)))?;
let mut sessions = self.sessions.write().unwrap();
let mut bob_session_id;
loop {
bob_session_id = SessionId::random();
if !sessions.active.contains_key(&bob_session_id) && !sessions.incomplete.contains_key(&bob_session_id) {
break;
} else if s.timestamp >= newest {
newest = s.timestamp;
replace_id = Some(*id);
}
}
let _ = sessions.incomplete.remove(replace_id.as_ref().unwrap());
}
sessions.incomplete.insert(
bob_session_id,
Arc::new(NoiseXKIncoming {
timestamp: current_time,
if sessions.incomplete.len() >= self.max_incomplete_session_queue_size {
// If this queue is too big, we remove the latest entry and replace it. The latest
// is used because under flood conditions this is most likely to be another bogus
// entry. If we find one that is actually timed out, that one is replaced instead.
let mut newest = i64::MIN;
let mut replace_id = None;
let cutoff_time = current_time - Application::INCOMING_SESSION_NEGOTIATION_TIMEOUT_MS;
for (id, s) in sessions.incomplete.iter() {
if s.timestamp <= cutoff_time {
replace_id = Some(*id);
break;
} else if s.timestamp >= newest {
newest = s.timestamp;
replace_id = Some(*id);
}
}
let _ = sessions.incomplete.remove(replace_id.as_ref().unwrap());
}
// Reserve session ID on this side and record incomplete session state.
sessions.incomplete.insert(
bob_session_id,
Arc::new(IncompleteIncomingSession {
timestamp: current_time,
request_hash,
alice_session_id,
bob_session_id,
noise_es_ee: noise_es_ee.clone(),
bob_hk_ciphertext,
hk,
bob_noise_e_secret,
header_protection_key: Secret(pkt.header_protection_key),
}),
);
(
alice_session_id,
bob_session_id,
noise_es_ee: noise_es_ee.clone(),
hk,
bob_noise_e_secret,
header_protection_key: Secret(pkt.header_protection_key),
}),
);
bob_session_id
};
noise_es_ee,
bob_hk_ciphertext,
Secret(pkt.header_protection_key),
bob_noise_e,
)
};
// Create Bob's ephemeral counter-offer reply.
let mut reply_buffer = [0u8; BobNoiseXKAck::SIZE];
let reply: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut reply_buffer)?;
reply.session_protocol_version = SESSION_PROTOCOL_VERSION;
reply.bob_noise_e = bob_noise_e;
reply.bob_session_id = *bob_session_id.as_bytes();
reply.bob_hk_ciphertext = bob_hk_ciphertext;
let mut ack_packet = [0u8; BobNoiseXKAck::SIZE];
let ack: &mut BobNoiseXKAck = byte_array_as_proto_buffer_mut(&mut ack_packet)?;
ack.session_protocol_version = SESSION_PROTOCOL_VERSION;
ack.bob_noise_e = bob_noise_e;
ack.bob_session_id = *bob_session_id.as_bytes();
ack.bob_hk_ciphertext = bob_hk_ciphertext;
// Encrypt main section of reply. Technically we could get away without this but why not?
let mut ctr =
AesCtr::new(kbkdf::<AES_KEY_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_ENCRYPTION>(noise_es_ee.as_bytes()).as_bytes());
ctr.reset_set_iv(&bob_noise_e[P384_PUBLIC_KEY_SIZE - AES_CTR_NONCE_SIZE..]);
ctr.crypt_in_place(&mut reply_buffer[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]);
ctr.crypt_in_place(&mut ack_packet[BobNoiseXKAck::ENC_START..BobNoiseXKAck::AUTH_START]);
// Add HMAC-SHA384 to reply packet.
let reply_hmac = hmac_sha384_2(
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es_ee.as_bytes()).as_bytes(),
&create_message_nonce(PACKET_TYPE_BOB_NOISE_XK_ACK, 1),
&reply_buffer[HEADER_SIZE..BobNoiseXKAck::AUTH_START],
&ack_packet[HEADER_SIZE..BobNoiseXKAck::AUTH_START],
);
reply_buffer[BobNoiseXKAck::AUTH_START..].copy_from_slice(&reply_hmac);
ack_packet[BobNoiseXKAck::AUTH_START..].copy_from_slice(&reply_hmac);
send_with_fragmentation(
|b| send(None, b),
&mut reply_buffer,
&mut ack_packet,
mtu,
PACKET_TYPE_BOB_NOISE_XK_ACK,
Some(alice_session_id),
0,
1,
Some(&Aes::new(&pkt.header_protection_key)),
Some(&Aes::new(header_protection_key.as_bytes())),
)?;
return Ok(ReceiveResult::Ok);
@ -715,13 +823,13 @@ impl<Application: ApplicationLayer> Context<Application> {
* the negotiation.
*/
if counter != 1 || incomplete.is_some() {
if incoming_counter != 1 || incomplete.is_some() {
return Err(Error::OutOfSequence);
}
if let Some(session) = session {
let state = session.state.read().unwrap();
if let EphemeralOffer::NoiseXKInit(outgoing_offer) = &state.offer {
if let Offer::NoiseXKInit(outgoing_offer) = &state.current_offer {
let pkt: &BobNoiseXKAck = byte_array_as_proto_buffer(pkt_assembled)?;
if let Some(bob_session_id) = SessionId::new_from_bytes(&pkt.bob_session_id) {
@ -776,12 +884,7 @@ impl<Application: ApplicationLayer> Context<Application> {
&hmac_sha512(session.psk.as_bytes(), hk.as_bytes()),
));
let noise_es_ee_se_hk_psk_hmac_key =
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es_ee_se_hk_psk.as_bytes());
let reply_counter = session.get_next_outgoing_counter().ok_or(Error::MaxKeyLifetimeExceeded)?;
debug_assert_eq!(reply_counter.get(), 2);
let reply_message_nonce = create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_ACK, reply_counter.get());
let reply_message_nonce = create_message_nonce(PACKET_TYPE_ALICE_NOISE_XK_ACK, 2);
// Create reply informing Bob of our static identity now that we've verified Bob and set
// up forward secrecy. Also return Bob's opaque note.
@ -826,7 +929,8 @@ impl<Application: ApplicationLayer> Context<Application> {
// key exchange. Bob won't be able to do this until he decrypts and parses Alice's
// identity, so the first HMAC is to let him authenticate that first.
let hmac_es_ee_se_hk_psk = hmac_sha384_2(
noise_es_ee_se_hk_psk_hmac_key.as_bytes(),
kbkdf::<HMAC_SHA384_SIZE, KBKDF_KEY_USAGE_LABEL_KEX_AUTHENTICATION>(noise_es_ee_se_hk_psk.as_bytes())
.as_bytes(),
&reply_message_nonce,
&reply_buffer[HEADER_SIZE..reply_len],
);
@ -840,12 +944,12 @@ impl<Application: ApplicationLayer> Context<Application> {
let _ = state.keys[0].insert(SessionKey::new::<Application>(
noise_es_ee_se_hk_psk,
current_time,
reply_counter.get(),
2,
true,
false,
));
state.current_key = 0;
state.offer = EphemeralOffer::None;
state.current_offer = Offer::None;
}
send_with_fragmentation(
@ -855,7 +959,7 @@ impl<Application: ApplicationLayer> Context<Application> {
PACKET_TYPE_ALICE_NOISE_XK_ACK,
Some(bob_session_id),
0,
reply_counter.get(),
2,
Some(&session.header_protection_cipher),
)?;
@ -882,7 +986,7 @@ impl<Application: ApplicationLayer> Context<Application> {
* that Alice must return.
*/
if session.is_some() || counter != 2 {
if incoming_counter != 2 || session.is_some() {
return Err(Error::OutOfSequence);
}
if pkt_assembled.len() < ALICE_NOISE_XK_ACK_MIN_SIZE {
@ -999,11 +1103,12 @@ impl<Application: ApplicationLayer> Context<Application> {
None,
],
current_key: 0,
offer: EphemeralOffer::None,
current_offer: Offer::None,
}),
defrag: Mutex::new(RingBufferMap::new(random::xorshift64_random() as u32)),
});
// Promote this from an incomplete session to an established session.
{
let mut sessions = self.sessions.write().unwrap();
sessions.incomplete.remove(&incomplete.bob_session_id);
@ -1020,11 +1125,15 @@ impl<Application: ApplicationLayer> Context<Application> {
if pkt_assembled.len() != AliceRekeyInit::SIZE {
return Err(Error::InvalidPacket);
}
if incomplete.is_some() {
return Err(Error::OutOfSequence);
}
if let Some(session) = session {
let state = session.state.read().unwrap();
if let Some(key) = state.keys[key_index].as_ref() {
// Only the current "Alice" accepts rekeys initiated by the current "Bob."
if !key.role_is_bob {
if !key.bob {
let mut c = key.get_receive_cipher();
c.reset_init_gcm(&incoming_message_nonce);
c.crypt_in_place(&mut pkt_assembled[AliceRekeyInit::ENC_START..AliceRekeyInit::AUTH_START]);
@ -1078,12 +1187,16 @@ impl<Application: ApplicationLayer> Context<Application> {
if pkt_assembled.len() != BobRekeyAck::SIZE {
return Err(Error::InvalidPacket);
}
if incomplete.is_some() {
return Err(Error::OutOfSequence);
}
if let Some(session) = session {
let state = session.state.read().unwrap();
if let EphemeralOffer::RekeyInit(alice_e_secret) = &state.offer {
if let Offer::RekeyInit(alice_e_secret, _, _) = &state.current_offer {
if let Some(key) = state.keys[key_index].as_ref() {
// Only the current "Bob" initiates rekeys and expects this ACK.
if key.role_is_bob {
if key.bob {
let mut c = key.get_receive_cipher();
c.reset_init_gcm(&incoming_message_nonce);
c.crypt_in_place(&mut pkt_assembled[BobRekeyAck::ENC_START..BobRekeyAck::AUTH_START]);
@ -1102,11 +1215,11 @@ impl<Application: ApplicationLayer> Context<Application> {
let _ = state.keys[key_index ^ 1].replace(SessionKey::new::<Application>(
next_session_key,
current_time,
counter,
session.send_counter.load(Ordering::Acquire),
true,
false,
));
state.offer = EphemeralOffer::None;
state.current_offer = Offer::None;
return Ok(ReceiveResult::Ok);
}
@ -1128,11 +1241,6 @@ impl<Application: ApplicationLayer> Context<Application> {
}
}
}
/// Look up a session by local session ID.
fn look_up_session(&self, id: SessionId) -> Option<Arc<Session<Application>>> {
self.sessions.read().unwrap().active.get(&id).and_then(|s| s.upgrade())
}
}
impl<Application: ApplicationLayer> Session<Application> {
@ -1206,7 +1314,7 @@ impl<Application: ApplicationLayer> Session<Application> {
/// This is called from the session context's service() method when it's time to rekey.
/// It should only be called when the current key was established in the 'bob' role. This
/// is checked when rekey time is checked.
fn send_rekey<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction) {
fn initiate_rekey<SendFunction: FnMut(&mut [u8])>(&self, mut send: SendFunction, current_time: i64) {
let rekey_e = P384KeyPair::generate();
let mut rekey_buf = [0u8; AliceRekeyInit::SIZE];
@ -1214,18 +1322,31 @@ impl<Application: ApplicationLayer> Session<Application> {
pkt.alice_e = *rekey_e.public_key_bytes();
let state = self.state.read().unwrap();
if let Some(key) = state.keys[state.current_key].as_ref() {
if let Some(counter) = self.get_next_outgoing_counter() {
if let Ok(mut gcm) = key.get_send_cipher(counter.get()) {
gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_ALICE_REKEY_INIT, counter.get()));
gcm.crypt_in_place(&mut rekey_buf[AliceRekeyInit::ENC_START..AliceRekeyInit::AUTH_START]);
rekey_buf[AliceRekeyInit::AUTH_START..].copy_from_slice(&gcm.finish_encrypt());
key.return_send_cipher(gcm);
if let Some(remote_session_id) = state.remote_session_id {
if let Some(key) = state.keys[state.current_key].as_ref() {
if let Some(counter) = self.get_next_outgoing_counter() {
if let Ok(mut gcm) = key.get_send_cipher(counter.get()) {
gcm.reset_init_gcm(&create_message_nonce(PACKET_TYPE_ALICE_REKEY_INIT, counter.get()));
gcm.crypt_in_place(&mut rekey_buf[AliceRekeyInit::ENC_START..AliceRekeyInit::AUTH_START]);
rekey_buf[AliceRekeyInit::AUTH_START..].copy_from_slice(&gcm.finish_encrypt());
key.return_send_cipher(gcm);
send(&mut rekey_buf);
debug_assert!(rekey_buf.len() <= MIN_TRANSPORT_MTU);
set_packet_header(
&mut rekey_buf,
1,
0,
PACKET_TYPE_ALICE_REKEY_INIT,
u64::from(remote_session_id),
state.current_key,
counter.get(),
);
drop(state);
self.state.write().unwrap().offer = EphemeralOffer::RekeyInit(rekey_e);
send(&mut rekey_buf);
drop(state);
self.state.write().unwrap().current_offer = Offer::RekeyInit(rekey_e, rekey_buf, AtomicI64::new(current_time));
}
}
}
}
@ -1240,16 +1361,16 @@ impl<Application: ApplicationLayer> Session<Application> {
/// Check the receive window without mutating state.
#[inline(always)]
fn check_receive_window(&self, counter: u64) -> bool {
let c = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Acquire);
c < counter && counter.wrapping_sub(c) < COUNTER_WINDOW_MAX_SKIP_AHEAD
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].load(Ordering::Acquire);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
/// Update the receive window, returning true if the packet is still valid.
/// This should only be called after the packet is authenticated.
#[inline(always)]
fn update_receive_window(&self, counter: u64) -> bool {
let c = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel);
c < counter && counter.wrapping_sub(c) < COUNTER_WINDOW_MAX_SKIP_AHEAD
let prev_counter = self.receive_window[(counter as usize) % COUNTER_WINDOW_MAX_OOO].fetch_max(counter, Ordering::AcqRel);
prev_counter < counter && counter.wrapping_sub(prev_counter) < COUNTER_WINDOW_MAX_SKIP_AHEAD
}
}
@ -1258,7 +1379,7 @@ fn set_packet_header(
fragment_count: usize,
fragment_no: usize,
packet_type: u8,
recipient_session_id: u64,
remote_session_id: u64,
key_index: usize,
counter: u64,
) {
@ -1276,7 +1397,7 @@ fn set_packet_header(
// [58-63] fragment number (0..63)
// [64-127] 64-bit counter
memory::store_raw(
(u64::from(recipient_session_id)
(u64::from(remote_session_id)
| ((key_index & 1) as u64).wrapping_shl(48)
| (packet_type as u64).wrapping_shl(49)
| ((fragment_count - 1) as u64).wrapping_shl(52)
@ -1323,13 +1444,13 @@ fn send_with_fragmentation<SendFunction: FnMut(&mut [u8])>(
packet: &mut [u8],
mtu: usize,
packet_type: u8,
recipient_session_id: Option<SessionId>,
remote_session_id: Option<SessionId>,
key_index: usize,
counter: u64,
header_protect_cipher: Option<&Aes>,
) -> Result<(), Error> {
let packet_len = packet.len();
let recipient_session_id = recipient_session_id.map_or(SessionId::NONE, |s| u64::from(s));
let recipient_session_id = remote_session_id.map_or(SessionId::NONE, |s| u64::from(s));
let fragment_count = ((packet_len as f32) / (mtu as f32)).ceil() as usize;
let mut fragment_start = 0;
let mut fragment_end = packet_len.min(mtu);
@ -1385,7 +1506,7 @@ impl SessionKey {
rekey_at_counter: current_counter.checked_add(Application::REKEY_AFTER_USES).unwrap(),
expire_at_counter: current_counter.checked_add(Application::EXPIRE_AFTER_USES).unwrap(),
confirmed,
role_is_bob,
bob: role_is_bob,
}
}