diff options
Diffstat (limited to 'src/packet.rs')
-rw-r--r-- | src/packet.rs | 145 |
1 files changed, 108 insertions, 37 deletions
diff --git a/src/packet.rs b/src/packet.rs index cc06031..39194f0 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -24,6 +24,10 @@ // NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use std::fmt::Display; +use std::ops::Index; +use std::ops::IndexMut; +use std::ops::RangeInclusive; use std::time; use ring::aead; @@ -49,20 +53,62 @@ pub const MAX_PKT_NUM_LEN: usize = 4; const SAMPLE_LEN: usize = 16; -pub const EPOCH_INITIAL: usize = 0; -pub const EPOCH_HANDSHAKE: usize = 1; -pub const EPOCH_APPLICATION: usize = 2; -pub const EPOCH_COUNT: usize = 3; +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub enum Epoch { + Initial = 0, + Handshake = 1, + Application = 2, +} + +static EPOCHS: [Epoch; 3] = + [Epoch::Initial, Epoch::Handshake, Epoch::Application]; + +impl Epoch { + /// Returns an ordered slice containing the `Epoch`s that fit in the + /// provided `range`. + pub fn epochs(range: RangeInclusive<Epoch>) -> &'static [Epoch] { + &EPOCHS[*range.start() as usize..=*range.end() as usize] + } + + pub const fn count() -> usize { + 3 + } +} + +impl Display for Epoch { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", usize::from(*self)) + } +} + +impl From<Epoch> for usize { + fn from(e: Epoch) -> Self { + e as usize + } +} + +impl<T> Index<Epoch> for [T] +where + T: Sized, +{ + type Output = T; -/// Packet number space epoch. -/// -/// This should only ever be one of `EPOCH_INITIAL`, `EPOCH_HANDSHAKE` or -/// `EPOCH_APPLICATION`, and can be used to index state specific to a packet -/// number space in `Connection` and `Recovery`. -pub type Epoch = usize; + fn index(&self, index: Epoch) -> &Self::Output { + self.index(usize::from(index)) + } +} + +impl<T> IndexMut<Epoch> for [T] +where + T: Sized, +{ + fn index_mut(&mut self, index: Epoch) -> &mut Self::Output { + self.index_mut(usize::from(index)) + } +} /// QUIC packet type. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Type { /// Initial packet. Initial, @@ -86,25 +132,23 @@ pub enum Type { impl Type { pub(crate) fn from_epoch(e: Epoch) -> Type { match e { - EPOCH_INITIAL => Type::Initial, - - EPOCH_HANDSHAKE => Type::Handshake, + Epoch::Initial => Type::Initial, - EPOCH_APPLICATION => Type::Short, + Epoch::Handshake => Type::Handshake, - _ => unreachable!(), + Epoch::Application => Type::Short, } } pub(crate) fn to_epoch(self) -> Result<Epoch> { match self { - Type::Initial => Ok(EPOCH_INITIAL), + Type::Initial => Ok(Epoch::Initial), - Type::ZeroRTT => Ok(EPOCH_APPLICATION), + Type::ZeroRTT => Ok(Epoch::Application), - Type::Handshake => Ok(EPOCH_HANDSHAKE), + Type::Handshake => Ok(Epoch::Handshake), - Type::Short => Ok(EPOCH_APPLICATION), + Type::Short => Ok(Epoch::Application), _ => Err(Error::InvalidPacket), } @@ -230,7 +274,7 @@ impl<'a> std::fmt::Debug for ConnectionId<'a> { #[inline] fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { for c in self.as_ref() { - write!(f, "{:02x}", c)?; + write!(f, "{c:02x}")?; } Ok(()) @@ -238,7 +282,7 @@ impl<'a> std::fmt::Debug for ConnectionId<'a> { } /// A QUIC packet's header. -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Eq)] pub struct Header<'a> { /// The type of the packet. pub ty: Type, @@ -495,12 +539,12 @@ impl<'a> std::fmt::Debug for Header<'a> { if let Some(ref token) = self.token { write!(f, " token=")?; for b in token { - write!(f, "{:02x}", b)?; + write!(f, "{b:02x}")?; } } if let Some(ref versions) = self.versions { - write!(f, " versions={:x?}", versions)?; + write!(f, " versions={versions:x?}")?; } if self.ty == Type::Short { @@ -512,11 +556,13 @@ impl<'a> std::fmt::Debug for Header<'a> { } pub fn pkt_num_len(pn: u64) -> Result<usize> { - let len = if pn < u64::from(std::u8::MAX) { + let len = if pn < u64::from(u8::MAX) { 1 - } else if pn < u64::from(std::u16::MAX) { + } else if pn < u64::from(u16::MAX) { 2 - } else if pn < u64::from(std::u32::MAX) { + } else if pn < 16_777_215u64 { + 3 + } else if pn < u64::from(u32::MAX) { 4 } else { return Err(Error::InvalidPacket); @@ -625,7 +671,8 @@ pub fn decrypt_pkt<'a>( pub fn encrypt_hdr( b: &mut octets::OctetsMut, pn_len: usize, payload: &[u8], aead: &crypto::Seal, ) -> Result<()> { - let sample = &payload[4 - pn_len..16 + (4 - pn_len)]; + let sample = &payload + [MAX_PKT_NUM_LEN - pn_len..SAMPLE_LEN + (MAX_PKT_NUM_LEN - pn_len)]; let mask = aead.new_mask(sample)?; @@ -814,11 +861,29 @@ fn compute_retry_integrity_tag( .map_err(|_| Error::CryptoFail) } +pub struct KeyUpdate { + /// 1-RTT key used prior to a key update. + pub crypto_open: crypto::Open, + + /// The packet number triggered the latest key-update. + /// + /// Incoming packets with lower pn should use this (prev) crypto key. + pub pn_on_update: u64, + + /// Whether ACK frame for key-update has been sent. + pub update_acked: bool, + + /// When the old key should be discarded. + pub timer: time::Instant, +} + pub struct PktNumSpace { pub largest_rx_pkt_num: u64, pub largest_rx_pkt_time: time::Instant, + pub largest_rx_non_probing_pkt_num: u64, + pub next_pkt_num: u64, pub recv_pkt_need_ack: ranges::RangeSet, @@ -827,6 +892,8 @@ pub struct PktNumSpace { pub ack_elicited: bool, + pub key_update: Option<KeyUpdate>, + pub crypto_open: Option<crypto::Open>, pub crypto_seal: Option<crypto::Seal>, @@ -843,6 +910,8 @@ impl PktNumSpace { largest_rx_pkt_time: time::Instant::now(), + largest_rx_non_probing_pkt_num: 0, + next_pkt_num: 0, recv_pkt_need_ack: ranges::RangeSet::new(crate::MAX_ACK_RANGES), @@ -851,6 +920,8 @@ impl PktNumSpace { ack_elicited: false, + key_update: None, + crypto_open: None, crypto_seal: None, @@ -858,8 +929,8 @@ impl PktNumSpace { crypto_0rtt_seal: None, crypto_stream: stream::Stream::new( - std::u64::MAX, - std::u64::MAX, + u64::MAX, + u64::MAX, true, true, stream::MAX_STREAM_WINDOW, @@ -869,8 +940,8 @@ impl PktNumSpace { pub fn clear(&mut self) { self.crypto_stream = stream::Stream::new( - std::u64::MAX, - std::u64::MAX, + u64::MAX, + u64::MAX, true, true, stream::MAX_STREAM_WINDOW, @@ -966,7 +1037,7 @@ mod tests { assert!(hdr.to_bytes(&mut b).is_ok()); // Add fake retry integrity token. - b.put_bytes(&vec![0xba; 16]).unwrap(); + b.put_bytes(&[0xba; 16]).unwrap(); let mut b = octets::OctetsMut::with_slice(&mut d); assert_eq!(Header::from_bytes(&mut b, 9).unwrap(), hdr); @@ -1213,7 +1284,7 @@ mod tests { assert!(!win.contains(1025)); assert!(!win.contains(1026)); - win.insert(std::u64::MAX - 1); + win.insert(u64::MAX - 1); assert!(win.contains(0)); assert!(win.contains(1)); assert!(win.contains(2)); @@ -1235,8 +1306,8 @@ mod tests { assert!(win.contains(1024)); assert!(win.contains(1025)); assert!(win.contains(1026)); - assert!(!win.contains(std::u64::MAX - 2)); - assert!(win.contains(std::u64::MAX - 1)); + assert!(!win.contains(u64::MAX - 2)); + assert!(win.contains(u64::MAX - 1)); } fn assert_decrypt_initial_pkt( @@ -1875,7 +1946,7 @@ mod tests { .unwrap(); assert_eq!(written, expected_pkt.len()); - assert_eq!(&out[..written], &expected_pkt[..]); + assert_eq!(&out[..written], expected_pkt); } #[test] |