aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.cargo_vcs_info.json2
-rw-r--r--.circleci/config.yml2
-rw-r--r--Android.bp6
-rw-r--r--Cargo.toml2
-rw-r--r--Cargo.toml.orig2
-rw-r--r--METADATA10
-rw-r--r--RELEASE-NOTES.md6
-rw-r--r--benches/benchmarks.rs3
-rw-r--r--src/decode.rs74
-rw-r--r--src/engine/general_purpose/decode.rs392
-rw-r--r--src/engine/general_purpose/decode_suffix.rs134
-rw-r--r--src/engine/general_purpose/mod.rs6
-rw-r--r--src/engine/mod.rs41
-rw-r--r--src/engine/naive.rs77
-rw-r--r--src/engine/tests.rs684
-rw-r--r--src/lib.rs3
-rw-r--r--src/read/decoder.rs93
17 files changed, 713 insertions, 824 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
index d61e543..c7677af 100644
--- a/.cargo_vcs_info.json
+++ b/.cargo_vcs_info.json
@@ -1,6 +1,6 @@
{
"git": {
- "sha1": "9652c787730e58515ce7b44fcafd2430ab424628"
+ "sha1": "5d70ba7576f9aafcbf02bd8acfcb9973411fb95f"
},
"path_in_vcs": ""
} \ No newline at end of file
diff --git a/.circleci/config.yml b/.circleci/config.yml
index ac0fae1..4d2576d 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -20,7 +20,7 @@ workflows:
# get a nightly or stable toolchain via rustup instead of a mutable docker tag
toolchain_override: [
'__msrv__', # won't add any other toolchains, just uses what's in the docker image
- '1.65.0', # minimum needed to build dev-dependencies
+ '1.70.0', # minimum needed to build dev-dependencies
'stable',
'beta',
'nightly'
diff --git a/Android.bp b/Android.bp
index 22a37c0..70d8968 100644
--- a/Android.bp
+++ b/Android.bp
@@ -1,5 +1,7 @@
// This file is generated by cargo_embargo.
-// Do not modify this file as changes will be overridden on upgrade.
+// Do not modify this file after the first "rust_*" or "genrule" module
+// because the changes will be overridden on upgrade.
+// Content before the first "rust_*" or "genrule" module is preserved.
package {
default_applicable_licenses: ["external_rust_crates_base64_license"],
@@ -42,7 +44,7 @@ rust_library {
host_supported: true,
crate_name: "base64",
cargo_env_compat: true,
- cargo_pkg_version: "0.21.7",
+ cargo_pkg_version: "0.22.0",
srcs: ["src/lib.rs"],
edition: "2018",
features: [
diff --git a/Cargo.toml b/Cargo.toml
index e508297..409d6d9 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -13,7 +13,7 @@
edition = "2018"
rust-version = "1.48.0"
name = "base64"
-version = "0.21.7"
+version = "0.22.0"
authors = [
"Alice Maz <alice@alicemaz.com>",
"Marshall Pierce <marshall@mpierce.org>",
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
index 4db5d26..c2670d3 100644
--- a/Cargo.toml.orig
+++ b/Cargo.toml.orig
@@ -1,6 +1,6 @@
[package]
name = "base64"
-version = "0.21.7"
+version = "0.22.0"
authors = ["Alice Maz <alice@alicemaz.com>", "Marshall Pierce <marshall@mpierce.org>"]
description = "encodes and decodes base64 as bytes or utf8"
repository = "https://github.com/marshallpierce/rust-base64"
diff --git a/METADATA b/METADATA
index b4a6927..e5d8d65 100644
--- a/METADATA
+++ b/METADATA
@@ -1,5 +1,5 @@
# This project was upgraded with external_updater.
-# Usage: tools/external_updater/updater.sh update external/rust/crates/base64
+# Usage: tools/external_updater/updater.sh update external/<absolute path to project>
# For more info, check https://cs.android.com/android/platform/superproject/+/main:tools/external_updater/README.md
name: "base64"
@@ -8,13 +8,13 @@ third_party {
license_type: NOTICE
last_upgrade_date {
year: 2024
- month: 1
- day: 31
+ month: 4
+ day: 24
}
homepage: "https://crates.io/crates/base64"
identifier {
type: "Archive"
- value: "https://static.crates.io/crates/base64/base64-0.21.7.crate"
- version: "0.21.7"
+ value: "https://static.crates.io/crates/base64/base64-0.22.0.crate"
+ version: "0.22.0"
}
}
diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md
index 0031215..46e281e 100644
--- a/RELEASE-NOTES.md
+++ b/RELEASE-NOTES.md
@@ -1,3 +1,9 @@
+# 0.22.0
+
+- `DecodeSliceError::OutputSliceTooSmall` is now conservative rather than precise. That is, the error will only occur if the decoded output _cannot_ fit, meaning that `Engine::decode_slice` can now be used with exactly-sized output slices. As part of this, `Engine::internal_decode` now returns `DecodeSliceError` instead of `DecodeError`, but that is not expected to affect any external callers.
+- `DecodeError::InvalidLength` now refers specifically to the _number of valid symbols_ being invalid (i.e. `len % 4 == 1`), rather than just the number of input bytes. This avoids confusing scenarios when based on interpretation you could make a case for either `InvalidLength` or `InvalidByte` being appropriate.
+- Decoding is somewhat faster (5-10%)
+
# 0.21.7
- Support getting an alphabet's contents as a str via `Alphabet::as_str()`
diff --git a/benches/benchmarks.rs b/benches/benchmarks.rs
index 802c8cc..8f04185 100644
--- a/benches/benchmarks.rs
+++ b/benches/benchmarks.rs
@@ -102,9 +102,8 @@ fn do_encode_bench_slice(b: &mut Bencher, &size: &usize) {
fn do_encode_bench_stream(b: &mut Bencher, &size: &usize) {
let mut v: Vec<u8> = Vec::with_capacity(size);
fill(&mut v);
- let mut buf = Vec::new();
+ let mut buf = Vec::with_capacity(size * 2);
- buf.reserve(size * 2);
b.iter(|| {
buf.clear();
let mut stream_enc = write::EncoderWriter::new(&mut buf, &STANDARD);
diff --git a/src/decode.rs b/src/decode.rs
index 5230fd3..6df8aba 100644
--- a/src/decode.rs
+++ b/src/decode.rs
@@ -9,18 +9,20 @@ use std::error;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum DecodeError {
/// An invalid byte was found in the input. The offset and offending byte are provided.
- /// Padding characters (`=`) interspersed in the encoded form will be treated as invalid bytes.
+ ///
+ /// Padding characters (`=`) interspersed in the encoded form are invalid, as they may only
+ /// be present as the last 0-2 bytes of input.
+ ///
+ /// This error may also indicate that extraneous trailing input bytes are present, causing
+ /// otherwise valid padding to no longer be the last bytes of input.
InvalidByte(usize, u8),
- /// The length of the input is invalid.
- /// A typical cause of this is stray trailing whitespace or other separator bytes.
- /// In the case where excess trailing bytes have produced an invalid length *and* the last byte
- /// is also an invalid base64 symbol (as would be the case for whitespace, etc), `InvalidByte`
- /// will be emitted instead of `InvalidLength` to make the issue easier to debug.
- InvalidLength,
+ /// The length of the input, as measured in valid base64 symbols, is invalid.
+ /// There must be 2-4 symbols in the last input quad.
+ InvalidLength(usize),
/// The last non-padding input symbol's encoded 6 bits have nonzero bits that will be discarded.
/// This is indicative of corrupted or truncated Base64.
- /// Unlike `InvalidByte`, which reports symbols that aren't in the alphabet, this error is for
- /// symbols that are in the alphabet but represent nonsensical encodings.
+ /// Unlike [DecodeError::InvalidByte], which reports symbols that aren't in the alphabet,
+ /// this error is for symbols that are in the alphabet but represent nonsensical encodings.
InvalidLastSymbol(usize, u8),
/// The nature of the padding was not as configured: absent or incorrect when it must be
/// canonical, or present when it must be absent, etc.
@@ -30,8 +32,10 @@ pub enum DecodeError {
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
- Self::InvalidByte(index, byte) => write!(f, "Invalid byte {}, offset {}.", byte, index),
- Self::InvalidLength => write!(f, "Encoded text cannot have a 6-bit remainder."),
+ Self::InvalidByte(index, byte) => {
+ write!(f, "Invalid symbol {}, offset {}.", byte, index)
+ }
+ Self::InvalidLength(len) => write!(f, "Invalid input length: {}", len),
Self::InvalidLastSymbol(index, byte) => {
write!(f, "Invalid last symbol {}, offset {}.", byte, index)
}
@@ -48,9 +52,7 @@ impl error::Error for DecodeError {}
pub enum DecodeSliceError {
/// A [DecodeError] occurred
DecodeError(DecodeError),
- /// The provided slice _may_ be too small.
- ///
- /// The check is conservative (assumes the last triplet of output bytes will all be needed).
+ /// The provided slice is too small.
OutputSliceTooSmall,
}
@@ -338,3 +340,47 @@ mod tests {
}
}
}
+
+#[allow(deprecated)]
+#[cfg(test)]
+mod coverage_gaming {
+ use super::*;
+ use std::error::Error;
+
+ #[test]
+ fn decode_error() {
+ let _ = format!("{:?}", DecodeError::InvalidPadding.clone());
+ let _ = format!(
+ "{} {} {} {}",
+ DecodeError::InvalidByte(0, 0),
+ DecodeError::InvalidLength(0),
+ DecodeError::InvalidLastSymbol(0, 0),
+ DecodeError::InvalidPadding,
+ );
+ }
+
+ #[test]
+ fn decode_slice_error() {
+ let _ = format!("{:?}", DecodeSliceError::OutputSliceTooSmall.clone());
+ let _ = format!(
+ "{} {}",
+ DecodeSliceError::OutputSliceTooSmall,
+ DecodeSliceError::DecodeError(DecodeError::InvalidPadding)
+ );
+ let _ = DecodeSliceError::OutputSliceTooSmall.source();
+ let _ = DecodeSliceError::DecodeError(DecodeError::InvalidPadding).source();
+ }
+
+ #[test]
+ fn deprecated_fns() {
+ let _ = decode("");
+ let _ = decode_engine("", &crate::prelude::BASE64_STANDARD);
+ let _ = decode_engine_vec("", &mut Vec::new(), &crate::prelude::BASE64_STANDARD);
+ let _ = decode_engine_slice("", &mut [], &crate::prelude::BASE64_STANDARD);
+ }
+
+ #[test]
+ fn decoded_len_est() {
+ assert_eq!(3, decoded_len_estimate(4));
+ }
+}
diff --git a/src/engine/general_purpose/decode.rs b/src/engine/general_purpose/decode.rs
index 21a386f..b55d3fc 100644
--- a/src/engine/general_purpose/decode.rs
+++ b/src/engine/general_purpose/decode.rs
@@ -1,47 +1,28 @@
use crate::{
engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
- DecodeError, PAD_BYTE,
+ DecodeError, DecodeSliceError, PAD_BYTE,
};
-// decode logic operates on chunks of 8 input bytes without padding
-const INPUT_CHUNK_LEN: usize = 8;
-const DECODED_CHUNK_LEN: usize = 6;
-
-// we read a u64 and write a u64, but a u64 of input only yields 6 bytes of output, so the last
-// 2 bytes of any output u64 should not be counted as written to (but must be available in a
-// slice).
-const DECODED_CHUNK_SUFFIX: usize = 2;
-
-// how many u64's of input to handle at a time
-const CHUNKS_PER_FAST_LOOP_BLOCK: usize = 4;
-
-const INPUT_BLOCK_LEN: usize = CHUNKS_PER_FAST_LOOP_BLOCK * INPUT_CHUNK_LEN;
-
-// includes the trailing 2 bytes for the final u64 write
-const DECODED_BLOCK_LEN: usize =
- CHUNKS_PER_FAST_LOOP_BLOCK * DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX;
-
#[doc(hidden)]
pub struct GeneralPurposeEstimate {
- /// Total number of decode chunks, including a possibly partial last chunk
- num_chunks: usize,
- decoded_len_estimate: usize,
+ /// input len % 4
+ rem: usize,
+ conservative_decoded_len: usize,
}
impl GeneralPurposeEstimate {
pub(crate) fn new(encoded_len: usize) -> Self {
- // Formulas that won't overflow
+ let rem = encoded_len % 4;
Self {
- num_chunks: encoded_len / INPUT_CHUNK_LEN
- + (encoded_len % INPUT_CHUNK_LEN > 0) as usize,
- decoded_len_estimate: (encoded_len / 4 + (encoded_len % 4 > 0) as usize) * 3,
+ rem,
+ conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
}
}
}
impl DecodeEstimate for GeneralPurposeEstimate {
fn decoded_len_estimate(&self) -> usize {
- self.decoded_len_estimate
+ self.conservative_decoded_len
}
}
@@ -58,263 +39,262 @@ pub(crate) fn decode_helper(
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
-) -> Result<DecodeMetadata, DecodeError> {
- let remainder_len = input.len() % INPUT_CHUNK_LEN;
-
- // Because the fast decode loop writes in groups of 8 bytes (unrolled to
- // CHUNKS_PER_FAST_LOOP_BLOCK times 8 bytes, where possible) and outputs 8 bytes at a time (of
- // which only 6 are valid data), we need to be sure that we stop using the fast decode loop
- // soon enough that there will always be 2 more bytes of valid data written after that loop.
- let trailing_bytes_to_skip = match remainder_len {
- // if input is a multiple of the chunk size, ignore the last chunk as it may have padding,
- // and the fast decode logic cannot handle padding
- 0 => INPUT_CHUNK_LEN,
- // 1 and 5 trailing bytes are illegal: can't decode 6 bits of input into a byte
- 1 | 5 => {
- // trailing whitespace is so common that it's worth it to check the last byte to
- // possibly return a better error message
- if let Some(b) = input.last() {
- if *b != PAD_BYTE && decode_table[*b as usize] == INVALID_VALUE {
- return Err(DecodeError::InvalidByte(input.len() - 1, *b));
- }
- }
+) -> Result<DecodeMetadata, DecodeSliceError> {
+ let input_complete_nonterminal_quads_len =
+ complete_quads_len(input, estimate.rem, output.len(), decode_table)?;
- return Err(DecodeError::InvalidLength);
- }
- // This will decode to one output byte, which isn't enough to overwrite the 2 extra bytes
- // written by the fast decode loop. So, we have to ignore both these 2 bytes and the
- // previous chunk.
- 2 => INPUT_CHUNK_LEN + 2,
- // If this is 3 un-padded chars, then it would actually decode to 2 bytes. However, if this
- // is an erroneous 2 chars + 1 pad char that would decode to 1 byte, then it should fail
- // with an error, not panic from going past the bounds of the output slice, so we let it
- // use stage 3 + 4.
- 3 => INPUT_CHUNK_LEN + 3,
- // This can also decode to one output byte because it may be 2 input chars + 2 padding
- // chars, which would decode to 1 byte.
- 4 => INPUT_CHUNK_LEN + 4,
- // Everything else is a legal decode len (given that we don't require padding), and will
- // decode to at least 2 bytes of output.
- _ => remainder_len,
- };
+ const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
+ const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;
- // rounded up to include partial chunks
- let mut remaining_chunks = estimate.num_chunks;
+ let input_complete_quads_after_unrolled_chunks_len =
+ input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE;
- let mut input_index = 0;
- let mut output_index = 0;
+ let input_unrolled_loop_len =
+ input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len;
+ // chunks of 32 bytes
+ for (chunk_index, chunk) in input[..input_unrolled_loop_len]
+ .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE)
+ .enumerate()
{
- let length_of_fast_decode_chunks = input.len().saturating_sub(trailing_bytes_to_skip);
-
- // Fast loop, stage 1
- // manual unroll to CHUNKS_PER_FAST_LOOP_BLOCK of u64s to amortize slice bounds checks
- if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_BLOCK_LEN) {
- while input_index <= max_start_index {
- let input_slice = &input[input_index..(input_index + INPUT_BLOCK_LEN)];
- let output_slice = &mut output[output_index..(output_index + DECODED_BLOCK_LEN)];
-
- decode_chunk(
- &input_slice[0..],
- input_index,
- decode_table,
- &mut output_slice[0..],
- )?;
- decode_chunk(
- &input_slice[8..],
- input_index + 8,
- decode_table,
- &mut output_slice[6..],
- )?;
- decode_chunk(
- &input_slice[16..],
- input_index + 16,
- decode_table,
- &mut output_slice[12..],
- )?;
- decode_chunk(
- &input_slice[24..],
- input_index + 24,
- decode_table,
- &mut output_slice[18..],
- )?;
-
- input_index += INPUT_BLOCK_LEN;
- output_index += DECODED_BLOCK_LEN - DECODED_CHUNK_SUFFIX;
- remaining_chunks -= CHUNKS_PER_FAST_LOOP_BLOCK;
- }
- }
-
- // Fast loop, stage 2 (aka still pretty fast loop)
- // 8 bytes at a time for whatever we didn't do in stage 1.
- if let Some(max_start_index) = length_of_fast_decode_chunks.checked_sub(INPUT_CHUNK_LEN) {
- while input_index < max_start_index {
- decode_chunk(
- &input[input_index..(input_index + INPUT_CHUNK_LEN)],
- input_index,
- decode_table,
- &mut output
- [output_index..(output_index + DECODED_CHUNK_LEN + DECODED_CHUNK_SUFFIX)],
- )?;
-
- output_index += DECODED_CHUNK_LEN;
- input_index += INPUT_CHUNK_LEN;
- remaining_chunks -= 1;
- }
- }
- }
+ let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE;
+ let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE
+ ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE];
- // Stage 3
- // If input length was such that a chunk had to be deferred until after the fast loop
- // because decoding it would have produced 2 trailing bytes that wouldn't then be
- // overwritten, we decode that chunk here. This way is slower but doesn't write the 2
- // trailing bytes.
- // However, we still need to avoid the last chunk (partial or complete) because it could
- // have padding, so we always do 1 fewer to avoid the last chunk.
- for _ in 1..remaining_chunks {
- decode_chunk_precise(
- &input[input_index..],
+ decode_chunk_8(
+ &chunk[0..8],
input_index,
decode_table,
- &mut output[output_index..(output_index + DECODED_CHUNK_LEN)],
+ &mut chunk_output[0..6],
+ )?;
+ decode_chunk_8(
+ &chunk[8..16],
+ input_index + 8,
+ decode_table,
+ &mut chunk_output[6..12],
+ )?;
+ decode_chunk_8(
+ &chunk[16..24],
+ input_index + 16,
+ decode_table,
+ &mut chunk_output[12..18],
+ )?;
+ decode_chunk_8(
+ &chunk[24..32],
+ input_index + 24,
+ decode_table,
+ &mut chunk_output[18..24],
)?;
-
- input_index += INPUT_CHUNK_LEN;
- output_index += DECODED_CHUNK_LEN;
}
- // always have one more (possibly partial) block of 8 input
- debug_assert!(input.len() - input_index > 1 || input.is_empty());
- debug_assert!(input.len() - input_index <= 8);
+ // remaining quads, except for the last possibly partial one, as it may have padding
+ let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3;
+ let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3;
+ {
+ let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len];
+
+ for (chunk_index, chunk) in input
+ [input_unrolled_loop_len..input_complete_nonterminal_quads_len]
+ .chunks_exact(4)
+ .enumerate()
+ {
+ let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3];
+
+ decode_chunk_4(
+ chunk,
+ input_unrolled_loop_len + chunk_index * 4,
+ decode_table,
+ chunk_output,
+ )?;
+ }
+ }
super::decode_suffix::decode_suffix(
input,
- input_index,
+ input_complete_nonterminal_quads_len,
output,
- output_index,
+ output_complete_quad_len,
decode_table,
decode_allow_trailing_bits,
padding_mode,
)
}
-/// Decode 8 bytes of input into 6 bytes of output. 8 bytes of output will be written, but only the
-/// first 6 of those contain meaningful data.
+/// Returns the length of complete quads, except for the last one, even if it is complete.
+///
+/// Returns an error if the output len is not big enough for decoding those complete quads, or if
+/// the input % 4 == 1, and that last byte is an invalid value other than a pad byte.
+///
+/// - `input` is the base64 input
+/// - `input_len_rem` is input len % 4
+/// - `output_len` is the length of the output slice
+pub(crate) fn complete_quads_len(
+ input: &[u8],
+ input_len_rem: usize,
+ output_len: usize,
+ decode_table: &[u8; 256],
+) -> Result<usize, DecodeSliceError> {
+ debug_assert!(input.len() % 4 == input_len_rem);
+
+ // detect a trailing invalid byte, like a newline, as a user convenience
+ if input_len_rem == 1 {
+ let last_byte = input[input.len() - 1];
+ // exclude pad bytes; might be part of padding that extends from earlier in the input
+ if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
+ return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into());
+ }
+ };
+
+ // skip last quad, even if it's complete, as it may have padding
+ let input_complete_nonterminal_quads_len = input
+ .len()
+ .saturating_sub(input_len_rem)
+ // if rem was 0, subtract 4 to avoid padding
+ .saturating_sub((input_len_rem == 0) as usize * 4);
+ debug_assert!(
+ input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
+ );
+
+ // check that everything except the last quad handled by decode_suffix will fit
+ if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
+ return Err(DecodeSliceError::OutputSliceTooSmall);
+ };
+ Ok(input_complete_nonterminal_quads_len)
+}
+
+/// Decode 8 bytes of input into 6 bytes of output.
///
-/// `input` is the bytes to decode, of which the first 8 bytes will be processed.
+/// `input` is the 8 bytes to decode.
/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
/// accurately)
/// `decode_table` is the lookup table for the particular base64 alphabet.
-/// `output` will have its first 8 bytes overwritten, of which only the first 6 are valid decoded
-/// data.
+/// `output` will have its first 6 bytes overwritten
// yes, really inline (worth 30-50% speedup)
#[inline(always)]
-fn decode_chunk(
+fn decode_chunk_8(
input: &[u8],
index_at_start_of_input: usize,
decode_table: &[u8; 256],
output: &mut [u8],
) -> Result<(), DecodeError> {
- let morsel = decode_table[input[0] as usize];
+ let morsel = decode_table[usize::from(input[0])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
}
- let mut accum = (morsel as u64) << 58;
+ let mut accum = u64::from(morsel) << 58;
- let morsel = decode_table[input[1] as usize];
+ let morsel = decode_table[usize::from(input[1])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 1,
input[1],
));
}
- accum |= (morsel as u64) << 52;
+ accum |= u64::from(morsel) << 52;
- let morsel = decode_table[input[2] as usize];
+ let morsel = decode_table[usize::from(input[2])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 2,
input[2],
));
}
- accum |= (morsel as u64) << 46;
+ accum |= u64::from(morsel) << 46;
- let morsel = decode_table[input[3] as usize];
+ let morsel = decode_table[usize::from(input[3])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 3,
input[3],
));
}
- accum |= (morsel as u64) << 40;
+ accum |= u64::from(morsel) << 40;
- let morsel = decode_table[input[4] as usize];
+ let morsel = decode_table[usize::from(input[4])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 4,
input[4],
));
}
- accum |= (morsel as u64) << 34;
+ accum |= u64::from(morsel) << 34;
- let morsel = decode_table[input[5] as usize];
+ let morsel = decode_table[usize::from(input[5])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 5,
input[5],
));
}
- accum |= (morsel as u64) << 28;
+ accum |= u64::from(morsel) << 28;
- let morsel = decode_table[input[6] as usize];
+ let morsel = decode_table[usize::from(input[6])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 6,
input[6],
));
}
- accum |= (morsel as u64) << 22;
+ accum |= u64::from(morsel) << 22;
- let morsel = decode_table[input[7] as usize];
+ let morsel = decode_table[usize::from(input[7])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 7,
input[7],
));
}
- accum |= (morsel as u64) << 16;
+ accum |= u64::from(morsel) << 16;
- write_u64(output, accum);
+ output[..6].copy_from_slice(&accum.to_be_bytes()[..6]);
Ok(())
}
-/// Decode an 8-byte chunk, but only write the 6 bytes actually decoded instead of including 2
-/// trailing garbage bytes.
-#[inline]
-fn decode_chunk_precise(
+/// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output.
+#[inline(always)]
+fn decode_chunk_4(
input: &[u8],
index_at_start_of_input: usize,
decode_table: &[u8; 256],
output: &mut [u8],
) -> Result<(), DecodeError> {
- let mut tmp_buf = [0_u8; 8];
+ let morsel = decode_table[usize::from(input[0])];
+ if morsel == INVALID_VALUE {
+ return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
+ }
+ let mut accum = u32::from(morsel) << 26;
- decode_chunk(
- input,
- index_at_start_of_input,
- decode_table,
- &mut tmp_buf[..],
- )?;
+ let morsel = decode_table[usize::from(input[1])];
+ if morsel == INVALID_VALUE {
+ return Err(DecodeError::InvalidByte(
+ index_at_start_of_input + 1,
+ input[1],
+ ));
+ }
+ accum |= u32::from(morsel) << 20;
+
+ let morsel = decode_table[usize::from(input[2])];
+ if morsel == INVALID_VALUE {
+ return Err(DecodeError::InvalidByte(
+ index_at_start_of_input + 2,
+ input[2],
+ ));
+ }
+ accum |= u32::from(morsel) << 14;
- output[0..6].copy_from_slice(&tmp_buf[0..6]);
+ let morsel = decode_table[usize::from(input[3])];
+ if morsel == INVALID_VALUE {
+ return Err(DecodeError::InvalidByte(
+ index_at_start_of_input + 3,
+ input[3],
+ ));
+ }
+ accum |= u32::from(morsel) << 8;
- Ok(())
-}
+ output[..3].copy_from_slice(&accum.to_be_bytes()[..3]);
-#[inline]
-fn write_u64(output: &mut [u8], value: u64) {
- output[..8].copy_from_slice(&value.to_be_bytes());
+ Ok(())
}
#[cfg(test)]
@@ -324,37 +304,36 @@ mod tests {
use crate::engine::general_purpose::STANDARD;
#[test]
- fn decode_chunk_precise_writes_only_6_bytes() {
+ fn decode_chunk_8_writes_only_6_bytes() {
let input = b"Zm9vYmFy"; // "foobar"
let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
- decode_chunk_precise(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
+ decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
}
#[test]
- fn decode_chunk_writes_8_bytes() {
- let input = b"Zm9vYmFy"; // "foobar"
- let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
+ fn decode_chunk_4_writes_only_3_bytes() {
+ let input = b"Zm9v"; // "foobar"
+ let mut output = [0_u8, 1, 2, 3];
- decode_chunk(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
- assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 0, 0], &output);
+ decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
+ assert_eq!(&vec![b'f', b'o', b'o', 3], &output);
}
#[test]
fn estimate_short_lengths() {
- for (range, (num_chunks, decoded_len_estimate)) in [
- (0..=0, (0, 0)),
- (1..=4, (1, 3)),
- (5..=8, (1, 6)),
- (9..=12, (2, 9)),
- (13..=16, (2, 12)),
- (17..=20, (3, 15)),
+ for (range, decoded_len_estimate) in [
+ (0..=0, 0),
+ (1..=4, 3),
+ (5..=8, 6),
+ (9..=12, 9),
+ (13..=16, 12),
+ (17..=20, 15),
] {
for encoded_len in range {
let estimate = GeneralPurposeEstimate::new(encoded_len);
- assert_eq!(num_chunks, estimate.num_chunks);
- assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate);
+ assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate());
}
}
}
@@ -370,13 +349,8 @@ mod tests {
let estimate = GeneralPurposeEstimate::new(encoded_len);
assert_eq!(
- ((len_128 + (INPUT_CHUNK_LEN - 1) as u128) / (INPUT_CHUNK_LEN as u128))
- as usize,
- estimate.num_chunks
- );
- assert_eq!(
- ((len_128 + 3) / 4 * 3) as usize,
- estimate.decoded_len_estimate
+ (len_128 + 3) / 4 * 3,
+ estimate.conservative_decoded_len as u128
);
})
}
diff --git a/src/engine/general_purpose/decode_suffix.rs b/src/engine/general_purpose/decode_suffix.rs
index e1e005d..02aaf51 100644
--- a/src/engine/general_purpose/decode_suffix.rs
+++ b/src/engine/general_purpose/decode_suffix.rs
@@ -1,9 +1,9 @@
use crate::{
engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode},
- DecodeError, PAD_BYTE,
+ DecodeError, DecodeSliceError, PAD_BYTE,
};
-/// Decode the last 1-8 bytes, checking for trailing set bits and padding per the provided
+/// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided
/// parameters.
///
/// Returns the decode metadata representing the total number of bytes decoded, including the ones
@@ -16,17 +16,19 @@ pub(crate) fn decode_suffix(
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
-) -> Result<DecodeMetadata, DecodeError> {
- // Decode any leftovers that aren't a complete input block of 8 bytes.
- // Use a u64 as a stack-resident 8 byte buffer.
- let mut leftover_bits: u64 = 0;
+) -> Result<DecodeMetadata, DecodeSliceError> {
+ debug_assert!((input.len() - input_index) <= 4);
+
+ // Decode any leftovers that might not be a complete input chunk of 4 bytes.
+ // Use a u32 as a stack-resident 4 byte buffer.
let mut morsels_in_leftover = 0;
- let mut padding_bytes = 0;
- let mut first_padding_index: usize = 0;
+ let mut padding_bytes_count = 0;
+ // offset from input_index
+ let mut first_padding_offset: usize = 0;
let mut last_symbol = 0_u8;
- let start_of_leftovers = input_index;
+ let mut morsels = [0_u8; 4];
- for (i, &b) in input[start_of_leftovers..].iter().enumerate() {
+ for (leftover_index, &b) in input[input_index..].iter().enumerate() {
// '=' padding
if b == PAD_BYTE {
// There can be bad padding bytes in a few ways:
@@ -41,30 +43,22 @@ pub(crate) fn decode_suffix(
// Per config, non-canonical but still functional non- or partially-padded base64
// may be treated as an error condition.
- if i % 4 < 2 {
- // Check for case #2.
- let bad_padding_index = start_of_leftovers
- + if padding_bytes > 0 {
- // If we've already seen padding, report the first padding index.
- // This is to be consistent with the normal decode logic: it will report an
- // error on the first padding character (since it doesn't expect to see
- // anything but actual encoded data).
- // This could only happen if the padding started in the previous quad since
- // otherwise this case would have been hit at i % 4 == 0 if it was the same
- // quad.
- first_padding_index
- } else {
- // haven't seen padding before, just use where we are now
- i
- };
- return Err(DecodeError::InvalidByte(bad_padding_index, b));
+ if leftover_index < 2 {
+ // Check for error #2.
+ // Either the previous byte was padding, in which case we would have already hit
+ // this case, or it wasn't, in which case this is the first such error.
+ debug_assert!(
+ leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0)
+ );
+ let bad_padding_index = input_index + leftover_index;
+ return Err(DecodeError::InvalidByte(bad_padding_index, b).into());
}
- if padding_bytes == 0 {
- first_padding_index = i;
+ if padding_bytes_count == 0 {
+ first_padding_offset = leftover_index;
}
- padding_bytes += 1;
+ padding_bytes_count += 1;
continue;
}
@@ -72,39 +66,44 @@ pub(crate) fn decode_suffix(
// To make '=' handling consistent with the main loop, don't allow
// non-suffix '=' in trailing chunk either. Report error as first
// erroneous padding.
- if padding_bytes > 0 {
- return Err(DecodeError::InvalidByte(
- start_of_leftovers + first_padding_index,
- PAD_BYTE,
- ));
+ if padding_bytes_count > 0 {
+ return Err(
+ DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(),
+ );
}
last_symbol = b;
// can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
// Pack the leftovers from left to right.
- let shift = 64 - (morsels_in_leftover + 1) * 6;
let morsel = decode_table[b as usize];
if morsel == INVALID_VALUE {
- return Err(DecodeError::InvalidByte(start_of_leftovers + i, b));
+ return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into());
}
- leftover_bits |= (morsel as u64) << shift;
+ morsels[morsels_in_leftover] = morsel;
morsels_in_leftover += 1;
}
+ // If there was 1 trailing byte, and it was valid, and we got to this point without hitting
+ // an invalid byte, now we can report invalid length
+ if !input.is_empty() && morsels_in_leftover < 2 {
+ return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into());
+ }
+
match padding_mode {
DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ }
DecodePaddingMode::RequireCanonical => {
- if (padding_bytes + morsels_in_leftover) % 4 != 0 {
- return Err(DecodeError::InvalidPadding);
+ // allow empty input
+ if (padding_bytes_count + morsels_in_leftover) % 4 != 0 {
+ return Err(DecodeError::InvalidPadding.into());
}
}
DecodePaddingMode::RequireNone => {
- if padding_bytes > 0 {
+ if padding_bytes_count > 0 {
// check at the end to make sure we let the cases of padding that should be InvalidByte
// get hit
- return Err(DecodeError::InvalidPadding);
+ return Err(DecodeError::InvalidPadding.into());
}
}
}
@@ -117,50 +116,45 @@ pub(crate) fn decode_suffix(
// bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a
// mask based on how many bits are used for just the canonical encoding, and optionally
// error if any other bits are set. In the example of one encoded byte -> 2 symbols,
- // 2 symbols can technically encode 12 bits, but the last 4 are non canonical, and
+ // 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and
// useless since there are no more symbols to provide the necessary 4 additional bits
// to finish the second original byte.
- let leftover_bits_ready_to_append = match morsels_in_leftover {
- 0 => 0,
- 2 => 8,
- 3 => 16,
- 4 => 24,
- 6 => 32,
- 7 => 40,
- 8 => 48,
- // can also be detected as case #2 bad padding above
- _ => unreachable!(
- "Impossible: must only have 0 to 8 input bytes in last chunk, with no invalid lengths"
- ),
- };
+ let leftover_bytes_to_append = morsels_in_leftover * 6 / 8;
+ // Put the up to 6 complete bytes as the high bytes.
+ // Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split.
+ let mut leftover_num = (u32::from(morsels[0]) << 26)
+ | (u32::from(morsels[1]) << 20)
+ | (u32::from(morsels[2]) << 14)
+ | (u32::from(morsels[3]) << 8);
// if there are bits set outside the bits we care about, last symbol encodes trailing bits that
// will not be included in the output
- let mask = !0 >> leftover_bits_ready_to_append;
- if !decode_allow_trailing_bits && (leftover_bits & mask) != 0 {
+ let mask = !0_u32 >> (leftover_bytes_to_append * 8);
+ if !decode_allow_trailing_bits && (leftover_num & mask) != 0 {
// last morsel is at `morsels_in_leftover` - 1
return Err(DecodeError::InvalidLastSymbol(
- start_of_leftovers + morsels_in_leftover - 1,
+ input_index + morsels_in_leftover - 1,
last_symbol,
- ));
+ )
+ .into());
}
- // TODO benchmark simply converting to big endian bytes
- let mut leftover_bits_appended_to_buf = 0;
- while leftover_bits_appended_to_buf < leftover_bits_ready_to_append {
- // `as` simply truncates the higher bits, which is what we want here
- let selected_bits = (leftover_bits >> (56 - leftover_bits_appended_to_buf)) as u8;
- output[output_index] = selected_bits;
+ // Strangely, this approach benchmarks better than writing bytes one at a time,
+ // or copy_from_slice into output.
+ for _ in 0..leftover_bytes_to_append {
+ let hi_byte = (leftover_num >> 24) as u8;
+ leftover_num <<= 8;
+ *output
+ .get_mut(output_index)
+ .ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte;
output_index += 1;
-
- leftover_bits_appended_to_buf += 8;
}
Ok(DecodeMetadata::new(
output_index,
- if padding_bytes > 0 {
- Some(input_index + first_padding_index)
+ if padding_bytes_count > 0 {
+ Some(input_index + first_padding_offset)
} else {
None
},
diff --git a/src/engine/general_purpose/mod.rs b/src/engine/general_purpose/mod.rs
index e0227f3..6fe9580 100644
--- a/src/engine/general_purpose/mod.rs
+++ b/src/engine/general_purpose/mod.rs
@@ -3,11 +3,11 @@ use crate::{
alphabet,
alphabet::Alphabet,
engine::{Config, DecodeMetadata, DecodePaddingMode},
- DecodeError,
+ DecodeSliceError,
};
use core::convert::TryInto;
-mod decode;
+pub(crate) mod decode;
pub(crate) mod decode_suffix;
pub use decode::GeneralPurposeEstimate;
@@ -173,7 +173,7 @@ impl super::Engine for GeneralPurpose {
input: &[u8],
output: &mut [u8],
estimate: Self::DecodeEstimate,
- ) -> Result<DecodeMetadata, DecodeError> {
+ ) -> Result<DecodeMetadata, DecodeSliceError> {
decode::decode_helper(
input,
estimate,
diff --git a/src/engine/mod.rs b/src/engine/mod.rs
index 16c05d7..77dcd14 100644
--- a/src/engine/mod.rs
+++ b/src/engine/mod.rs
@@ -83,17 +83,13 @@ pub trait Engine: Send + Sync {
///
/// Non-canonical trailing bits in the final tokens or non-canonical padding must be reported as
/// errors unless the engine is configured otherwise.
- ///
- /// # Panics
- ///
- /// Panics if `output` is too small.
#[doc(hidden)]
fn internal_decode(
&self,
input: &[u8],
output: &mut [u8],
decode_estimate: Self::DecodeEstimate,
- ) -> Result<DecodeMetadata, DecodeError>;
+ ) -> Result<DecodeMetadata, DecodeSliceError>;
/// Returns the config for this engine.
fn config(&self) -> &Self::Config;
@@ -253,7 +249,13 @@ pub trait Engine: Send + Sync {
let mut buffer = vec![0; estimate.decoded_len_estimate()];
let bytes_written = engine
- .internal_decode(input_bytes, &mut buffer, estimate)?
+ .internal_decode(input_bytes, &mut buffer, estimate)
+ .map_err(|e| match e {
+ DecodeSliceError::DecodeError(e) => e,
+ DecodeSliceError::OutputSliceTooSmall => {
+ unreachable!("Vec is sized conservatively")
+ }
+ })?
.decoded_len;
buffer.truncate(bytes_written);
@@ -318,7 +320,13 @@ pub trait Engine: Send + Sync {
let buffer_slice = &mut buffer.as_mut_slice()[starting_output_len..];
let bytes_written = engine
- .internal_decode(input_bytes, buffer_slice, estimate)?
+ .internal_decode(input_bytes, buffer_slice, estimate)
+ .map_err(|e| match e {
+ DecodeSliceError::DecodeError(e) => e,
+ DecodeSliceError::OutputSliceTooSmall => {
+ unreachable!("Vec is sized conservatively")
+ }
+ })?
.decoded_len;
buffer.truncate(starting_output_len + bytes_written);
@@ -354,15 +362,12 @@ pub trait Engine: Send + Sync {
where
E: Engine + ?Sized,
{
- let estimate = engine.internal_decoded_len_estimate(input_bytes.len());
-
- if output.len() < estimate.decoded_len_estimate() {
- return Err(DecodeSliceError::OutputSliceTooSmall);
- }
-
engine
- .internal_decode(input_bytes, output, estimate)
- .map_err(|e| e.into())
+ .internal_decode(
+ input_bytes,
+ output,
+ engine.internal_decoded_len_estimate(input_bytes.len()),
+ )
.map(|dm| dm.decoded_len)
}
@@ -400,6 +405,12 @@ pub trait Engine: Send + Sync {
engine.internal_decoded_len_estimate(input_bytes.len()),
)
.map(|dm| dm.decoded_len)
+ .map_err(|e| match e {
+ DecodeSliceError::DecodeError(e) => e,
+ DecodeSliceError::OutputSliceTooSmall => {
+ panic!("Output slice is too small")
+ }
+ })
}
inner(self, input.as_ref(), output)
diff --git a/src/engine/naive.rs b/src/engine/naive.rs
index 6a50cbe..af509bf 100644
--- a/src/engine/naive.rs
+++ b/src/engine/naive.rs
@@ -4,7 +4,7 @@ use crate::{
general_purpose::{self, decode_table, encode_table},
Config, DecodeEstimate, DecodeMetadata, DecodePaddingMode, Engine,
},
- DecodeError, PAD_BYTE,
+ DecodeError, DecodeSliceError,
};
use std::ops::{BitAnd, BitOr, Shl, Shr};
@@ -111,63 +111,40 @@ impl Engine for Naive {
input: &[u8],
output: &mut [u8],
estimate: Self::DecodeEstimate,
- ) -> Result<DecodeMetadata, DecodeError> {
- if estimate.rem == 1 {
- // trailing whitespace is so common that it's worth it to check the last byte to
- // possibly return a better error message
- if let Some(b) = input.last() {
- if *b != PAD_BYTE
- && self.decode_table[*b as usize] == general_purpose::INVALID_VALUE
- {
- return Err(DecodeError::InvalidByte(input.len() - 1, *b));
- }
- }
-
- return Err(DecodeError::InvalidLength);
- }
+ ) -> Result<DecodeMetadata, DecodeSliceError> {
+ let complete_nonterminal_quads_len = general_purpose::decode::complete_quads_len(
+ input,
+ estimate.rem,
+ output.len(),
+ &self.decode_table,
+ )?;
- let mut input_index = 0_usize;
- let mut output_index = 0_usize;
const BOTTOM_BYTE: u32 = 0xFF;
- // can only use the main loop on non-trailing chunks
- if input.len() > Self::DECODE_INPUT_CHUNK_SIZE {
- // skip the last chunk, whether it's partial or full, since it might
- // have padding, and start at the beginning of the chunk before that
- let last_complete_chunk_start_index = estimate.complete_chunk_len
- - if estimate.rem == 0 {
- // Trailing chunk is also full chunk, so there must be at least 2 chunks, and
- // this won't underflow
- Self::DECODE_INPUT_CHUNK_SIZE * 2
- } else {
- // Trailing chunk is partial, so it's already excluded in
- // complete_chunk_len
- Self::DECODE_INPUT_CHUNK_SIZE
- };
-
- while input_index <= last_complete_chunk_start_index {
- let chunk = &input[input_index..input_index + Self::DECODE_INPUT_CHUNK_SIZE];
- let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18)
- | self
- .decode_byte_into_u32(input_index + 1, chunk[1])?
- .shl(12)
- | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6)
- | self.decode_byte_into_u32(input_index + 3, chunk[3])?;
-
- output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8;
- output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8;
- output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8;
-
- input_index += Self::DECODE_INPUT_CHUNK_SIZE;
- output_index += 3;
- }
+ for (chunk_index, chunk) in input[..complete_nonterminal_quads_len]
+ .chunks_exact(4)
+ .enumerate()
+ {
+ let input_index = chunk_index * 4;
+ let output_index = chunk_index * 3;
+
+ let decoded_int: u32 = self.decode_byte_into_u32(input_index, chunk[0])?.shl(18)
+ | self
+ .decode_byte_into_u32(input_index + 1, chunk[1])?
+ .shl(12)
+ | self.decode_byte_into_u32(input_index + 2, chunk[2])?.shl(6)
+ | self.decode_byte_into_u32(input_index + 3, chunk[3])?;
+
+ output[output_index] = decoded_int.shr(16_u8).bitand(BOTTOM_BYTE) as u8;
+ output[output_index + 1] = decoded_int.shr(8_u8).bitand(BOTTOM_BYTE) as u8;
+ output[output_index + 2] = decoded_int.bitand(BOTTOM_BYTE) as u8;
}
general_purpose::decode_suffix::decode_suffix(
input,
- input_index,
+ complete_nonterminal_quads_len,
output,
- output_index,
+ complete_nonterminal_quads_len / 4 * 3,
&self.decode_table,
self.config.decode_allow_trailing_bits,
self.config.decode_padding_mode,
diff --git a/src/engine/tests.rs b/src/engine/tests.rs
index b048005..72bbf4b 100644
--- a/src/engine/tests.rs
+++ b/src/engine/tests.rs
@@ -19,7 +19,7 @@ use crate::{
},
read::DecoderReader,
tests::{assert_encode_sanity, random_alphabet, random_config},
- DecodeError, PAD_BYTE,
+ DecodeError, DecodeSliceError, PAD_BYTE,
};
// the case::foo syntax includes the "foo" in the generated test method names
@@ -365,26 +365,49 @@ fn decode_detect_invalid_last_symbol<E: EngineWrapper>(engine_wrapper: E) {
}
#[apply(all_engines)]
-fn decode_detect_invalid_last_symbol_when_length_is_also_invalid<E: EngineWrapper>(
- engine_wrapper: E,
-) {
- let mut rng = seeded_rng();
-
- // check across enough lengths that it would likely cover any implementation's various internal
- // small/large input division
+fn decode_detect_1_valid_symbol_in_last_quad_invalid_length<E: EngineWrapper>(engine_wrapper: E) {
for len in (0_usize..256).map(|len| len * 4 + 1) {
- let engine = E::random_alphabet(&mut rng, &STANDARD);
+ for mode in all_pad_modes() {
+ let mut input = vec![b'A'; len];
- let mut input = vec![b'A'; len];
+ let engine = E::standard_with_pad_mode(true, mode);
- // with a valid last char, it's InvalidLength
- assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&input));
- // after mangling the last char, it's InvalidByte
- input[len - 1] = b'"';
- assert_eq!(
- Err(DecodeError::InvalidByte(len - 1, b'"')),
- engine.decode(&input)
- );
+ assert_eq!(Err(DecodeError::InvalidLength(len)), engine.decode(&input));
+ // if we add padding, then the first pad byte in the quad is invalid because it should
+ // be the second symbol
+ for _ in 0..3 {
+ input.push(PAD_BYTE);
+ assert_eq!(
+ Err(DecodeError::InvalidByte(len, PAD_BYTE)),
+ engine.decode(&input)
+ );
+ }
+ }
+ }
+}
+
+#[apply(all_engines)]
+fn decode_detect_1_invalid_byte_in_last_quad_invalid_byte<E: EngineWrapper>(engine_wrapper: E) {
+ for prefix_len in (0_usize..256).map(|len| len * 4) {
+ for mode in all_pad_modes() {
+ let mut input = vec![b'A'; prefix_len];
+ input.push(b'*');
+
+ let engine = E::standard_with_pad_mode(true, mode);
+
+ assert_eq!(
+ Err(DecodeError::InvalidByte(prefix_len, b'*')),
+ engine.decode(&input)
+ );
+ // adding padding doesn't matter
+ for _ in 0..3 {
+ input.push(PAD_BYTE);
+ assert_eq!(
+ Err(DecodeError::InvalidByte(prefix_len, b'*')),
+ engine.decode(&input)
+ );
+ }
+ }
}
}
@@ -471,8 +494,10 @@ fn decode_detect_invalid_last_symbol_every_possible_three_symbols<E: EngineWrapp
// every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol, with or without any leading chunks
let mut prefix = Vec::new();
+ let mut input = Vec::new();
for _ in 0..256 {
- let mut input = prefix.clone();
+ input.clear();
+ input.extend_from_slice(&prefix);
let mut symbols = [0_u8; 4];
for &s1 in STANDARD.symbols.iter() {
@@ -613,75 +638,119 @@ fn decode_invalid_byte_error<E: EngineWrapper>(engine_wrapper: E) {
/// Any amount of padding anywhere before the final non padding character = invalid byte at first
/// pad byte.
-/// From this, we know padding must extend to the end of the input.
-// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it
-// can end a decode on the quad that happens to contain the start of the padding
-#[apply(all_engines_except_decoder_reader)]
-fn decode_padding_before_final_non_padding_char_error_invalid_byte<E: EngineWrapper>(
+/// From this and [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix_all_modes],
+/// we know padding must extend contiguously to the end of the input.
+#[apply(all_engines)]
+fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes<
+ E: EngineWrapper,
+>(
engine_wrapper: E,
) {
- let mut rng = seeded_rng();
+ // Different amounts of padding, w/ offset from end for the last non-padding char.
+ // Only canonical padding, so Canonical mode will work.
+ let suffixes = &[("AA==", 2), ("AAA=", 1), ("AAAA", 0)];
- // the different amounts of proper padding, w/ offset from end for the last non-padding char
- let suffixes = [("/w==", 2), ("iYu=", 1), ("zzzz", 0)];
+ for mode in pad_modes_allowing_padding() {
+ // We don't encode, so we don't care about encode padding.
+ let engine = E::standard_with_pad_mode(true, mode);
- let prefix_quads_range = distributions::Uniform::from(0..=256);
+ decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad(
+ engine,
+ suffixes.as_slice(),
+ );
+ }
+}
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
+/// See [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes]
+#[apply(all_engines)]
+fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix<
+ E: EngineWrapper,
+>(
+ engine_wrapper: E,
+) {
+ // Different amounts of padding, w/ offset from end for the last non-padding char, and
+ // non-canonical padding.
+ let suffixes = [
+ ("AA==", 2),
+ ("AA=", 1),
+ ("AA", 0),
+ ("AAA=", 1),
+ ("AAA", 0),
+ ("AAAA", 0),
+ ];
- for _ in 0..100_000 {
- for (suffix, offset) in suffixes.iter() {
- let mut s = "ABCD".repeat(prefix_quads_range.sample(&mut rng));
- s.push_str(suffix);
- let mut encoded = s.into_bytes();
+ // We don't encode, so we don't care about encode padding.
+ // Decoding is indifferent so that we don't get caught by missing padding on the last quad
+ let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent);
- // calculate a range to write padding into that leaves at least one non padding char
- let last_non_padding_offset = encoded.len() - 1 - offset;
+ decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad(
+ engine,
+ suffixes.as_slice(),
+ )
+}
- // don't include last non padding char as it must stay not padding
- let padding_end = rng.gen_range(0..last_non_padding_offset);
+fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad(
+ engine: impl Engine,
+ suffixes: &[(&str, usize)],
+) {
+ let mut rng = seeded_rng();
- // don't use more than 100 bytes of padding, but also use shorter lengths when
- // padding_end is near the start of the encoded data to avoid biasing to padding
- // the entire prefix on short lengths
- let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1));
- let padding_start = padding_end.saturating_sub(padding_len);
+ let prefix_quads_range = distributions::Uniform::from(0..=256);
- encoded[padding_start..=padding_end].fill(PAD_BYTE);
+ for _ in 0..100_000 {
+ for (suffix, suffix_offset) in suffixes.iter() {
+ let mut s = "AAAA".repeat(prefix_quads_range.sample(&mut rng));
+ s.push_str(suffix);
+ let mut encoded = s.into_bytes();
- assert_eq!(
- Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)),
- engine.decode(&encoded),
- );
- }
+ // calculate a range to write padding into that leaves at least one non padding char
+ let last_non_padding_offset = encoded.len() - 1 - suffix_offset;
+
+ // don't include last non padding char as it must stay not padding
+ let padding_end = rng.gen_range(0..last_non_padding_offset);
+
+ // don't use more than 100 bytes of padding, but also use shorter lengths when
+ // padding_end is near the start of the encoded data to avoid biasing to padding
+ // the entire prefix on short lengths
+ let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1));
+ let padding_start = padding_end.saturating_sub(padding_len);
+
+ encoded[padding_start..=padding_end].fill(PAD_BYTE);
+
+ // should still have non-padding before any final padding
+ assert_ne!(PAD_BYTE, encoded[last_non_padding_offset]);
+ assert_eq!(
+ Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)),
+ engine.decode(&encoded),
+ "len: {}, input: {}",
+ encoded.len(),
+ String::from_utf8(encoded).unwrap()
+ );
}
}
}
-/// Any amount of padding before final chunk that crosses over into final chunk with 2-4 bytes =
+/// Any amount of padding before final chunk that crosses over into final chunk with 1-4 bytes =
/// invalid byte at first pad byte.
-/// From this and [decode_padding_starts_before_final_chunk_error_invalid_length] we know the
-/// padding must start in the final chunk.
-// DecoderReader pseudo-engine detects InvalidLastSymbol instead of InvalidLength because it
-// can end a decode on the quad that happens to contain the start of the padding
-#[apply(all_engines_except_decoder_reader)]
-fn decode_padding_starts_before_final_chunk_error_invalid_byte<E: EngineWrapper>(
+/// From this we know the padding must start in the final chunk.
+#[apply(all_engines)]
+fn decode_padding_starts_before_final_chunk_error_invalid_byte_at_first_pad<E: EngineWrapper>(
engine_wrapper: E,
) {
let mut rng = seeded_rng();
// must have at least one prefix quad
let prefix_quads_range = distributions::Uniform::from(1..256);
- // excluding 1 since we don't care about invalid length in this test
- let suffix_pad_len_range = distributions::Uniform::from(2..=4);
- for mode in all_pad_modes() {
+ let suffix_pad_len_range = distributions::Uniform::from(1..=4);
+ // don't use no-padding mode, as the reader decode might decode a block that ends with
+ // valid padding, which should then be referenced when encountering the later invalid byte
+ for mode in pad_modes_allowing_padding() {
// we don't encode so we don't care about encode padding
let engine = E::standard_with_pad_mode(true, mode);
for _ in 0..100_000 {
let suffix_len = suffix_pad_len_range.sample(&mut rng);
- let mut encoded = "ABCD"
+ // all 0 bits so we don't hit InvalidLastSymbol with the reader decoder
+ let mut encoded = "AAAA"
.repeat(prefix_quads_range.sample(&mut rng))
.into_bytes();
encoded.resize(encoded.len() + suffix_len, PAD_BYTE);
@@ -705,40 +774,6 @@ fn decode_padding_starts_before_final_chunk_error_invalid_byte<E: EngineWrapper>
}
}
-/// Any amount of padding before final chunk that crosses over into final chunk with 1 byte =
-/// invalid length.
-/// From this we know the padding must start in the final chunk.
-// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by
-// decoding only the available complete quads
-#[apply(all_engines_except_decoder_reader)]
-fn decode_padding_starts_before_final_chunk_error_invalid_length<E: EngineWrapper>(
- engine_wrapper: E,
-) {
- let mut rng = seeded_rng();
-
- // must have at least one prefix quad
- let prefix_quads_range = distributions::Uniform::from(1..256);
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
- for _ in 0..100_000 {
- let mut encoded = "ABCD"
- .repeat(prefix_quads_range.sample(&mut rng))
- .into_bytes();
- encoded.resize(encoded.len() + 1, PAD_BYTE);
-
- // amount of padding must be long enough to extend back from suffix into previous
- // quads
- let padding_len = rng.gen_range(1 + 1..encoded.len());
- // no non-padding after padding in this test, so padding goes to the end
- let padding_start = encoded.len() - padding_len;
- encoded[padding_start..].fill(PAD_BYTE);
-
- assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),);
- }
- }
-}
-
/// 0-1 bytes of data before any amount of padding in final chunk = invalid byte, since padding
/// is not valid data (consistent with error for pad bytes in earlier chunks).
/// From this we know there must be 2-3 bytes of data before padding
@@ -756,29 +791,23 @@ fn decode_too_little_data_before_padding_error_invalid_byte<E: EngineWrapper>(en
let suffix_data_len = suffix_data_len_range.sample(&mut rng);
let prefix_quad_len = prefix_quads_range.sample(&mut rng);
- // ensure there is a suffix quad
- let min_padding = usize::from(suffix_data_len == 0);
-
// for all possible padding lengths
- for padding_len in min_padding..=(4 - suffix_data_len) {
+ for padding_len in 1..=(4 - suffix_data_len) {
let mut encoded = "ABCD".repeat(prefix_quad_len).into_bytes();
encoded.resize(encoded.len() + suffix_data_len, b'A');
encoded.resize(encoded.len() + padding_len, PAD_BYTE);
- if suffix_data_len + padding_len == 1 {
- assert_eq!(Err(DecodeError::InvalidLength), engine.decode(&encoded),);
- } else {
- assert_eq!(
- Err(DecodeError::InvalidByte(
- prefix_quad_len * 4 + suffix_data_len,
- PAD_BYTE,
- )),
- engine.decode(&encoded),
- "suffix data len {} pad len {}",
- suffix_data_len,
- padding_len
- );
- }
+ assert_eq!(
+ Err(DecodeError::InvalidByte(
+ prefix_quad_len * 4 + suffix_data_len,
+ PAD_BYTE,
+ )),
+ engine.decode(&encoded),
+ "input {} suffix data len {} pad len {}",
+ String::from_utf8(encoded).unwrap(),
+ suffix_data_len,
+ padding_len
+ );
}
}
}
@@ -918,258 +947,64 @@ fn decode_pad_mode_indifferent_padding_accepts_anything<E: EngineWrapper>(engine
);
}
-//this is a MAY in the rfc: https://tools.ietf.org/html/rfc4648#section-3.3
-// DecoderReader pseudo-engine finds the first padding, but doesn't report it as an error,
-// because in the next decode it finds more padding, which is reported as InvalidByte, just
-// with an offset at its position in the second decode, rather than being linked to the start
-// of the padding that was first seen in the previous decode.
-#[apply(all_engines_except_decoder_reader)]
-fn decode_pad_byte_in_penultimate_quad_error<E: EngineWrapper>(engine_wrapper: E) {
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
-
- for num_prefix_quads in 0..256 {
- // leave room for at least one pad byte in penultimate quad
- for num_valid_bytes_penultimate_quad in 0..4 {
- // can't have 1 or it would be invalid length
- for num_pad_bytes_in_final_quad in 2..=4 {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
-
- // varying amounts of padding in the penultimate quad
- for _ in 0..num_valid_bytes_penultimate_quad {
- s.push('A');
- }
- // finish penultimate quad with padding
- for _ in num_valid_bytes_penultimate_quad..4 {
- s.push('=');
- }
- // and more padding in the final quad
- for _ in 0..num_pad_bytes_in_final_quad {
- s.push('=');
- }
-
- // padding should be an invalid byte before the final quad.
- // Could argue that the *next* padding byte (in the next quad) is technically the first
- // erroneous one, but reporting that accurately is more complex and probably nobody cares
- assert_eq!(
- DecodeError::InvalidByte(
- num_prefix_quads * 4 + num_valid_bytes_penultimate_quad,
- b'=',
- ),
- engine.decode(&s).unwrap_err(),
- );
- }
- }
- }
- }
-}
-
-#[apply(all_engines)]
-fn decode_bytes_after_padding_in_final_quad_error<E: EngineWrapper>(engine_wrapper: E) {
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
-
- for num_prefix_quads in 0..256 {
- // leave at least one byte in the quad for padding
- for bytes_after_padding in 1..4 {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
-
- // every invalid padding position with a 3-byte final quad: 1 to 3 bytes after padding
- for _ in 0..(3 - bytes_after_padding) {
- s.push('A');
- }
- s.push('=');
- for _ in 0..bytes_after_padding {
- s.push('A');
- }
-
- // First (and only) padding byte is invalid.
- assert_eq!(
- DecodeError::InvalidByte(
- num_prefix_quads * 4 + (3 - bytes_after_padding),
- b'='
- ),
- engine.decode(&s).unwrap_err()
- );
- }
- }
- }
-}
-
-#[apply(all_engines)]
-fn decode_absurd_pad_error<E: EngineWrapper>(engine_wrapper: E) {
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
-
- for num_prefix_quads in 0..256 {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
- s.push_str("==Y=Wx===pY=2U=====");
-
- // first padding byte
- assert_eq!(
- DecodeError::InvalidByte(num_prefix_quads * 4, b'='),
- engine.decode(&s).unwrap_err()
- );
- }
- }
-}
-
-// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by
-// decoding only the available complete quads
-#[apply(all_engines_except_decoder_reader)]
-fn decode_too_much_padding_returns_error<E: EngineWrapper>(engine_wrapper: E) {
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
-
- for num_prefix_quads in 0..256 {
- // add enough padding to ensure that we'll hit all decode stages at the different lengths
- for pad_bytes in 1..=64 {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
- let padding: String = "=".repeat(pad_bytes);
- s.push_str(&padding);
-
- if pad_bytes % 4 == 1 {
- assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err());
- } else {
- assert_eq!(
- DecodeError::InvalidByte(num_prefix_quads * 4, b'='),
- engine.decode(&s).unwrap_err()
- );
- }
- }
- }
- }
-}
-
-// DecoderReader pseudo-engine detects InvalidByte instead of InvalidLength because it starts by
-// decoding only the available complete quads
-#[apply(all_engines_except_decoder_reader)]
-fn decode_padding_followed_by_non_padding_returns_error<E: EngineWrapper>(engine_wrapper: E) {
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
-
- for num_prefix_quads in 0..256 {
- for pad_bytes in 0..=32 {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
- let padding: String = "=".repeat(pad_bytes);
- s.push_str(&padding);
- s.push('E');
-
- if pad_bytes % 4 == 0 {
- assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err());
- } else {
- assert_eq!(
- DecodeError::InvalidByte(num_prefix_quads * 4, b'='),
- engine.decode(&s).unwrap_err()
- );
- }
- }
- }
- }
-}
-
-#[apply(all_engines)]
-fn decode_one_char_in_final_quad_with_padding_error<E: EngineWrapper>(engine_wrapper: E) {
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
-
- for num_prefix_quads in 0..256 {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
- s.push_str("E=");
-
- assert_eq!(
- DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='),
- engine.decode(&s).unwrap_err()
- );
-
- // more padding doesn't change the error
- s.push('=');
- assert_eq!(
- DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='),
- engine.decode(&s).unwrap_err()
- );
-
- s.push('=');
- assert_eq!(
- DecodeError::InvalidByte(num_prefix_quads * 4 + 1, b'='),
- engine.decode(&s).unwrap_err()
- );
- }
- }
-}
-
-#[apply(all_engines)]
-fn decode_too_few_symbols_in_final_quad_error<E: EngineWrapper>(engine_wrapper: E) {
- for mode in all_pad_modes() {
- // we don't encode so we don't care about encode padding
- let engine = E::standard_with_pad_mode(true, mode);
-
- for num_prefix_quads in 0..256 {
- // <2 is invalid
- for final_quad_symbols in 0..2 {
- for padding_symbols in 0..=(4 - final_quad_symbols) {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
-
- for _ in 0..final_quad_symbols {
- s.push('A');
- }
- for _ in 0..padding_symbols {
- s.push('=');
- }
-
- match final_quad_symbols + padding_symbols {
- 0 => continue,
- 1 => {
- assert_eq!(DecodeError::InvalidLength, engine.decode(&s).unwrap_err());
- }
- _ => {
- // error reported at first padding byte
- assert_eq!(
- DecodeError::InvalidByte(
- num_prefix_quads * 4 + final_quad_symbols,
- b'=',
- ),
- engine.decode(&s).unwrap_err()
- );
- }
- }
- }
- }
- }
- }
-}
-
+/// 1 trailing byte that's not padding is detected as invalid byte even though there's padding
+/// in the middle of the input. This is essentially mandating the eager check for 1 trailing byte
+/// to catch the \n suffix case.
// DecoderReader pseudo-engine can't handle DecodePaddingMode::RequireNone since it will decode
// a complete quad with padding in it before encountering the stray byte that makes it an invalid
// length
#[apply(all_engines_except_decoder_reader)]
-fn decode_invalid_trailing_bytes<E: EngineWrapper>(engine_wrapper: E) {
+fn decode_invalid_trailing_bytes_all_pad_modes_invalid_byte<E: EngineWrapper>(engine_wrapper: E) {
for mode in all_pad_modes() {
do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode);
}
}
#[apply(all_engines)]
-fn decode_invalid_trailing_bytes_all_modes<E: EngineWrapper>(engine_wrapper: E) {
+fn decode_invalid_trailing_bytes_invalid_byte<E: EngineWrapper>(engine_wrapper: E) {
// excluding no padding mode because the DecoderWrapper pseudo-engine will fail with
// InvalidPadding because it will decode the last complete quad with padding first
for mode in pad_modes_allowing_padding() {
do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode);
}
}
+fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) {
+ for last_byte in [b'*', b'\n'] {
+ for num_prefix_quads in 0..256 {
+ let mut s: String = "ABCD".repeat(num_prefix_quads);
+ s.push_str("Cg==");
+ let mut input = s.into_bytes();
+ input.push(last_byte);
+ // The case of trailing newlines is common enough to warrant a test for a good error
+ // message.
+ assert_eq!(
+ Err(DecodeError::InvalidByte(
+ num_prefix_quads * 4 + 4,
+ last_byte
+ )),
+ engine.decode(&input),
+ "mode: {:?}, input: {}",
+ mode,
+ String::from_utf8(input).unwrap()
+ );
+ }
+ }
+}
+
+/// When there's 1 trailing byte, but it's padding, it's only InvalidByte if there isn't padding
+/// earlier.
#[apply(all_engines)]
-fn decode_invalid_trailing_padding_as_invalid_length<E: EngineWrapper>(engine_wrapper: E) {
+fn decode_invalid_trailing_padding_as_invalid_byte_at_first_pad_byte<E: EngineWrapper>(
+ engine_wrapper: E,
+) {
// excluding no padding mode because the DecoderWrapper pseudo-engine will fail with
// InvalidPadding because it will decode the last complete quad with padding first
for mode in pad_modes_allowing_padding() {
- do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode);
+ do_invalid_trailing_padding_as_invalid_byte_at_first_padding(
+ E::standard_with_pad_mode(true, mode),
+ mode,
+ );
}
}
@@ -1177,48 +1012,36 @@ fn decode_invalid_trailing_padding_as_invalid_length<E: EngineWrapper>(engine_wr
// a complete quad with padding in it before encountering the stray byte that makes it an invalid
// length
#[apply(all_engines_except_decoder_reader)]
-fn decode_invalid_trailing_padding_as_invalid_length_all_modes<E: EngineWrapper>(
+fn decode_invalid_trailing_padding_as_invalid_byte_at_first_byte_all_modes<E: EngineWrapper>(
engine_wrapper: E,
) {
for mode in all_pad_modes() {
- do_invalid_trailing_padding_as_invalid_length(E::standard_with_pad_mode(true, mode), mode);
+ do_invalid_trailing_padding_as_invalid_byte_at_first_padding(
+ E::standard_with_pad_mode(true, mode),
+ mode,
+ );
}
}
-
-#[apply(all_engines)]
-fn decode_wrong_length_error<E: EngineWrapper>(engine_wrapper: E) {
- let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent);
-
+fn do_invalid_trailing_padding_as_invalid_byte_at_first_padding(
+ engine: impl Engine,
+ mode: DecodePaddingMode,
+) {
for num_prefix_quads in 0..256 {
- // at least one token, otherwise it wouldn't be a final quad
- for num_tokens_final_quad in 1..=4 {
- for num_padding in 0..=(4 - num_tokens_final_quad) {
- let mut s: String = "IIII".repeat(num_prefix_quads);
- for _ in 0..num_tokens_final_quad {
- s.push('g');
- }
- for _ in 0..num_padding {
- s.push('=');
- }
+ for (suffix, pad_offset) in [("AA===", 2), ("AAA==", 3), ("AAAA=", 4)] {
+ let mut s: String = "ABCD".repeat(num_prefix_quads);
+ s.push_str(suffix);
- let res = engine.decode(&s);
- if num_tokens_final_quad >= 2 {
- assert!(res.is_ok());
- } else if num_tokens_final_quad == 1 && num_padding > 0 {
- // = is invalid if it's too early
- assert_eq!(
- Err(DecodeError::InvalidByte(
- num_prefix_quads * 4 + num_tokens_final_quad,
- 61
- )),
- res
- );
- } else if num_padding > 2 {
- assert_eq!(Err(DecodeError::InvalidPadding), res);
- } else {
- assert_eq!(Err(DecodeError::InvalidLength), res);
- }
- }
+ assert_eq!(
+ // pad after `g`, not the last one
+ Err(DecodeError::InvalidByte(
+ num_prefix_quads * 4 + pad_offset,
+ PAD_BYTE
+ )),
+ engine.decode(&s),
+ "mode: {:?}, input: {}",
+ mode,
+ s
+ );
}
}
}
@@ -1248,12 +1071,20 @@ fn decode_into_slice_fits_in_precisely_sized_slice<E: EngineWrapper>(engine_wrap
assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len);
decode_buf.resize(input_len, 0);
-
// decode into the non-empty buf
let decode_bytes_written = engine
.decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..])
.unwrap();
+ assert_eq!(orig_data.len(), decode_bytes_written);
+ assert_eq!(orig_data, decode_buf);
+ // same for checked variant
+ decode_buf.clear();
+ decode_buf.resize(input_len, 0);
+ // decode into the non-empty buf
+ let decode_bytes_written = engine
+ .decode_slice(encoded_data.as_bytes(), &mut decode_buf[..])
+ .unwrap();
assert_eq!(orig_data.len(), decode_bytes_written);
assert_eq!(orig_data, decode_buf);
}
@@ -1287,7 +1118,10 @@ fn inner_decode_reports_padding_position<E: EngineWrapper>(engine_wrapper: E) {
if pad_position % 4 < 2 {
// impossible padding
assert_eq!(
- Err(DecodeError::InvalidByte(pad_position, PAD_BYTE)),
+ Err(DecodeSliceError::DecodeError(DecodeError::InvalidByte(
+ pad_position,
+ PAD_BYTE
+ ))),
decode_res
);
} else {
@@ -1355,35 +1189,60 @@ fn estimate_via_u128_inflation<E: EngineWrapper>(engine_wrapper: E) {
})
}
-fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) {
- for num_prefix_quads in 0..256 {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
- s.push_str("Cg==\n");
-
- // The case of trailing newlines is common enough to warrant a test for a good error
- // message.
- assert_eq!(
- Err(DecodeError::InvalidByte(num_prefix_quads * 4 + 4, b'\n')),
- engine.decode(&s),
- "mode: {:?}, input: {}",
- mode,
- s
- );
- }
-}
+#[apply(all_engines)]
+fn decode_slice_checked_fails_gracefully_at_all_output_lengths<E: EngineWrapper>(
+ engine_wrapper: E,
+) {
+ let mut rng = seeded_rng();
+ for original_len in 0..1000 {
+ let mut original = vec![0; original_len];
+ rng.fill(&mut original[..]);
+
+ for mode in all_pad_modes() {
+ let engine = E::standard_with_pad_mode(
+ match mode {
+ DecodePaddingMode::Indifferent | DecodePaddingMode::RequireCanonical => true,
+ DecodePaddingMode::RequireNone => false,
+ },
+ mode,
+ );
-fn do_invalid_trailing_padding_as_invalid_length(engine: impl Engine, mode: DecodePaddingMode) {
- for num_prefix_quads in 0..256 {
- let mut s: String = "ABCD".repeat(num_prefix_quads);
- s.push_str("Cg===");
+ let encoded = engine.encode(&original);
+ let mut decode_buf = Vec::with_capacity(original_len);
+ for decode_buf_len in 0..original_len {
+ decode_buf.resize(decode_buf_len, 0);
+ assert_eq!(
+ DecodeSliceError::OutputSliceTooSmall,
+ engine
+ .decode_slice(&encoded, &mut decode_buf[..])
+ .unwrap_err(),
+ "original len: {}, encoded len: {}, buf len: {}, mode: {:?}",
+ original_len,
+ encoded.len(),
+ decode_buf_len,
+ mode
+ );
+ // internal method works the same
+ assert_eq!(
+ DecodeSliceError::OutputSliceTooSmall,
+ engine
+ .internal_decode(
+ encoded.as_bytes(),
+ &mut decode_buf[..],
+ engine.internal_decoded_len_estimate(encoded.len())
+ )
+ .unwrap_err()
+ );
+ }
- assert_eq!(
- Err(DecodeError::InvalidLength),
- engine.decode(&s),
- "mode: {:?}, input: {}",
- mode,
- s
- );
+ decode_buf.resize(original_len, 0);
+ rng.fill(&mut decode_buf[..]);
+ assert_eq!(
+ original_len,
+ engine.decode_slice(&encoded, &mut decode_buf[..]).unwrap()
+ );
+ assert_eq!(original, decode_buf);
+ }
}
}
@@ -1547,7 +1406,7 @@ impl EngineWrapper for NaiveWrapper {
naive::Naive::new(
&STANDARD,
naive::NaiveConfig {
- encode_padding: false,
+ encode_padding: encode_pad,
decode_allow_trailing_bits: false,
decode_padding_mode: decode_pad_mode,
},
@@ -1616,7 +1475,7 @@ impl<E: Engine> Engine for DecoderReaderEngine<E> {
input: &[u8],
output: &mut [u8],
decode_estimate: Self::DecodeEstimate,
- ) -> Result<DecodeMetadata, DecodeError> {
+ ) -> Result<DecodeMetadata, DecodeSliceError> {
let mut reader = DecoderReader::new(input, &self.engine);
let mut buf = vec![0; input.len()];
// to avoid effects like not detecting invalid length due to progressively growing
@@ -1635,6 +1494,9 @@ impl<E: Engine> Engine for DecoderReaderEngine<E> {
.and_then(|inner| inner.downcast::<DecodeError>().ok())
.unwrap()
})?;
+ if output.len() < buf.len() {
+ return Err(DecodeSliceError::OutputSliceTooSmall);
+ }
output[..buf.len()].copy_from_slice(&buf);
Ok(DecodeMetadata::new(
buf.len(),
diff --git a/src/lib.rs b/src/lib.rs
index 6ec3c12..579a722 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -228,8 +228,7 @@
unused_extern_crates,
unused_import_braces,
unused_results,
- variant_size_differences,
- warnings
+ variant_size_differences
)]
#![forbid(unsafe_code)]
// Allow globally until https://github.com/rust-lang/rust-clippy/issues/8768 is resolved.
diff --git a/src/read/decoder.rs b/src/read/decoder.rs
index b656ae3..781f6f8 100644
--- a/src/read/decoder.rs
+++ b/src/read/decoder.rs
@@ -1,4 +1,4 @@
-use crate::{engine::Engine, DecodeError, PAD_BYTE};
+use crate::{engine::Engine, DecodeError, DecodeSliceError, PAD_BYTE};
use std::{cmp, fmt, io};
// This should be large, but it has to fit on the stack.
@@ -35,37 +35,39 @@ pub struct DecoderReader<'e, E: Engine, R: io::Read> {
/// Where b64 data is read from
inner: R,
- // Holds b64 data read from the delegate reader.
+ /// Holds b64 data read from the delegate reader.
b64_buffer: [u8; BUF_SIZE],
- // The start of the pending buffered data in b64_buffer.
+ /// The start of the pending buffered data in `b64_buffer`.
b64_offset: usize,
- // The amount of buffered b64 data.
+ /// The amount of buffered b64 data after `b64_offset` in `b64_len`.
b64_len: usize,
- // Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a
- // decoded chunk in to, we have to be able to hang on to a few decoded bytes.
- // Technically we only need to hold 2 bytes but then we'd need a separate temporary buffer to
- // decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest
- // into here, which seems like a lot of complexity for 1 extra byte of storage.
- decoded_buffer: [u8; DECODED_CHUNK_SIZE],
- // index of start of decoded data
+ /// Since the caller may provide us with a buffer of size 1 or 2 that's too small to copy a
+ /// decoded chunk in to, we have to be able to hang on to a few decoded bytes.
+ /// Technically we only need to hold 2 bytes, but then we'd need a separate temporary buffer to
+ /// decode 3 bytes into and then juggle copying one byte into the provided read buf and the rest
+ /// into here, which seems like a lot of complexity for 1 extra byte of storage.
+ decoded_chunk_buffer: [u8; DECODED_CHUNK_SIZE],
+ /// Index of start of decoded data in `decoded_chunk_buffer`
decoded_offset: usize,
- // length of decoded data
+ /// Length of decoded data after `decoded_offset` in `decoded_chunk_buffer`
decoded_len: usize,
- // used to provide accurate offsets in errors
- total_b64_decoded: usize,
- // offset of previously seen padding, if any
+ /// Input length consumed so far.
+ /// Used to provide accurate offsets in errors
+ input_consumed_len: usize,
+ /// offset of previously seen padding, if any
padding_offset: Option<usize>,
}
+// exclude b64_buffer as it's uselessly large
impl<'e, E: Engine, R: io::Read> fmt::Debug for DecoderReader<'e, E, R> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("DecoderReader")
.field("b64_offset", &self.b64_offset)
.field("b64_len", &self.b64_len)
- .field("decoded_buffer", &self.decoded_buffer)
+ .field("decoded_chunk_buffer", &self.decoded_chunk_buffer)
.field("decoded_offset", &self.decoded_offset)
.field("decoded_len", &self.decoded_len)
- .field("total_b64_decoded", &self.total_b64_decoded)
+ .field("input_consumed_len", &self.input_consumed_len)
.field("padding_offset", &self.padding_offset)
.finish()
}
@@ -80,10 +82,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
b64_buffer: [0; BUF_SIZE],
b64_offset: 0,
b64_len: 0,
- decoded_buffer: [0; DECODED_CHUNK_SIZE],
+ decoded_chunk_buffer: [0; DECODED_CHUNK_SIZE],
decoded_offset: 0,
decoded_len: 0,
- total_b64_decoded: 0,
+ input_consumed_len: 0,
padding_offset: None,
}
}
@@ -100,7 +102,7 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
debug_assert!(copy_len <= self.decoded_len);
buf[..copy_len].copy_from_slice(
- &self.decoded_buffer[self.decoded_offset..self.decoded_offset + copy_len],
+ &self.decoded_chunk_buffer[self.decoded_offset..self.decoded_offset + copy_len],
);
self.decoded_offset += copy_len;
@@ -131,6 +133,10 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
/// caller's responsibility to choose the number of b64 bytes to decode correctly.
///
/// Returns a Result with the number of decoded bytes written to `buf`.
+ ///
+ /// # Panics
+ ///
+ /// panics if `buf` is too small
fn decode_to_buf(&mut self, b64_len_to_decode: usize, buf: &mut [u8]) -> io::Result<usize> {
debug_assert!(self.b64_len >= b64_len_to_decode);
debug_assert!(self.b64_offset + self.b64_len <= BUF_SIZE);
@@ -144,22 +150,35 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
buf,
self.engine.internal_decoded_len_estimate(b64_len_to_decode),
)
- .map_err(|e| match e {
- DecodeError::InvalidByte(offset, byte) => {
- // This can be incorrect, but not in a way that probably matters to anyone:
- // if there was padding handled in a previous decode, and we are now getting
- // InvalidByte due to more padding, we should arguably report InvalidByte with
- // PAD_BYTE at the original padding position (`self.padding_offset`), but we
- // don't have a good way to tie those two cases together, so instead we
- // just report the invalid byte as if the previous padding, and its possibly
- // related downgrade to a now invalid byte, didn't happen.
- DecodeError::InvalidByte(self.total_b64_decoded + offset, byte)
+ .map_err(|dse| match dse {
+ DecodeSliceError::DecodeError(de) => {
+ match de {
+ DecodeError::InvalidByte(offset, byte) => {
+ match (byte, self.padding_offset) {
+ // if there was padding in a previous block of decoding that happened to
+ // be correct, and we now find more padding that happens to be incorrect,
+ // to be consistent with non-reader decodes, record the error at the first
+ // padding
+ (PAD_BYTE, Some(first_pad_offset)) => {
+ DecodeError::InvalidByte(first_pad_offset, PAD_BYTE)
+ }
+ _ => {
+ DecodeError::InvalidByte(self.input_consumed_len + offset, byte)
+ }
+ }
+ }
+ DecodeError::InvalidLength(len) => {
+ DecodeError::InvalidLength(self.input_consumed_len + len)
+ }
+ DecodeError::InvalidLastSymbol(offset, byte) => {
+ DecodeError::InvalidLastSymbol(self.input_consumed_len + offset, byte)
+ }
+ DecodeError::InvalidPadding => DecodeError::InvalidPadding,
+ }
}
- DecodeError::InvalidLength => DecodeError::InvalidLength,
- DecodeError::InvalidLastSymbol(offset, byte) => {
- DecodeError::InvalidLastSymbol(self.total_b64_decoded + offset, byte)
+ DecodeSliceError::OutputSliceTooSmall => {
+ unreachable!("buf is sized correctly in calling code")
}
- DecodeError::InvalidPadding => DecodeError::InvalidPadding,
})
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
@@ -176,8 +195,8 @@ impl<'e, E: Engine, R: io::Read> DecoderReader<'e, E, R> {
self.padding_offset = self.padding_offset.or(decode_metadata
.padding_offset
- .map(|offset| self.total_b64_decoded + offset));
- self.total_b64_decoded += b64_len_to_decode;
+ .map(|offset| self.input_consumed_len + offset));
+ self.input_consumed_len += b64_len_to_decode;
self.b64_offset += b64_len_to_decode;
self.b64_len -= b64_len_to_decode;
@@ -283,7 +302,7 @@ impl<'e, E: Engine, R: io::Read> io::Read for DecoderReader<'e, E, R> {
let to_decode = cmp::min(self.b64_len, BASE64_CHUNK_SIZE);
let decoded = self.decode_to_buf(to_decode, &mut decoded_chunk[..])?;
- self.decoded_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]);
+ self.decoded_chunk_buffer[..decoded].copy_from_slice(&decoded_chunk[..decoded]);
self.decoded_offset = 0;
self.decoded_len = decoded;