// Copyright 2023 The Chromium Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/filter/zstd_source_stream.h" #include #include #include #define ZSTD_STATIC_LINKING_ONLY #include "base/bits.h" #include "base/check_op.h" #include "base/metrics/histogram_macros.h" #include "base/numerics/safe_conversions.h" #include "net/base/io_buffer.h" #include "third_party/zstd/src/lib/zstd.h" #include "third_party/zstd/src/lib/zstd_errors.h" namespace net { namespace { const char kZstd[] = "ZSTD"; struct FreeContextDeleter { inline void operator()(ZSTD_DCtx* ptr) const { ZSTD_freeDCtx(ptr); } }; // ZstdSourceStream applies Zstd content decoding to a data stream. // Zstd format speciication: https://datatracker.ietf.org/doc/html/rfc8878 class ZstdSourceStream : public FilterSourceStream { public: explicit ZstdSourceStream(std::unique_ptr upstream, scoped_refptr dictionary = nullptr, size_t dictionary_size = 0u) : FilterSourceStream(SourceStream::TYPE_ZSTD, std::move(upstream)), dictionary_(std::move(dictionary)), dictionary_size_(dictionary_size) { ZSTD_customMem custom_mem = {&customMalloc, &customFree, this}; dctx_.reset(ZSTD_createDCtx_advanced(custom_mem)); CHECK(dctx_); // Following RFC 8878 recommendation (see section 3.1.1.1.2 Window // Descriptor) of using a maximum 8MB memory buffer to decompress frames // to '... protect decoders from unreasonable memory requirements'. int window_log_max = 23; if (dictionary_) { // For shared dictionary case, allow using larger window size (Log2Ceiling // of `dictionary_size`). It is safe because we have the size limit per // shared dictionary and the total dictionary size limit. window_log_max = std::max(base::bits::Log2Ceiling( base::checked_cast(dictionary_size_)), window_log_max); } ZSTD_DCtx_setParameter(dctx_.get(), ZSTD_d_windowLogMax, window_log_max); if (dictionary_) { size_t result = ZSTD_DCtx_loadDictionary_advanced( dctx_.get(), reinterpret_cast(dictionary_->data()), dictionary_size_, ZSTD_dlm_byRef, ZSTD_dct_rawContent); DCHECK(!ZSTD_isError(result)); } } ZstdSourceStream(const ZstdSourceStream&) = delete; ZstdSourceStream& operator=(const ZstdSourceStream&) = delete; ~ZstdSourceStream() override { if (ZSTD_isError(decoding_result_)) { ZSTD_ErrorCode error_code = ZSTD_getErrorCode(decoding_result_); UMA_HISTOGRAM_ENUMERATION( "Net.ZstdFilter.ErrorCode", static_cast(error_code), static_cast(ZSTD_ErrorCode::ZSTD_error_maxCode)); } UMA_HISTOGRAM_ENUMERATION("Net.ZstdFilter.Status", decoding_status_); if (decoding_status_ == ZstdDecodingStatus::kEndOfFrame) { // CompressionRatio is undefined when there is no output produced. if (produced_bytes_ != 0) { UMA_HISTOGRAM_PERCENTAGE( "Net.ZstdFilter.CompressionRatio", static_cast((consumed_bytes_ * 100) / produced_bytes_)); } } UMA_HISTOGRAM_MEMORY_KB("Net.ZstdFilter.MaxMemoryUsage", (max_allocated_ / 1024)); } private: static void* customMalloc(void* opaque, size_t size) { return reinterpret_cast(opaque)->customMalloc(size); } void* customMalloc(size_t size) { void* address = malloc(size); CHECK(address); malloc_sizes_.insert(std::make_pair(address, size)); total_allocated_ += size; if (total_allocated_ > max_allocated_) { max_allocated_ = total_allocated_; } return address; } static void customFree(void* opaque, void* address) { return reinterpret_cast(opaque)->customFree(address); } void customFree(void* address) { free(address); auto it = malloc_sizes_.find(address); CHECK(it != malloc_sizes_.end()); const size_t size = it->second; total_allocated_ -= size; malloc_sizes_.erase(it); } // SourceStream implementation std::string GetTypeAsString() const override { return kZstd; } base::expected FilterData(IOBuffer* output_buffer, size_t output_buffer_size, IOBuffer* input_buffer, size_t input_buffer_size, size_t* consumed_bytes, bool upstream_end_reached) override { CHECK(dctx_); ZSTD_inBuffer input = {input_buffer->data(), input_buffer_size, 0}; ZSTD_outBuffer output = {output_buffer->data(), output_buffer_size, 0}; const size_t result = ZSTD_decompressStream(dctx_.get(), &output, &input); decoding_result_ = result; produced_bytes_ += output.pos; consumed_bytes_ += input.pos; *consumed_bytes = input.pos; if (ZSTD_isError(result)) { decoding_status_ = ZstdDecodingStatus::kDecodingError; return base::unexpected(ERR_CONTENT_DECODING_FAILED); } else if (input.pos < input.size) { // Given a valid frame, zstd won't consume the last byte of the frame // until it has flushed all of the decompressed data of the frame. // Therefore, instead of checking if the return code is 0, we can // just check if input.pos < input.size. return output.pos; } else { CHECK_EQ(input.pos, input.size); if (result != 0u) { // The return value from ZSTD_decompressStream did not end on a frame, // but we reached the end of the file. We assume this is an error, and // the input was truncated. if (upstream_end_reached) { decoding_status_ = ZstdDecodingStatus::kDecodingError; } } else { CHECK_EQ(result, 0u); CHECK_LE(output.pos, output.size); // Finished decoding a frame. decoding_status_ = ZstdDecodingStatus::kEndOfFrame; } return output.pos; } } size_t total_allocated_ = 0; size_t max_allocated_ = 0; std::unordered_map malloc_sizes_; const scoped_refptr dictionary_; const size_t dictionary_size_; std::unique_ptr dctx_; ZstdDecodingStatus decoding_status_ = ZstdDecodingStatus::kDecodingInProgress; size_t decoding_result_ = 0; size_t consumed_bytes_ = 0; size_t produced_bytes_ = 0; }; } // namespace std::unique_ptr CreateZstdSourceStream( std::unique_ptr previous) { return std::make_unique(std::move(previous)); } std::unique_ptr CreateZstdSourceStreamWithDictionary( std::unique_ptr previous, scoped_refptr dictionary, size_t dictionary_size) { return std::make_unique( std::move(previous), std::move(dictionary), dictionary_size); } } // namespace net