aboutsummaryrefslogtreecommitdiff
path: root/src/packet.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/packet.rs')
-rw-r--r--src/packet.rs145
1 files changed, 108 insertions, 37 deletions
diff --git a/src/packet.rs b/src/packet.rs
index cc06031..39194f0 100644
--- a/src/packet.rs
+++ b/src/packet.rs
@@ -24,6 +24,10 @@
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+use std::fmt::Display;
+use std::ops::Index;
+use std::ops::IndexMut;
+use std::ops::RangeInclusive;
use std::time;
use ring::aead;
@@ -49,20 +53,62 @@ pub const MAX_PKT_NUM_LEN: usize = 4;
const SAMPLE_LEN: usize = 16;
-pub const EPOCH_INITIAL: usize = 0;
-pub const EPOCH_HANDSHAKE: usize = 1;
-pub const EPOCH_APPLICATION: usize = 2;
-pub const EPOCH_COUNT: usize = 3;
+#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
+pub enum Epoch {
+ Initial = 0,
+ Handshake = 1,
+ Application = 2,
+}
+
+static EPOCHS: [Epoch; 3] =
+ [Epoch::Initial, Epoch::Handshake, Epoch::Application];
+
+impl Epoch {
+ /// Returns an ordered slice containing the `Epoch`s that fit in the
+ /// provided `range`.
+ pub fn epochs(range: RangeInclusive<Epoch>) -> &'static [Epoch] {
+ &EPOCHS[*range.start() as usize..=*range.end() as usize]
+ }
+
+ pub const fn count() -> usize {
+ 3
+ }
+}
+
+impl Display for Epoch {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", usize::from(*self))
+ }
+}
+
+impl From<Epoch> for usize {
+ fn from(e: Epoch) -> Self {
+ e as usize
+ }
+}
+
+impl<T> Index<Epoch> for [T]
+where
+ T: Sized,
+{
+ type Output = T;
-/// Packet number space epoch.
-///
-/// This should only ever be one of `EPOCH_INITIAL`, `EPOCH_HANDSHAKE` or
-/// `EPOCH_APPLICATION`, and can be used to index state specific to a packet
-/// number space in `Connection` and `Recovery`.
-pub type Epoch = usize;
+ fn index(&self, index: Epoch) -> &Self::Output {
+ self.index(usize::from(index))
+ }
+}
+
+impl<T> IndexMut<Epoch> for [T]
+where
+ T: Sized,
+{
+ fn index_mut(&mut self, index: Epoch) -> &mut Self::Output {
+ self.index_mut(usize::from(index))
+ }
+}
/// QUIC packet type.
-#[derive(Clone, Copy, Debug, PartialEq)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Type {
/// Initial packet.
Initial,
@@ -86,25 +132,23 @@ pub enum Type {
impl Type {
pub(crate) fn from_epoch(e: Epoch) -> Type {
match e {
- EPOCH_INITIAL => Type::Initial,
-
- EPOCH_HANDSHAKE => Type::Handshake,
+ Epoch::Initial => Type::Initial,
- EPOCH_APPLICATION => Type::Short,
+ Epoch::Handshake => Type::Handshake,
- _ => unreachable!(),
+ Epoch::Application => Type::Short,
}
}
pub(crate) fn to_epoch(self) -> Result<Epoch> {
match self {
- Type::Initial => Ok(EPOCH_INITIAL),
+ Type::Initial => Ok(Epoch::Initial),
- Type::ZeroRTT => Ok(EPOCH_APPLICATION),
+ Type::ZeroRTT => Ok(Epoch::Application),
- Type::Handshake => Ok(EPOCH_HANDSHAKE),
+ Type::Handshake => Ok(Epoch::Handshake),
- Type::Short => Ok(EPOCH_APPLICATION),
+ Type::Short => Ok(Epoch::Application),
_ => Err(Error::InvalidPacket),
}
@@ -230,7 +274,7 @@ impl<'a> std::fmt::Debug for ConnectionId<'a> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
for c in self.as_ref() {
- write!(f, "{:02x}", c)?;
+ write!(f, "{c:02x}")?;
}
Ok(())
@@ -238,7 +282,7 @@ impl<'a> std::fmt::Debug for ConnectionId<'a> {
}
/// A QUIC packet's header.
-#[derive(Clone, PartialEq)]
+#[derive(Clone, PartialEq, Eq)]
pub struct Header<'a> {
/// The type of the packet.
pub ty: Type,
@@ -495,12 +539,12 @@ impl<'a> std::fmt::Debug for Header<'a> {
if let Some(ref token) = self.token {
write!(f, " token=")?;
for b in token {
- write!(f, "{:02x}", b)?;
+ write!(f, "{b:02x}")?;
}
}
if let Some(ref versions) = self.versions {
- write!(f, " versions={:x?}", versions)?;
+ write!(f, " versions={versions:x?}")?;
}
if self.ty == Type::Short {
@@ -512,11 +556,13 @@ impl<'a> std::fmt::Debug for Header<'a> {
}
pub fn pkt_num_len(pn: u64) -> Result<usize> {
- let len = if pn < u64::from(std::u8::MAX) {
+ let len = if pn < u64::from(u8::MAX) {
1
- } else if pn < u64::from(std::u16::MAX) {
+ } else if pn < u64::from(u16::MAX) {
2
- } else if pn < u64::from(std::u32::MAX) {
+ } else if pn < 16_777_215u64 {
+ 3
+ } else if pn < u64::from(u32::MAX) {
4
} else {
return Err(Error::InvalidPacket);
@@ -625,7 +671,8 @@ pub fn decrypt_pkt<'a>(
pub fn encrypt_hdr(
b: &mut octets::OctetsMut, pn_len: usize, payload: &[u8], aead: &crypto::Seal,
) -> Result<()> {
- let sample = &payload[4 - pn_len..16 + (4 - pn_len)];
+ let sample = &payload
+ [MAX_PKT_NUM_LEN - pn_len..SAMPLE_LEN + (MAX_PKT_NUM_LEN - pn_len)];
let mask = aead.new_mask(sample)?;
@@ -814,11 +861,29 @@ fn compute_retry_integrity_tag(
.map_err(|_| Error::CryptoFail)
}
+pub struct KeyUpdate {
+ /// 1-RTT key used prior to a key update.
+ pub crypto_open: crypto::Open,
+
+ /// The packet number triggered the latest key-update.
+ ///
+ /// Incoming packets with lower pn should use this (prev) crypto key.
+ pub pn_on_update: u64,
+
+ /// Whether ACK frame for key-update has been sent.
+ pub update_acked: bool,
+
+ /// When the old key should be discarded.
+ pub timer: time::Instant,
+}
+
pub struct PktNumSpace {
pub largest_rx_pkt_num: u64,
pub largest_rx_pkt_time: time::Instant,
+ pub largest_rx_non_probing_pkt_num: u64,
+
pub next_pkt_num: u64,
pub recv_pkt_need_ack: ranges::RangeSet,
@@ -827,6 +892,8 @@ pub struct PktNumSpace {
pub ack_elicited: bool,
+ pub key_update: Option<KeyUpdate>,
+
pub crypto_open: Option<crypto::Open>,
pub crypto_seal: Option<crypto::Seal>,
@@ -843,6 +910,8 @@ impl PktNumSpace {
largest_rx_pkt_time: time::Instant::now(),
+ largest_rx_non_probing_pkt_num: 0,
+
next_pkt_num: 0,
recv_pkt_need_ack: ranges::RangeSet::new(crate::MAX_ACK_RANGES),
@@ -851,6 +920,8 @@ impl PktNumSpace {
ack_elicited: false,
+ key_update: None,
+
crypto_open: None,
crypto_seal: None,
@@ -858,8 +929,8 @@ impl PktNumSpace {
crypto_0rtt_seal: None,
crypto_stream: stream::Stream::new(
- std::u64::MAX,
- std::u64::MAX,
+ u64::MAX,
+ u64::MAX,
true,
true,
stream::MAX_STREAM_WINDOW,
@@ -869,8 +940,8 @@ impl PktNumSpace {
pub fn clear(&mut self) {
self.crypto_stream = stream::Stream::new(
- std::u64::MAX,
- std::u64::MAX,
+ u64::MAX,
+ u64::MAX,
true,
true,
stream::MAX_STREAM_WINDOW,
@@ -966,7 +1037,7 @@ mod tests {
assert!(hdr.to_bytes(&mut b).is_ok());
// Add fake retry integrity token.
- b.put_bytes(&vec![0xba; 16]).unwrap();
+ b.put_bytes(&[0xba; 16]).unwrap();
let mut b = octets::OctetsMut::with_slice(&mut d);
assert_eq!(Header::from_bytes(&mut b, 9).unwrap(), hdr);
@@ -1213,7 +1284,7 @@ mod tests {
assert!(!win.contains(1025));
assert!(!win.contains(1026));
- win.insert(std::u64::MAX - 1);
+ win.insert(u64::MAX - 1);
assert!(win.contains(0));
assert!(win.contains(1));
assert!(win.contains(2));
@@ -1235,8 +1306,8 @@ mod tests {
assert!(win.contains(1024));
assert!(win.contains(1025));
assert!(win.contains(1026));
- assert!(!win.contains(std::u64::MAX - 2));
- assert!(win.contains(std::u64::MAX - 1));
+ assert!(!win.contains(u64::MAX - 2));
+ assert!(win.contains(u64::MAX - 1));
}
fn assert_decrypt_initial_pkt(
@@ -1875,7 +1946,7 @@ mod tests {
.unwrap();
assert_eq!(written, expected_pkt.len());
- assert_eq!(&out[..written], &expected_pkt[..]);
+ assert_eq!(&out[..written], expected_pkt);
}
#[test]