diff options
Diffstat (limited to 'lib/compress/zstdmt_compress.c')
-rw-r--r-- | lib/compress/zstdmt_compress.c | 120 |
1 files changed, 80 insertions, 40 deletions
diff --git a/lib/compress/zstdmt_compress.c b/lib/compress/zstdmt_compress.c index 50454a50..f564822d 100644 --- a/lib/compress/zstdmt_compress.c +++ b/lib/compress/zstdmt_compress.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2020, Yann Collet, Facebook, Inc. + * Copyright (c) Yann Collet, Facebook, Inc. * All rights reserved. * * This source code is licensed under both the BSD-style license (found in the @@ -467,29 +467,27 @@ ZSTDMT_serialState_reset(serialState_t* serialState, ZSTD_dictContentType_e dictContentType) { /* Adjust parameters */ - if (params.ldmParams.enableLdm) { + if (params.ldmParams.enableLdm == ZSTD_ps_enable) { DEBUGLOG(4, "LDM window size = %u KB", (1U << params.cParams.windowLog) >> 10); ZSTD_ldm_adjustParameters(¶ms.ldmParams, ¶ms.cParams); assert(params.ldmParams.hashLog >= params.ldmParams.bucketSizeLog); assert(params.ldmParams.hashRateLog < 32); - serialState->ldmState.hashPower = - ZSTD_rollingHash_primePower(params.ldmParams.minMatchLength); } else { ZSTD_memset(¶ms.ldmParams, 0, sizeof(params.ldmParams)); } serialState->nextJobID = 0; if (params.fParams.checksumFlag) XXH64_reset(&serialState->xxhState, 0); - if (params.ldmParams.enableLdm) { + if (params.ldmParams.enableLdm == ZSTD_ps_enable) { ZSTD_customMem cMem = params.customMem; unsigned const hashLog = params.ldmParams.hashLog; size_t const hashSize = ((size_t)1 << hashLog) * sizeof(ldmEntry_t); unsigned const bucketLog = params.ldmParams.hashLog - params.ldmParams.bucketSizeLog; - size_t const bucketSize = (size_t)1 << bucketLog; unsigned const prevBucketLog = serialState->params.ldmParams.hashLog - serialState->params.ldmParams.bucketSizeLog; + size_t const numBuckets = (size_t)1 << bucketLog; /* Size the seq pool tables */ ZSTDMT_setNbSeq(seqPool, ZSTD_ldm_getMaxNbSeq(params.ldmParams, jobSize)); /* Reset the window */ @@ -501,20 +499,20 @@ ZSTDMT_serialState_reset(serialState_t* serialState, } if (serialState->ldmState.bucketOffsets == NULL || prevBucketLog < bucketLog) { ZSTD_customFree(serialState->ldmState.bucketOffsets, cMem); - serialState->ldmState.bucketOffsets = (BYTE*)ZSTD_customMalloc(bucketSize, cMem); + serialState->ldmState.bucketOffsets = (BYTE*)ZSTD_customMalloc(numBuckets, cMem); } if (!serialState->ldmState.hashTable || !serialState->ldmState.bucketOffsets) return 1; /* Zero the tables */ ZSTD_memset(serialState->ldmState.hashTable, 0, hashSize); - ZSTD_memset(serialState->ldmState.bucketOffsets, 0, bucketSize); + ZSTD_memset(serialState->ldmState.bucketOffsets, 0, numBuckets); /* Update window state and fill hash table with dict */ serialState->ldmState.loadedDictEnd = 0; if (dictSize > 0) { if (dictContentType == ZSTD_dct_rawContent) { BYTE const* const dictEnd = (const BYTE*)dict + dictSize; - ZSTD_window_update(&serialState->ldmState.window, dict, dictSize); + ZSTD_window_update(&serialState->ldmState.window, dict, dictSize, /* forceNonContiguous */ 0); ZSTD_ldm_fillHashTable(&serialState->ldmState, (const BYTE*)dict, dictEnd, ¶ms.ldmParams); serialState->ldmState.loadedDictEnd = params.forceWindow ? 0 : (U32)(dictEnd - serialState->ldmState.window.base); } else { @@ -566,12 +564,12 @@ static void ZSTDMT_serialState_update(serialState_t* serialState, /* A future job may error and skip our job */ if (serialState->nextJobID == jobID) { /* It is now our turn, do any processing necessary */ - if (serialState->params.ldmParams.enableLdm) { + if (serialState->params.ldmParams.enableLdm == ZSTD_ps_enable) { size_t error; assert(seqStore.seq != NULL && seqStore.pos == 0 && seqStore.size == 0 && seqStore.capacity > 0); assert(src.size <= serialState->params.jobSize); - ZSTD_window_update(&serialState->ldmState.window, src.start, src.size); + ZSTD_window_update(&serialState->ldmState.window, src.start, src.size, /* forceNonContiguous */ 0); error = ZSTD_ldm_generateSequences( &serialState->ldmState, &seqStore, &serialState->params.ldmParams, src.start, src.size); @@ -596,7 +594,7 @@ static void ZSTDMT_serialState_update(serialState_t* serialState, if (seqStore.size > 0) { size_t const err = ZSTD_referenceExternalSequences( jobCCtx, seqStore.seq, seqStore.size); - assert(serialState->params.ldmParams.enableLdm); + assert(serialState->params.ldmParams.enableLdm == ZSTD_ps_enable); assert(!ZSTD_isError(err)); (void)err; } @@ -674,7 +672,7 @@ static void ZSTDMT_compressionJob(void* jobDescription) if (dstBuff.start==NULL) JOB_ERROR(ERROR(memory_allocation)); job->dstBuff = dstBuff; /* this value can be read in ZSTDMT_flush, when it copies the whole job */ } - if (jobParams.ldmParams.enableLdm && rawSeqStore.seq == NULL) + if (jobParams.ldmParams.enableLdm == ZSTD_ps_enable && rawSeqStore.seq == NULL) JOB_ERROR(ERROR(memory_allocation)); /* Don't compute the checksum for chunks, since we compute it externally, @@ -682,7 +680,9 @@ static void ZSTDMT_compressionJob(void* jobDescription) */ if (job->jobID != 0) jobParams.fParams.checksumFlag = 0; /* Don't run LDM for the chunks, since we handle it externally */ - jobParams.ldmParams.enableLdm = 0; + jobParams.ldmParams.enableLdm = ZSTD_ps_disable; + /* Correct nbWorkers to 0. */ + jobParams.nbWorkers = 0; /* init */ @@ -695,6 +695,10 @@ static void ZSTDMT_compressionJob(void* jobDescription) { size_t const forceWindowError = ZSTD_CCtxParams_setParameter(&jobParams, ZSTD_c_forceMaxWindow, !job->firstJob); if (ZSTD_isError(forceWindowError)) JOB_ERROR(forceWindowError); } + if (!job->firstJob) { + size_t const err = ZSTD_CCtxParams_setParameter(&jobParams, ZSTD_c_deterministicRefPrefix, 0); + if (ZSTD_isError(err)) JOB_ERROR(err); + } { size_t const initError = ZSTD_compressBegin_advanced_internal(cctx, job->prefix.start, job->prefix.size, ZSTD_dct_rawContent, /* load dictionary in "content-only" mode (no header analysis) */ ZSTD_dtlm_fast, @@ -750,6 +754,13 @@ static void ZSTDMT_compressionJob(void* jobDescription) if (ZSTD_isError(cSize)) JOB_ERROR(cSize); lastCBlockSize = cSize; } } + if (!job->firstJob) { + /* Double check that we don't have an ext-dict, because then our + * repcode invalidation doesn't work. + */ + assert(!ZSTD_window_hasExtDict(cctx->blockState.matchState.window)); + } + ZSTD_CCtx_trace(cctx, 0); _endJob: ZSTDMT_serialState_ensureFinished(job->serial, job->jobID, job->cSize); @@ -796,6 +807,15 @@ typedef struct { static const roundBuff_t kNullRoundBuff = {NULL, 0, 0}; #define RSYNC_LENGTH 32 +/* Don't create chunks smaller than the zstd block size. + * This stops us from regressing compression ratio too much, + * and ensures our output fits in ZSTD_compressBound(). + * + * If this is shrunk < ZSTD_BLOCKSIZELOG_MIN then + * ZSTD_COMPRESSBOUND() will need to be updated. + */ +#define RSYNC_MIN_BLOCK_LOG ZSTD_BLOCKSIZELOG_MAX +#define RSYNC_MIN_BLOCK_SIZE (1<<RSYNC_MIN_BLOCK_LOG) typedef struct { U64 hash; @@ -1124,7 +1144,7 @@ size_t ZSTDMT_toFlushNow(ZSTDMT_CCtx* mtctx) static unsigned ZSTDMT_computeTargetJobLog(const ZSTD_CCtx_params* params) { unsigned jobLog; - if (params->ldmParams.enableLdm) { + if (params->ldmParams.enableLdm == ZSTD_ps_enable) { /* In Long Range Mode, the windowLog is typically oversized. * In which case, it's preferable to determine the jobSize * based on cycleLog instead. */ @@ -1168,7 +1188,7 @@ static size_t ZSTDMT_computeOverlapSize(const ZSTD_CCtx_params* params) int const overlapRLog = 9 - ZSTDMT_overlapLog(params->overlapLog, params->cParams.strategy); int ovLog = (overlapRLog >= 8) ? 0 : (params->cParams.windowLog - overlapRLog); assert(0 <= overlapRLog && overlapRLog <= 8); - if (params->ldmParams.enableLdm) { + if (params->ldmParams.enableLdm == ZSTD_ps_enable) { /* In Long Range Mode, the windowLog is typically oversized. * In which case, it's preferable to determine the jobSize * based on chainLog instead. @@ -1239,9 +1259,11 @@ size_t ZSTDMT_initCStream_internal( if (params.rsyncable) { /* Aim for the targetsectionSize as the average job size. */ - U32 const jobSizeMB = (U32)(mtctx->targetSectionSize >> 20); - U32 const rsyncBits = ZSTD_highbit32(jobSizeMB) + 20; - assert(jobSizeMB >= 1); + U32 const jobSizeKB = (U32)(mtctx->targetSectionSize >> 10); + U32 const rsyncBits = (assert(jobSizeKB >= 1), ZSTD_highbit32(jobSizeKB) + 10); + /* We refuse to create jobs < RSYNC_MIN_BLOCK_SIZE bytes, so make sure our + * expected job size is at least 4x larger. */ + assert(rsyncBits >= RSYNC_MIN_BLOCK_LOG + 2); DEBUGLOG(4, "rsyncLog = %u", rsyncBits); mtctx->rsync.hash = 0; mtctx->rsync.hitMask = (1ULL << rsyncBits) - 1; @@ -1253,7 +1275,7 @@ size_t ZSTDMT_initCStream_internal( ZSTDMT_setBufferSize(mtctx->bufPool, ZSTD_compressBound(mtctx->targetSectionSize)); { /* If ldm is enabled we need windowSize space. */ - size_t const windowSize = mtctx->params.ldmParams.enableLdm ? (1U << mtctx->params.cParams.windowLog) : 0; + size_t const windowSize = mtctx->params.ldmParams.enableLdm == ZSTD_ps_enable ? (1U << mtctx->params.cParams.windowLog) : 0; /* Two buffers of slack, plus extra space for the overlap * This is the minimum slack that LDM works with. One extra because * flush might waste up to targetSectionSize-1 bytes. Another extra @@ -1528,17 +1550,21 @@ static range_t ZSTDMT_getInputDataInUse(ZSTDMT_CCtx* mtctx) static int ZSTDMT_isOverlapped(buffer_t buffer, range_t range) { BYTE const* const bufferStart = (BYTE const*)buffer.start; - BYTE const* const bufferEnd = bufferStart + buffer.capacity; BYTE const* const rangeStart = (BYTE const*)range.start; - BYTE const* const rangeEnd = range.size != 0 ? rangeStart + range.size : rangeStart; if (rangeStart == NULL || bufferStart == NULL) return 0; - /* Empty ranges cannot overlap */ - if (bufferStart == bufferEnd || rangeStart == rangeEnd) - return 0; - return bufferStart < rangeEnd && rangeStart < bufferEnd; + { + BYTE const* const bufferEnd = bufferStart + buffer.capacity; + BYTE const* const rangeEnd = rangeStart + range.size; + + /* Empty ranges cannot overlap */ + if (bufferStart == bufferEnd || rangeStart == rangeEnd) + return 0; + + return bufferStart < rangeEnd && rangeStart < bufferEnd; + } } static int ZSTDMT_doesOverlapWindow(buffer_t buffer, ZSTD_window_t window) @@ -1565,7 +1591,7 @@ static int ZSTDMT_doesOverlapWindow(buffer_t buffer, ZSTD_window_t window) static void ZSTDMT_waitForLdmComplete(ZSTDMT_CCtx* mtctx, buffer_t buffer) { - if (mtctx->params.ldmParams.enableLdm) { + if (mtctx->params.ldmParams.enableLdm == ZSTD_ps_enable) { ZSTD_pthread_mutex_t* mutex = &mtctx->serial.ldmWindowMutex; DEBUGLOG(5, "ZSTDMT_waitForLdmComplete"); DEBUGLOG(5, "source [0x%zx, 0x%zx)", @@ -1668,6 +1694,11 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input) if (!mtctx->params.rsyncable) /* Rsync is disabled. */ return syncPoint; + if (mtctx->inBuff.filled + input.size - input.pos < RSYNC_MIN_BLOCK_SIZE) + /* We don't emit synchronization points if it would produce too small blocks. + * We don't have enough input to find a synchronization point, so don't look. + */ + return syncPoint; if (mtctx->inBuff.filled + syncPoint.toLoad < RSYNC_LENGTH) /* Not enough to compute the hash. * We will miss any synchronization points in this RSYNC_LENGTH byte @@ -1678,10 +1709,28 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input) */ return syncPoint; /* Initialize the loop variables. */ - if (mtctx->inBuff.filled >= RSYNC_LENGTH) { - /* We have enough bytes buffered to initialize the hash. + if (mtctx->inBuff.filled < RSYNC_MIN_BLOCK_SIZE) { + /* We don't need to scan the first RSYNC_MIN_BLOCK_SIZE positions + * because they can't possibly be a sync point. So we can start + * part way through the input buffer. + */ + pos = RSYNC_MIN_BLOCK_SIZE - mtctx->inBuff.filled; + if (pos >= RSYNC_LENGTH) { + prev = istart + pos - RSYNC_LENGTH; + hash = ZSTD_rollingHash_compute(prev, RSYNC_LENGTH); + } else { + assert(mtctx->inBuff.filled >= RSYNC_LENGTH); + prev = (BYTE const*)mtctx->inBuff.buffer.start + mtctx->inBuff.filled - RSYNC_LENGTH; + hash = ZSTD_rollingHash_compute(prev + pos, (RSYNC_LENGTH - pos)); + hash = ZSTD_rollingHash_append(hash, istart, pos); + } + } else { + /* We have enough bytes buffered to initialize the hash, + * and are have processed enough bytes to find a sync point. * Start scanning at the beginning of the input. */ + assert(mtctx->inBuff.filled >= RSYNC_MIN_BLOCK_SIZE); + assert(RSYNC_MIN_BLOCK_SIZE >= RSYNC_LENGTH); pos = 0; prev = (BYTE const*)mtctx->inBuff.buffer.start + mtctx->inBuff.filled - RSYNC_LENGTH; hash = ZSTD_rollingHash_compute(prev, RSYNC_LENGTH); @@ -1695,16 +1744,6 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input) syncPoint.flush = 1; return syncPoint; } - } else { - /* We don't have enough bytes buffered to initialize the hash, but - * we know we have at least RSYNC_LENGTH bytes total. - * Start scanning after the first RSYNC_LENGTH bytes less the bytes - * already buffered. - */ - pos = RSYNC_LENGTH - mtctx->inBuff.filled; - prev = (BYTE const*)mtctx->inBuff.buffer.start - pos; - hash = ZSTD_rollingHash_compute(mtctx->inBuff.buffer.start, mtctx->inBuff.filled); - hash = ZSTD_rollingHash_append(hash, istart, pos); } /* Starting with the hash of the previous RSYNC_LENGTH bytes, roll * through the input. If we hit a synchronization point, then cut the @@ -1716,8 +1755,9 @@ findSynchronizationPoint(ZSTDMT_CCtx const* mtctx, ZSTD_inBuffer const input) */ for (; pos < syncPoint.toLoad; ++pos) { BYTE const toRemove = pos < RSYNC_LENGTH ? prev[pos] : istart[pos - RSYNC_LENGTH]; - /* if (pos >= RSYNC_LENGTH) assert(ZSTD_rollingHash_compute(istart + pos - RSYNC_LENGTH, RSYNC_LENGTH) == hash); */ + assert(pos < RSYNC_LENGTH || ZSTD_rollingHash_compute(istart + pos - RSYNC_LENGTH, RSYNC_LENGTH) == hash); hash = ZSTD_rollingHash_rotate(hash, toRemove, istart[pos], primePower); + assert(mtctx->inBuff.filled + pos >= RSYNC_MIN_BLOCK_SIZE); if ((hash & hitMask) == hitMask) { syncPoint.toLoad = pos + 1; syncPoint.flush = 1; |