diff options
Diffstat (limited to 'okio/src/jvmTest/kotlin')
55 files changed, 9192 insertions, 71 deletions
diff --git a/okio/src/jvmTest/kotlin/okio/AsyncTimeoutTest.kt b/okio/src/jvmTest/kotlin/okio/AsyncTimeoutTest.kt new file mode 100644 index 00000000..f3a90a2b --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/AsyncTimeoutTest.kt @@ -0,0 +1,500 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.IOException +import java.io.InterruptedIOException +import java.util.Random +import java.util.concurrent.LinkedBlockingDeque +import java.util.concurrent.TimeUnit +import okio.ByteString.Companion.of +import okio.TestUtil.bufferWithRandomSegmentLayout +import org.junit.Assert +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Before +import org.junit.Ignore +import org.junit.Test + +/** + * This test uses four timeouts of varying durations: 250ms, 500ms, 750ms and + * 1000ms, named 'a', 'b', 'c' and 'd'. + */ +class AsyncTimeoutTest { + private val timedOut = LinkedBlockingDeque<AsyncTimeout>() + private val a = RecordingAsyncTimeout() + private val b = RecordingAsyncTimeout() + private val c = RecordingAsyncTimeout() + private val d = RecordingAsyncTimeout() + + @Before + fun setUp() { + a.timeout(250, TimeUnit.MILLISECONDS) + b.timeout(500, TimeUnit.MILLISECONDS) + c.timeout(750, TimeUnit.MILLISECONDS) + d.timeout(1000, TimeUnit.MILLISECONDS) + } + + @Test + fun zeroTimeoutIsNoTimeout() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(0, TimeUnit.MILLISECONDS) + timeout.enter() + Thread.sleep(250) + assertFalse(timeout.exit()) + assertTimedOut() + } + + @Test + fun singleInstanceTimedOut() { + a.enter() + Thread.sleep(500) + assertTrue(a.exit()) + assertTimedOut(a) + } + + @Test + fun singleInstanceNotTimedOut() { + b.enter() + Thread.sleep(250) + b.exit() + assertFalse(b.exit()) + assertTimedOut() + } + + @Test + fun instancesAddedAtEnd() { + a.enter() + b.enter() + c.enter() + d.enter() + Thread.sleep(1250) + assertTrue(a.exit()) + assertTrue(b.exit()) + assertTrue(c.exit()) + assertTrue(d.exit()) + assertTimedOut(a, b, c, d) + } + + @Test + fun instancesAddedAtFront() { + d.enter() + c.enter() + b.enter() + a.enter() + Thread.sleep(1250) + assertTrue(d.exit()) + assertTrue(c.exit()) + assertTrue(b.exit()) + assertTrue(a.exit()) + assertTimedOut(a, b, c, d) + } + + @Test + fun instancesRemovedAtFront() { + a.enter() + b.enter() + c.enter() + d.enter() + assertFalse(a.exit()) + assertFalse(b.exit()) + assertFalse(c.exit()) + assertFalse(d.exit()) + assertTimedOut() + } + + @Test + fun instancesRemovedAtEnd() { + a.enter() + b.enter() + c.enter() + d.enter() + assertFalse(d.exit()) + assertFalse(c.exit()) + assertFalse(b.exit()) + assertFalse(a.exit()) + assertTimedOut() + } + + @Test + fun doubleEnter() { + a.enter() + try { + a.enter() + fail() + } catch (expected: IllegalStateException) { + } + } + + @Test + fun reEnter() { + a.timeout(10, TimeUnit.SECONDS) + a.enter() + assertFalse(a.exit()) + a.enter() + assertFalse(a.exit()) + } + + @Test + fun reEnterAfterTimeout() { + a.timeout(1, TimeUnit.MILLISECONDS) + a.enter() + Assert.assertSame(a, timedOut.take()) + assertTrue(a.exit()) + a.enter() + assertFalse(a.exit()) + } + + @Test + fun deadlineOnly() { + val timeout = RecordingAsyncTimeout() + timeout.deadline(250, TimeUnit.MILLISECONDS) + timeout.enter() + Thread.sleep(500) + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + @Test + fun deadlineBeforeTimeout() { + val timeout = RecordingAsyncTimeout() + timeout.deadline(250, TimeUnit.MILLISECONDS) + timeout.timeout(750, TimeUnit.MILLISECONDS) + timeout.enter() + Thread.sleep(500) + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + @Test + fun deadlineAfterTimeout() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + timeout.deadline(750, TimeUnit.MILLISECONDS) + timeout.enter() + Thread.sleep(500) + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + @Test + fun deadlineStartsBeforeEnter() { + val timeout = RecordingAsyncTimeout() + timeout.deadline(500, TimeUnit.MILLISECONDS) + Thread.sleep(500) + timeout.enter() + Thread.sleep(250) + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + @Test + fun deadlineInThePast() { + val timeout = RecordingAsyncTimeout() + timeout.deadlineNanoTime(System.nanoTime() - 1) + timeout.enter() + Thread.sleep(250) + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + @Test + fun wrappedSinkTimesOut() { + val sink: Sink = object : ForwardingSink(Buffer()) { + override fun write(source: Buffer, byteCount: Long) { + Thread.sleep(500) + } + } + val timeout = AsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + val timeoutSink = timeout.sink(sink) + val data = Buffer().writeUtf8("a") + try { + timeoutSink.write(data, 1) + fail() + } catch (expected: InterruptedIOException) { + } + } + + @Test + fun wrappedSinkFlushTimesOut() { + val sink: Sink = object : ForwardingSink(Buffer()) { + override fun flush() { + Thread.sleep(500) + } + } + val timeout = AsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + val timeoutSink = timeout.sink(sink) + try { + timeoutSink.flush() + fail() + } catch (expected: InterruptedIOException) { + } + } + + @Test + fun wrappedSinkCloseTimesOut() { + val sink: Sink = object : ForwardingSink(Buffer()) { + override fun close() { + Thread.sleep(500) + } + } + val timeout = AsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + val timeoutSink = timeout.sink(sink) + try { + timeoutSink.close() + fail() + } catch (expected: InterruptedIOException) { + } + } + + @Test + fun wrappedSourceTimesOut() { + val source: Source = object : ForwardingSource(Buffer()) { + override fun read(sink: Buffer, byteCount: Long): Long { + Thread.sleep(500) + return -1 + } + } + val timeout = AsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + val timeoutSource = timeout.source(source) + try { + timeoutSource.read(Buffer(), 0) + fail() + } catch (expected: InterruptedIOException) { + } + } + + @Test + fun wrappedSourceCloseTimesOut() { + val source: Source = object : ForwardingSource(Buffer()) { + override fun close() { + Thread.sleep(500) + } + } + val timeout = AsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + val timeoutSource = timeout.source(source) + try { + timeoutSource.close() + fail() + } catch (expected: InterruptedIOException) { + } + } + + @Test + fun wrappedThrowsWithTimeout() { + val sink: Sink = object : ForwardingSink(Buffer()) { + override fun write(source: Buffer, byteCount: Long) { + Thread.sleep(500) + throw IOException("exception and timeout") + } + } + val timeout = AsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + val timeoutSink = timeout.sink(sink) + val data = Buffer().writeUtf8("a") + try { + timeoutSink.write(data, 1) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + assertEquals("exception and timeout", expected.cause!!.message) + } + } + + @Test + fun wrappedThrowsWithoutTimeout() { + val sink: Sink = object : ForwardingSink(Buffer()) { + override fun write(source: Buffer, byteCount: Long) { + throw IOException("no timeout occurred") + } + } + val timeout = AsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + val timeoutSink = timeout.sink(sink) + val data = Buffer().writeUtf8("a") + try { + timeoutSink.write(data, 1) + fail() + } catch (expected: IOException) { + assertEquals("no timeout occurred", expected.message) + } + } + + /** + * We had a bug where writing a very large buffer would fail with an + * unexpected timeout because although the sink was making steady forward + * progress, doing it all as a single write caused a timeout. + */ + @Ignore("Flaky") + @Test + fun sinkSplitsLargeWrites() { + val data = ByteArray(512 * 1024) + val dice = Random(0) + dice.nextBytes(data) + val source = bufferWithRandomSegmentLayout(dice, data) + val target = Buffer() + val sink: Sink = object : ForwardingSink(Buffer()) { + override fun write(source: Buffer, byteCount: Long) { + Thread.sleep(byteCount / 500) // ~500 KiB/s. + target.write(source, byteCount) + } + } + + // Timeout after 250 ms of inactivity. + val timeout = AsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + val timeoutSink = timeout.sink(sink) + + // Transmit 500 KiB of data, which should take ~1 second. But expect no timeout! + timeoutSink.write(source, source.size) + + // The data should all have arrived. + assertEquals(of(*data), target.readByteString()) + } + + @Test + fun enterCancelSleepExit() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + timeout.enter() + timeout.cancel() + Thread.sleep(500) + + // Call didn't time out because the timeout was canceled. + assertFalse(timeout.exit()) + assertTimedOut() + } + + @Test + fun enterCancelCancelSleepExit() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + timeout.enter() + timeout.cancel() + timeout.cancel() + Thread.sleep(500) + + // Call didn't time out because the timeout was canceled. + assertFalse(timeout.exit()) + assertTimedOut() + } + + @Test + fun enterSleepCancelExit() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + timeout.enter() + Thread.sleep(500) + timeout.cancel() + + // Call timed out because the cancel was too late. + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + @Test + fun enterSleepCancelCancelExit() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + timeout.enter() + Thread.sleep(500) + timeout.cancel() + timeout.cancel() + + // Call timed out because both cancels were too late. + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + @Test + fun enterCancelSleepCancelExit() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + timeout.enter() + timeout.cancel() + Thread.sleep(500) + timeout.cancel() + + // Call didn't time out because the timeout was canceled. + assertFalse(timeout.exit()) + assertTimedOut() + } + + @Test + fun cancelEnterSleepExit() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + timeout.cancel() + timeout.enter() + Thread.sleep(500) + + // Call timed out because the cancel was too early. + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + @Test + fun cancelEnterCancelSleepExit() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + timeout.cancel() + timeout.enter() + timeout.cancel() + Thread.sleep(500) + + // Call didn't time out because the timeout was canceled. + assertFalse(timeout.exit()) + assertTimedOut() + } + + @Test + fun enterCancelSleepExitEnterSleepExit() { + val timeout = RecordingAsyncTimeout() + timeout.timeout(250, TimeUnit.MILLISECONDS) + + // First call doesn't time out because we cancel it. + timeout.enter() + timeout.cancel() + Thread.sleep(500) + assertFalse(timeout.exit()) + assertTimedOut() + + // Second call does time out because it isn't canceled a second time. + timeout.enter() + Thread.sleep(500) + assertTrue(timeout.exit()) + assertTimedOut(timeout) + } + + /** Asserts which timeouts fired, and in which order. */ + private fun assertTimedOut(vararg expected: Timeout) { + assertEquals(expected.toList(), timedOut.toList()) + timedOut.clear() + } + + internal inner class RecordingAsyncTimeout : AsyncTimeout() { + override fun timedOut() { + timedOut.add(this) + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt b/okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt new file mode 100644 index 00000000..a04e8cfe --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/AwaitSignalTest.kt @@ -0,0 +1,224 @@ +/* + * Copyright (C) 2023 Block Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.InterruptedIOException +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.Condition +import java.util.concurrent.locks.ReentrantLock +import okio.TestUtil.assumeNotWindows +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +@RunWith(Parameterized::class) +class AwaitSignalTest( + factory: TimeoutFactory, +) { + private val timeout = factory.newTimeout() + val executorService = TestingExecutors.newScheduledExecutorService(0) + + val lock: ReentrantLock = ReentrantLock() + val condition: Condition = lock.newCondition() + + @After + fun tearDown() { + executorService.shutdown() + } + + @Test + fun signaled() = lock.withLock { + timeout.timeout(5000, TimeUnit.MILLISECONDS) + val start = now() + executorService.schedule( + { lock.withLock { condition.signal() } }, + 1000, + TimeUnit.MILLISECONDS, + ) + timeout.awaitSignal(condition) + assertElapsed(1000.0, start) + } + + @Test + fun timeout() = lock.withLock { + assumeNotWindows() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + val start = now() + try { + timeout.awaitSignal(condition) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + fun deadline() = lock.withLock { + assumeNotWindows() + timeout.deadline(1000, TimeUnit.MILLISECONDS) + val start = now() + try { + timeout.awaitSignal(condition) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + fun deadlineBeforeTimeout() = lock.withLock { + assumeNotWindows() + timeout.timeout(5000, TimeUnit.MILLISECONDS) + timeout.deadline(1000, TimeUnit.MILLISECONDS) + val start = now() + try { + timeout.awaitSignal(condition) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + fun timeoutBeforeDeadline() = lock.withLock { + assumeNotWindows() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.deadline(5000, TimeUnit.MILLISECONDS) + val start = now() + try { + timeout.awaitSignal(condition) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + fun deadlineAlreadyReached() = lock.withLock { + assumeNotWindows() + timeout.deadlineNanoTime(System.nanoTime()) + val start = now() + try { + timeout.awaitSignal(condition) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(0.0, start) + } + + @Test + fun threadInterrupted() = lock.withLock { + assumeNotWindows() + val start = now() + Thread.currentThread().interrupt() + try { + timeout.awaitSignal(condition) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("interrupted", expected.message) + assertTrue(Thread.interrupted()) + } + assertElapsed(0.0, start) + } + + @Test + fun threadInterruptedOnThrowIfReached() = lock.withLock { + assumeNotWindows() + Thread.currentThread().interrupt() + try { + timeout.throwIfReached() + fail() + } catch (expected: InterruptedIOException) { + assertEquals("interrupted", expected.message) + assertTrue(Thread.interrupted()) + } + } + + @Test + fun cancelBeforeWaitDoesNothing() = lock.withLock { + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancel() + val start = now() + try { + timeout.awaitSignal(condition) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + fun canceledTimeoutDoesNotThrowWhenNotNotifiedOnTime() = lock.withLock { + assumeNotWindows() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancelLater(500) + + val start = now() + timeout.awaitSignal(condition) // Returns early but doesn't throw. + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun multipleCancelsAreIdempotent() = lock.withLock { + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancelLater(250) + timeout.cancelLater(500) + timeout.cancelLater(750) + + val start = now() + timeout.awaitSignal(condition) // Returns early but doesn't throw. + assertElapsed(1000.0, start) + } + + /** Returns the nanotime in milliseconds as a double for measuring timeouts. */ + private fun now(): Double { + return System.nanoTime() / 1000000.0 + } + + /** + * Fails the test unless the time from start until now is duration, accepting differences in + * -50..+450 milliseconds. + */ + private fun assertElapsed(duration: Double, start: Double) { + assertEquals(duration, now() - start - 200.0, 250.0) + } + + private fun Timeout.cancelLater(delay: Long) { + executorService.schedule( + { cancel() }, + delay, + TimeUnit.MILLISECONDS, + ) + } + + companion object { + @Parameters(name = "{0}") + @JvmStatic + fun parameters(): List<Array<out Any?>> = TimeoutFactory.entries.map { arrayOf(it) } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/BufferCursorKotlinTest.kt b/okio/src/jvmTest/kotlin/okio/BufferCursorKotlinTest.kt index ddbc4c53..0b50b005 100644 --- a/okio/src/jvmTest/kotlin/okio/BufferCursorKotlinTest.kt +++ b/okio/src/jvmTest/kotlin/okio/BufferCursorKotlinTest.kt @@ -15,6 +15,11 @@ */ package okio +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertNotSame +import kotlin.test.assertSame +import kotlin.test.assertTrue import okio.Buffer.UnsafeCursor import okio.TestUtil.deepCopy import org.junit.Assume.assumeTrue @@ -23,11 +28,6 @@ import org.junit.runner.RunWith import org.junit.runners.Parameterized import org.junit.runners.Parameterized.Parameter import org.junit.runners.Parameterized.Parameters -import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertNotSame -import kotlin.test.assertSame -import kotlin.test.assertTrue @RunWith(Parameterized::class) class BufferCursorKotlinTest { @@ -97,7 +97,7 @@ class BufferCursorKotlinTest { buffer.readAndWriteUnsafe().use { cursor -> while (cursor.next() != -1) { - cursor.data!!.fill('x'.toByte(), cursor.start, cursor.end) + cursor.data!!.fill('x'.code.toByte(), cursor.start, cursor.end) } } diff --git a/okio/src/jvmTest/kotlin/okio/BufferCursorTest.kt b/okio/src/jvmTest/kotlin/okio/BufferCursorTest.kt new file mode 100644 index 00000000..54873286 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/BufferCursorTest.kt @@ -0,0 +1,459 @@ +/* + * Copyright (C) 2018 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.util.Arrays +import okio.ByteString.Companion.of +import okio.TestUtil.SEGMENT_SIZE +import okio.TestUtil.deepCopy +import org.junit.Assert +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotEquals +import org.junit.Assert.assertNotNull +import org.junit.Assert.assertNull +import org.junit.Assert.fail +import org.junit.Assume.assumeTrue +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +@RunWith(Parameterized::class) +class BufferCursorTest( + private var bufferFactory: BufferFactory, +) { + + @Test + fun apiExample() { + val buffer = Buffer() + buffer.readAndWriteUnsafe().use { cursor -> + cursor.resizeBuffer(1000000) + do { + Arrays.fill(cursor.data, cursor.start, cursor.end, 'x'.code.toByte()) + } while (cursor.next() != -1) + cursor.seek(3) + cursor.data!![cursor.start] = 'o'.code.toByte() + cursor.seek(1) + cursor.data!![cursor.start] = 'o'.code.toByte() + cursor.resizeBuffer(4) + } + assertEquals(Buffer().writeUtf8("xoxo"), buffer) + } + + @Test + fun accessSegmentBySegment() { + val buffer = bufferFactory.newBuffer() + buffer.readUnsafe().use { cursor -> + val actual = Buffer() + while (cursor.next().toLong() != -1L) { + actual.write(cursor.data!!, cursor.start, cursor.end - cursor.start) + } + assertEquals(buffer, actual) + } + } + + @Test + fun seekToNegativeOneSeeksBeforeFirstSegment() { + val buffer = bufferFactory.newBuffer() + buffer.readUnsafe().use { cursor -> + cursor.seek(-1L) + assertEquals(-1, cursor.offset) + assertNull(cursor.data) + assertEquals(-1, cursor.start.toLong()) + assertEquals(-1, cursor.end.toLong()) + cursor.next() + assertEquals(0, cursor.offset) + } + } + + @Test + fun accessByteByByte() { + val buffer = bufferFactory.newBuffer() + buffer.readUnsafe().use { cursor -> + val actual = ByteArray(buffer.size.toInt()) + for (i in 0 until buffer.size) { + cursor.seek(i) + actual[i.toInt()] = cursor.data!![cursor.start] + } + assertEquals(of(*actual), buffer.snapshot()) + } + } + + @Test + fun accessByteByByteReverse() { + val buffer = bufferFactory.newBuffer() + buffer.readUnsafe().use { cursor -> + val actual = ByteArray(buffer.size.toInt()) + for (i in (buffer.size - 1).toInt() downTo 0) { + cursor.seek(i.toLong()) + actual[i] = cursor.data!![cursor.start] + } + assertEquals(of(*actual), buffer.snapshot()) + } + } + + @Test + fun accessByteByByteAlwaysResettingToZero() { + val buffer = bufferFactory.newBuffer() + buffer.readUnsafe().use { cursor -> + val actual = ByteArray(buffer.size.toInt()) + for (i in 0 until buffer.size) { + cursor.seek(i) + actual[i.toInt()] = cursor.data!![cursor.start] + cursor.seek(0L) + } + assertEquals(of(*actual), buffer.snapshot()) + } + } + + @Test + fun segmentBySegmentNavigation() { + val buffer = bufferFactory.newBuffer() + val cursor = buffer.readUnsafe() + assertEquals(-1, cursor.offset) + try { + var lastOffset = cursor.offset + while (cursor.next().toLong() != -1L) { + Assert.assertTrue(cursor.offset > lastOffset) + lastOffset = cursor.offset + } + assertEquals(buffer.size, cursor.offset) + assertNull(cursor.data) + assertEquals(-1, cursor.start.toLong()) + assertEquals(-1, cursor.end.toLong()) + } finally { + cursor.close() + } + } + + @Test + fun seekWithinSegment() { + assumeTrue(bufferFactory === BufferFactory.SMALL_SEGMENTED_BUFFER) + val buffer = bufferFactory.newBuffer() + assertEquals("abcdefghijkl", buffer.clone().readUtf8()) + buffer.readUnsafe().use { cursor -> + assertEquals(2, cursor.seek(5).toLong()) // 2 for 2 bytes left in the segment: "fg". + assertEquals(5, cursor.offset) + assertEquals(2, (cursor.end - cursor.start).toLong()) + assertEquals('d'.code.toLong(), Char(cursor.data!![cursor.start - 2].toUShort()).code.toLong()) // Out of bounds! + assertEquals('e'.code.toLong(), Char(cursor.data!![cursor.start - 1].toUShort()).code.toLong()) // Out of bounds! + assertEquals('f'.code.toLong(), Char(cursor.data!![cursor.start].toUShort()).code.toLong()) + assertEquals('g'.code.toLong(), Char(cursor.data!![cursor.start + 1].toUShort()).code.toLong()) + } + } + + @Test + fun acquireAndRelease() { + val buffer = bufferFactory.newBuffer() + val cursor = Buffer.UnsafeCursor() + + // Nothing initialized before acquire. + assertEquals(-1, cursor.offset) + assertNull(cursor.data) + assertEquals(-1, cursor.start.toLong()) + assertEquals(-1, cursor.end.toLong()) + buffer.readUnsafe(cursor) + cursor.close() + + // Nothing initialized after close. + assertEquals(-1, cursor.offset) + assertNull(cursor.data) + assertEquals(-1, cursor.start.toLong()) + assertEquals(-1, cursor.end.toLong()) + } + + @Test + fun doubleAcquire() { + val buffer = bufferFactory.newBuffer() + try { + buffer.readUnsafe().use { cursor -> + buffer.readUnsafe(cursor) + fail() + } + } catch (expected: IllegalStateException) { + } + } + + @Test + fun releaseWithoutAcquire() { + val cursor = Buffer.UnsafeCursor() + try { + cursor.close() + fail() + } catch (expected: IllegalStateException) { + } + } + + @Test + fun releaseAfterRelease() { + val buffer = bufferFactory.newBuffer() + val cursor = buffer.readUnsafe() + cursor.close() + try { + cursor.close() + fail() + } catch (expected: IllegalStateException) { + } + } + + @Test + fun enlarge() { + val buffer = bufferFactory.newBuffer() + val originalSize = buffer.size + val expected = deepCopy(buffer) + expected.writeUtf8("abc") + buffer.readAndWriteUnsafe().use { cursor -> + assertEquals(originalSize, cursor.resizeBuffer(originalSize + 3)) + cursor.seek(originalSize) + cursor.data!![cursor.start] = 'a'.code.toByte() + cursor.seek(originalSize + 1) + cursor.data!![cursor.start] = 'b'.code.toByte() + cursor.seek(originalSize + 2) + cursor.data!![cursor.start] = 'c'.code.toByte() + } + assertEquals(expected, buffer) + } + + @Test + fun enlargeByManySegments() { + val buffer = bufferFactory.newBuffer() + val originalSize = buffer.size + val expected = deepCopy(buffer) + expected.writeUtf8("x".repeat(1000000)) + buffer.readAndWriteUnsafe().use { cursor -> + cursor.resizeBuffer(originalSize + 1000000) + cursor.seek(originalSize) + do { + Arrays.fill(cursor.data, cursor.start, cursor.end, 'x'.code.toByte()) + } while (cursor.next() != -1) + } + assertEquals(expected, buffer) + } + + @Test + fun resizeNotAcquired() { + val cursor = Buffer.UnsafeCursor() + try { + cursor.resizeBuffer(10) + fail() + } catch (expected: IllegalStateException) { + } + } + + @Test + fun expandNotAcquired() { + val cursor = Buffer.UnsafeCursor() + try { + cursor.expandBuffer(10) + fail() + } catch (expected: IllegalStateException) { + } + } + + @Test + fun resizeAcquiredReadOnly() { + val buffer = bufferFactory.newBuffer() + try { + buffer.readUnsafe().use { cursor -> + cursor.resizeBuffer(10) + fail() + } + } catch (expected: IllegalStateException) { + } + } + + @Test + fun expandAcquiredReadOnly() { + val buffer = bufferFactory.newBuffer() + try { + buffer.readUnsafe().use { cursor -> + cursor.expandBuffer(10) + fail() + } + } catch (expected: IllegalStateException) { + } + } + + @Test + fun shrink() { + val buffer = bufferFactory.newBuffer() + assumeTrue(buffer.size > 3) + val originalSize = buffer.size + val expected = Buffer() + deepCopy(buffer).copyTo(expected, 0, originalSize - 3) + buffer.readAndWriteUnsafe().use { cursor -> + assertEquals(originalSize, cursor.resizeBuffer(originalSize - 3)) + } + assertEquals(expected, buffer) + } + + @Test + fun shrinkByManySegments() { + val buffer = bufferFactory.newBuffer() + assumeTrue(buffer.size <= 1000000) + val originalSize = buffer.size + val toShrink = Buffer() + toShrink.writeUtf8("x".repeat(1000000)) + deepCopy(buffer).copyTo(toShrink, 0, originalSize) + val cursor = Buffer.UnsafeCursor() + toShrink.readAndWriteUnsafe(cursor) + try { + cursor.resizeBuffer(originalSize) + } finally { + cursor.close() + } + val expected = Buffer() + expected.writeUtf8("x".repeat(originalSize.toInt())) + assertEquals(expected, toShrink) + } + + @Test + fun shrinkAdjustOffset() { + val buffer = bufferFactory.newBuffer() + assumeTrue(buffer.size > 4) + buffer.readAndWriteUnsafe().use { cursor -> + cursor.seek(buffer.size - 1) + cursor.resizeBuffer(3) + assertEquals(3, cursor.offset) + assertNull(cursor.data) + assertEquals(-1, cursor.start.toLong()) + assertEquals(-1, cursor.end.toLong()) + } + } + + @Test + fun resizeToSameSizeSeeksToEnd() { + val buffer = bufferFactory.newBuffer() + val originalSize = buffer.size + buffer.readAndWriteUnsafe().use { cursor -> + cursor.seek(buffer.size / 2) + assertEquals(originalSize, buffer.size) + cursor.resizeBuffer(originalSize) + assertEquals(originalSize, buffer.size) + assertEquals(originalSize, cursor.offset) + assertNull(cursor.data) + assertEquals(-1, cursor.start.toLong()) + assertEquals(-1, cursor.end.toLong()) + } + } + + @Test + fun resizeEnlargeMovesCursorToOldSize() { + val buffer = bufferFactory.newBuffer() + val originalSize = buffer.size + val expected = deepCopy(buffer) + expected.writeUtf8("a") + buffer.readAndWriteUnsafe().use { cursor -> + cursor.seek(buffer.size / 2) + assertEquals(originalSize, buffer.size) + cursor.resizeBuffer(originalSize + 1) + assertEquals(originalSize, cursor.offset) + assertNotNull(cursor.data) + assertNotEquals(-1, cursor.start.toLong()) + assertEquals((cursor.start + 1).toLong(), cursor.end.toLong()) + cursor.data!![cursor.start] = 'a'.code.toByte() + } + assertEquals(expected, buffer) + } + + @Test + fun resizeShrinkMovesCursorToEnd() { + val buffer = bufferFactory.newBuffer() + assumeTrue(buffer.size > 0) + val originalSize = buffer.size + buffer.readAndWriteUnsafe().use { cursor -> + cursor.seek(buffer.size / 2) + assertEquals(originalSize, buffer.size) + cursor.resizeBuffer(originalSize - 1) + assertEquals(originalSize - 1, cursor.offset) + assertNull(cursor.data) + assertEquals(-1, cursor.start.toLong()) + assertEquals(-1, cursor.end.toLong()) + } + } + + @Test + fun expand() { + val buffer = bufferFactory.newBuffer() + val originalSize = buffer.size + val expected = deepCopy(buffer) + expected.writeUtf8("abcde") + buffer.readAndWriteUnsafe().use { cursor -> + cursor.expandBuffer(5) + for (i in 0..4) { + cursor.data!![cursor.start + i] = ('a'.code + i).toByte() + } + cursor.resizeBuffer(originalSize + 5) + } + assertEquals(expected, buffer) + } + + @Test + fun expandSameSegment() { + val buffer = bufferFactory.newBuffer() + val originalSize = buffer.size + assumeTrue(originalSize > 0) + buffer.readAndWriteUnsafe().use { cursor -> + cursor.seek(originalSize - 1) + val originalEnd = cursor.end + assumeTrue(originalEnd < SEGMENT_SIZE) + val addedByteCount = cursor.expandBuffer(1) + assertEquals((SEGMENT_SIZE - originalEnd).toLong(), addedByteCount) + assertEquals(originalSize + addedByteCount, buffer.size) + assertEquals(originalSize, cursor.offset) + assertEquals(originalEnd.toLong(), cursor.start.toLong()) + assertEquals(SEGMENT_SIZE.toLong(), cursor.end.toLong()) + } + } + + @Test + fun expandNewSegment() { + val buffer = bufferFactory.newBuffer() + val originalSize = buffer.size + buffer.readAndWriteUnsafe().use { cursor -> + val addedByteCount = cursor.expandBuffer(SEGMENT_SIZE) + assertEquals(SEGMENT_SIZE.toLong(), addedByteCount) + assertEquals(originalSize, cursor.offset) + assertEquals(0, cursor.start.toLong()) + assertEquals(SEGMENT_SIZE.toLong(), cursor.end.toLong()) + } + } + + @Test + fun expandMovesOffsetToOldSize() { + val buffer = bufferFactory.newBuffer() + val originalSize = buffer.size + buffer.readAndWriteUnsafe().use { cursor -> + cursor.seek(buffer.size / 2) + assertEquals(originalSize, buffer.size) + val addedByteCount = cursor.expandBuffer(5) + assertEquals(originalSize + addedByteCount, buffer.size) + assertEquals(originalSize, cursor.offset) + } + } + + companion object { + @JvmStatic + @Parameters(name = "{0}") + fun parameters(): List<Array<Any>> { + val result = mutableListOf<Array<Any>>() + for (bufferFactory in BufferFactory.values()) { + result += arrayOf(bufferFactory) + } + return result + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/BufferFactory.kt b/okio/src/jvmTest/kotlin/okio/BufferFactory.kt index e1533d23..0e6ce906 100644 --- a/okio/src/jvmTest/kotlin/okio/BufferFactory.kt +++ b/okio/src/jvmTest/kotlin/okio/BufferFactory.kt @@ -15,9 +15,9 @@ */ package okio +import java.util.Random import okio.TestUtil.bufferWithRandomSegmentLayout import okio.TestUtil.bufferWithSegments -import java.util.Random enum class BufferFactory { EMPTY { @@ -59,7 +59,8 @@ enum class BufferFactory { return bufferWithRandomSegmentLayout(dice, largeByteArray) } - }; + }, + ; @Throws(Exception::class) abstract fun newBuffer(): Buffer diff --git a/okio/src/jvmTest/kotlin/okio/BufferKotlinTest.kt b/okio/src/jvmTest/kotlin/okio/BufferKotlinTest.kt index eda7989d..3bc8193c 100644 --- a/okio/src/jvmTest/kotlin/okio/BufferKotlinTest.kt +++ b/okio/src/jvmTest/kotlin/okio/BufferKotlinTest.kt @@ -15,16 +15,16 @@ */ package okio +import kotlin.test.assertFailsWith import org.assertj.core.api.Assertions.assertThat import org.junit.Test -import kotlin.test.assertFailsWith class BufferKotlinTest { @Test fun get() { val actual = Buffer().writeUtf8("abc") - assertThat(actual[0]).isEqualTo('a'.toByte()) - assertThat(actual[1]).isEqualTo('b'.toByte()) - assertThat(actual[2]).isEqualTo('c'.toByte()) + assertThat(actual[0]).isEqualTo('a'.code.toByte()) + assertThat(actual[1]).isEqualTo('b'.code.toByte()) + assertThat(actual[2]).isEqualTo('c'.code.toByte()) assertFailsWith<IndexOutOfBoundsException> { actual[-1] } diff --git a/okio/src/jvmTest/kotlin/okio/BufferTest.kt b/okio/src/jvmTest/kotlin/okio/BufferTest.kt new file mode 100644 index 00000000..6f3501bd --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/BufferTest.kt @@ -0,0 +1,612 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.io.EOFException +import java.io.InputStream +import java.util.Arrays +import java.util.Random +import kotlin.text.Charsets.UTF_8 +import okio.ByteString.Companion.decodeHex +import okio.TestUtil.SEGMENT_POOL_MAX_SIZE +import okio.TestUtil.SEGMENT_SIZE +import okio.TestUtil.assertNoEmptySegments +import okio.TestUtil.bufferWithRandomSegmentLayout +import okio.TestUtil.segmentPoolByteCount +import okio.TestUtil.segmentSizes +import org.junit.Assert +import org.junit.Assert.assertEquals +import org.junit.Assert.fail +import org.junit.Test + +/** + * Tests solely for the behavior of Buffer's implementation. For generic BufferedSink or + * BufferedSource behavior use BufferedSinkTest or BufferedSourceTest, respectively. + */ +class BufferTest { + @Test + fun readAndWriteUtf8() { + val buffer = Buffer() + buffer.writeUtf8("ab") + assertEquals(2, buffer.size) + buffer.writeUtf8("cdef") + assertEquals(6, buffer.size) + assertEquals("abcd", buffer.readUtf8(4)) + assertEquals(2, buffer.size) + assertEquals("ef", buffer.readUtf8(2)) + assertEquals(0, buffer.size) + try { + buffer.readUtf8(1) + fail() + } catch (expected: EOFException) { + } + } + + /** Buffer's toString is the same as ByteString's. */ + @Test + fun bufferToString() { + assertEquals("[size=0]", Buffer().toString()) + assertEquals( + "[text=a\\r\\nb\\nc\\rd\\\\e]", + Buffer().writeUtf8("a\r\nb\nc\rd\\e").toString(), + ) + assertEquals( + "[text=Tyrannosaur]", + Buffer().writeUtf8("Tyrannosaur").toString(), + ) + assertEquals( + "[text=təˈranəˌsôr]", + Buffer() + .write("74c999cb8872616ec999cb8c73c3b472".decodeHex()) + .toString(), + ) + assertEquals( + "[hex=0000000000000000000000000000000000000000000000000000000000000000000000000000" + + "0000000000000000000000000000000000000000000000000000]", + Buffer().write(ByteArray(64)).toString(), + ) + } + + @Test + fun multipleSegmentBuffers() { + val buffer = Buffer() + buffer.writeUtf8("a".repeat(1000)) + buffer.writeUtf8("b".repeat(2500)) + buffer.writeUtf8("c".repeat(5000)) + buffer.writeUtf8("d".repeat(10000)) + buffer.writeUtf8("e".repeat(25000)) + buffer.writeUtf8("f".repeat(50000)) + assertEquals("a".repeat(999), buffer.readUtf8(999)) // a...a + assertEquals("a" + "b".repeat(2500) + "c", buffer.readUtf8(2502)) // ab...bc + assertEquals("c".repeat(4998), buffer.readUtf8(4998)) // c...c + assertEquals("c" + "d".repeat(10000) + "e", buffer.readUtf8(10002)) // cd...de + assertEquals("e".repeat(24998), buffer.readUtf8(24998)) // e...e + assertEquals("e" + "f".repeat(50000), buffer.readUtf8(50001)) // ef...f + assertEquals(0, buffer.size) + } + + @Test + fun fillAndDrainPool() { + val buffer = Buffer() + + // Take 2 * MAX_SIZE segments. This will drain the pool, even if other tests filled it. + buffer.write(ByteArray(SEGMENT_POOL_MAX_SIZE)) + buffer.write(ByteArray(SEGMENT_POOL_MAX_SIZE)) + assertEquals(0, segmentPoolByteCount().toLong()) + + // Recycle MAX_SIZE segments. They're all in the pool. + buffer.skip(SEGMENT_POOL_MAX_SIZE.toLong()) + assertEquals(SEGMENT_POOL_MAX_SIZE.toLong(), segmentPoolByteCount().toLong()) + + // Recycle MAX_SIZE more segments. The pool is full so they get garbage collected. + buffer.skip(SEGMENT_POOL_MAX_SIZE.toLong()) + assertEquals(SEGMENT_POOL_MAX_SIZE.toLong(), segmentPoolByteCount().toLong()) + + // Take MAX_SIZE segments to drain the pool. + buffer.write(ByteArray(SEGMENT_POOL_MAX_SIZE)) + assertEquals(0, segmentPoolByteCount().toLong()) + + // Take MAX_SIZE more segments. The pool is drained so these will need to be allocated. + buffer.write(ByteArray(SEGMENT_POOL_MAX_SIZE)) + assertEquals(0, segmentPoolByteCount().toLong()) + } + + @Test + fun moveBytesBetweenBuffersShareSegment() { + val size: Int = SEGMENT_SIZE / 2 - 1 + val segmentSizes = moveBytesBetweenBuffers("a".repeat(size), "b".repeat(size)) + assertEquals(Arrays.asList(size * 2), segmentSizes) + } + + @Test + fun moveBytesBetweenBuffersReassignSegment() { + val size: Int = SEGMENT_SIZE / 2 + 1 + val segmentSizes = moveBytesBetweenBuffers("a".repeat(size), "b".repeat(size)) + assertEquals(Arrays.asList(size, size), segmentSizes) + } + + @Test + fun moveBytesBetweenBuffersMultipleSegments() { + val size: Int = 3 * SEGMENT_SIZE + 1 + val segmentSizes = moveBytesBetweenBuffers("a".repeat(size), "b".repeat(size)) + assertEquals( + listOf( + SEGMENT_SIZE, + SEGMENT_SIZE, + SEGMENT_SIZE, + 1, + SEGMENT_SIZE, + SEGMENT_SIZE, + SEGMENT_SIZE, + 1, + ), + segmentSizes, + ) + } + + private fun moveBytesBetweenBuffers(vararg contents: String): List<Int> { + val expected = StringBuilder() + val buffer = Buffer() + for (s in contents) { + val source = Buffer() + source.writeUtf8(s) + buffer.writeAll(source) + expected.append(s) + } + val segmentSizes = segmentSizes(buffer) + assertEquals(expected.toString(), buffer.readUtf8(expected.length.toLong())) + return segmentSizes + } + + /** The big part of source's first segment is being moved. */ + @Test + fun writeSplitSourceBufferLeft() { + val writeSize: Int = SEGMENT_SIZE / 2 + 1 + val sink = Buffer() + sink.writeUtf8("b".repeat(SEGMENT_SIZE - 10)) + val source = Buffer() + source.writeUtf8("a".repeat(SEGMENT_SIZE * 2)) + sink.write(source, writeSize.toLong()) + assertEquals(Arrays.asList<Int>(SEGMENT_SIZE - 10, writeSize), segmentSizes(sink)) + assertEquals(Arrays.asList<Int>(SEGMENT_SIZE - writeSize, SEGMENT_SIZE), segmentSizes(source)) + } + + /** The big part of source's first segment is staying put. */ + @Test + fun writeSplitSourceBufferRight() { + val writeSize: Int = SEGMENT_SIZE / 2 - 1 + val sink = Buffer() + sink.writeUtf8("b".repeat(SEGMENT_SIZE - 10)) + val source = Buffer() + source.writeUtf8("a".repeat(SEGMENT_SIZE * 2)) + sink.write(source, writeSize.toLong()) + assertEquals(Arrays.asList<Int>(SEGMENT_SIZE - 10, writeSize), segmentSizes(sink)) + assertEquals(Arrays.asList<Int>(SEGMENT_SIZE - writeSize, SEGMENT_SIZE), segmentSizes(source)) + } + + @Test + fun writePrefixDoesntSplit() { + val sink = Buffer() + sink.writeUtf8("b".repeat(10)) + val source = Buffer() + source.writeUtf8("a".repeat(SEGMENT_SIZE * 2)) + sink.write(source, 20) + assertEquals(mutableListOf(30), segmentSizes(sink)) + assertEquals(Arrays.asList<Int>(SEGMENT_SIZE - 20, SEGMENT_SIZE), segmentSizes(source)) + assertEquals(30, sink.size) + assertEquals((SEGMENT_SIZE * 2 - 20).toLong(), source.size) + } + + @Test + fun writePrefixDoesntSplitButRequiresCompact() { + val sink = Buffer() + sink.writeUtf8("b".repeat(SEGMENT_SIZE - 10)) // limit = size - 10 + sink.readUtf8((SEGMENT_SIZE - 20).toLong()) // pos = size = 20 + val source = Buffer() + source.writeUtf8("a".repeat(SEGMENT_SIZE * 2)) + sink.write(source, 20) + assertEquals(mutableListOf(30), segmentSizes(sink)) + assertEquals(Arrays.asList<Int>(SEGMENT_SIZE - 20, SEGMENT_SIZE), segmentSizes(source)) + assertEquals(30, sink.size) + assertEquals((SEGMENT_SIZE * 2 - 20).toLong(), source.size) + } + + @Test + fun copyToSpanningSegments() { + val source = Buffer() + source.writeUtf8("a".repeat(SEGMENT_SIZE * 2)) + source.writeUtf8("b".repeat(SEGMENT_SIZE * 2)) + val out = ByteArrayOutputStream() + source.copyTo(out, 10, (SEGMENT_SIZE * 3).toLong()) + assertEquals( + "a".repeat(SEGMENT_SIZE * 2 - 10) + "b".repeat(SEGMENT_SIZE + 10), + out.toString(), + ) + assertEquals( + "a".repeat(SEGMENT_SIZE * 2) + "b".repeat(SEGMENT_SIZE * 2), + source.readUtf8((SEGMENT_SIZE * 4).toLong()), + ) + } + + @Test + fun copyToStream() { + val buffer = Buffer().writeUtf8("hello, world!") + val out = ByteArrayOutputStream() + buffer.copyTo(out) + val outString = out.toByteArray().toString(UTF_8) + assertEquals("hello, world!", outString) + assertEquals("hello, world!", buffer.readUtf8()) + } + + @Test + fun writeToSpanningSegments() { + val buffer = Buffer() + buffer.writeUtf8("a".repeat(SEGMENT_SIZE * 2)) + buffer.writeUtf8("b".repeat(SEGMENT_SIZE * 2)) + val out = ByteArrayOutputStream() + buffer.skip(10) + buffer.writeTo(out, (SEGMENT_SIZE * 3).toLong()) + assertEquals( + "a".repeat(SEGMENT_SIZE * 2 - 10) + "b".repeat(SEGMENT_SIZE + 10), + out.toString(), + ) + assertEquals("b".repeat(SEGMENT_SIZE - 10), buffer.readUtf8(buffer.size)) + } + + @Test + fun writeToStream() { + val buffer = Buffer().writeUtf8("hello, world!") + val out = ByteArrayOutputStream() + buffer.writeTo(out) + val outString = out.toByteArray().toString(UTF_8) + assertEquals("hello, world!", outString) + assertEquals(0, buffer.size) + } + + @Test + fun readFromStream() { + val `in`: InputStream = ByteArrayInputStream("hello, world!".toByteArray(UTF_8)) + val buffer = Buffer() + buffer.readFrom(`in`) + val out = buffer.readUtf8() + assertEquals("hello, world!", out) + } + + @Test + fun readFromSpanningSegments() { + val `in`: InputStream = ByteArrayInputStream("hello, world!".toByteArray(UTF_8)) + val buffer = Buffer().writeUtf8("a".repeat(SEGMENT_SIZE - 10)) + buffer.readFrom(`in`) + val out = buffer.readUtf8() + assertEquals("a".repeat(SEGMENT_SIZE - 10) + "hello, world!", out) + } + + @Test + fun readFromStreamWithCount() { + val `in`: InputStream = ByteArrayInputStream("hello, world!".toByteArray(UTF_8)) + val buffer = Buffer() + buffer.readFrom(`in`, 10) + val out = buffer.readUtf8() + assertEquals("hello, wor", out) + } + + @Test + fun readFromDoesNotLeaveEmptyTailSegment() { + val buffer = Buffer() + buffer.readFrom(ByteArrayInputStream(ByteArray(SEGMENT_SIZE))) + assertNoEmptySegments(buffer) + } + + @Test + fun moveAllRequestedBytesWithRead() { + val sink = Buffer() + sink.writeUtf8("a".repeat(10)) + val source = Buffer() + source.writeUtf8("b".repeat(15)) + assertEquals(10, source.read(sink, 10)) + assertEquals(20, sink.size) + assertEquals(5, source.size) + assertEquals("a".repeat(10) + "b".repeat(10), sink.readUtf8(20)) + } + + @Test + fun moveFewerThanRequestedBytesWithRead() { + val sink = Buffer() + sink.writeUtf8("a".repeat(10)) + val source = Buffer() + source.writeUtf8("b".repeat(20)) + assertEquals(20, source.read(sink, 25)) + assertEquals(30, sink.size) + assertEquals(0, source.size) + assertEquals("a".repeat(10) + "b".repeat(20), sink.readUtf8(30)) + } + + @Test + fun indexOfWithOffset() { + val buffer = Buffer() + val halfSegment: Int = SEGMENT_SIZE / 2 + buffer.writeUtf8("a".repeat(halfSegment)) + buffer.writeUtf8("b".repeat(halfSegment)) + buffer.writeUtf8("c".repeat(halfSegment)) + buffer.writeUtf8("d".repeat(halfSegment)) + assertEquals(0, buffer.indexOf('a'.code.toByte(), 0)) + assertEquals((halfSegment - 1).toLong(), buffer.indexOf('a'.code.toByte(), (halfSegment - 1).toLong())) + assertEquals(halfSegment.toLong(), buffer.indexOf('b'.code.toByte(), (halfSegment - 1).toLong())) + assertEquals((halfSegment * 2).toLong(), buffer.indexOf('c'.code.toByte(), (halfSegment - 1).toLong())) + assertEquals((halfSegment * 3).toLong(), buffer.indexOf('d'.code.toByte(), (halfSegment - 1).toLong())) + assertEquals((halfSegment * 3).toLong(), buffer.indexOf('d'.code.toByte(), (halfSegment * 2).toLong())) + assertEquals((halfSegment * 3).toLong(), buffer.indexOf('d'.code.toByte(), (halfSegment * 3).toLong())) + assertEquals((halfSegment * 4 - 1).toLong(), buffer.indexOf('d'.code.toByte(), (halfSegment * 4 - 1).toLong())) + } + + @Test + fun byteAt() { + val buffer = Buffer() + buffer.writeUtf8("a") + buffer.writeUtf8("b".repeat(SEGMENT_SIZE)) + buffer.writeUtf8("c") + assertEquals('a'.code.toLong(), buffer[0].toLong()) + assertEquals('a'.code.toLong(), buffer[0].toLong()) // getByte doesn't mutate! + assertEquals('c'.code.toLong(), buffer[buffer.size - 1].toLong()) + assertEquals('b'.code.toLong(), buffer[buffer.size - 2].toLong()) + assertEquals('b'.code.toLong(), buffer[buffer.size - 3].toLong()) + } + + @Test + fun getByteOfEmptyBuffer() { + val buffer = Buffer() + try { + buffer[0] + fail() + } catch (expected: IndexOutOfBoundsException) { + } + } + + @Test + fun writePrefixToEmptyBuffer() { + val sink = Buffer() + val source = Buffer() + source.writeUtf8("abcd") + sink.write(source, 2) + assertEquals("ab", sink.readUtf8(2)) + } + + @Test + fun cloneDoesNotObserveWritesToOriginal() { + val original = Buffer() + val clone = original.clone() + original.writeUtf8("abc") + assertEquals(0, clone.size) + } + + @Test + fun cloneDoesNotObserveReadsFromOriginal() { + val original = Buffer() + original.writeUtf8("abc") + val clone = original.clone() + assertEquals("abc", original.readUtf8(3)) + assertEquals(3, clone.size) + assertEquals("ab", clone.readUtf8(2)) + } + + @Test + fun originalDoesNotObserveWritesToClone() { + val original = Buffer() + val clone = original.clone() + clone.writeUtf8("abc") + assertEquals(0, original.size) + } + + @Test + fun originalDoesNotObserveReadsFromClone() { + val original = Buffer() + original.writeUtf8("abc") + val clone = original.clone() + assertEquals("abc", clone.readUtf8(3)) + assertEquals(3, original.size) + assertEquals("ab", original.readUtf8(2)) + } + + @Test + fun cloneMultipleSegments() { + val original = Buffer() + original.writeUtf8("a".repeat(SEGMENT_SIZE * 3)) + val clone = original.clone() + original.writeUtf8("b".repeat(SEGMENT_SIZE * 3)) + clone.writeUtf8("c".repeat(SEGMENT_SIZE * 3)) + assertEquals( + "a".repeat(SEGMENT_SIZE * 3) + "b".repeat(SEGMENT_SIZE * 3), + original.readUtf8((SEGMENT_SIZE * 6).toLong()), + ) + assertEquals( + "a".repeat(SEGMENT_SIZE * 3) + "c".repeat(SEGMENT_SIZE * 3), + clone.readUtf8((SEGMENT_SIZE * 6).toLong()), + ) + } + + @Test + fun equalsAndHashCodeEmpty() { + val a = Buffer() + val b = Buffer() + assertEquals(a, b) + assertEquals(a.hashCode().toLong(), b.hashCode().toLong()) + } + + @Test + fun equalsAndHashCode() { + val a = Buffer().writeUtf8("dog") + val b = Buffer().writeUtf8("hotdog") + Assert.assertNotEquals(a, b) + Assert.assertNotEquals(a.hashCode().toLong(), b.hashCode().toLong()) + b.readUtf8(3) // Leaves b containing 'dog'. + assertEquals(a, b) + assertEquals(a.hashCode().toLong(), b.hashCode().toLong()) + } + + @Test + fun equalsAndHashCodeSpanningSegments() { + val data = ByteArray(1024 * 1024) + val dice = Random(0) + dice.nextBytes(data) + val a = bufferWithRandomSegmentLayout(dice, data) + val b = bufferWithRandomSegmentLayout(dice, data) + assertEquals(a, b) + assertEquals(a.hashCode().toLong(), b.hashCode().toLong()) + data[data.size / 2]++ // Change a single byte. + val c = bufferWithRandomSegmentLayout(dice, data) + Assert.assertNotEquals(a, c) + Assert.assertNotEquals(a.hashCode().toLong(), c.hashCode().toLong()) + } + + @Test + fun bufferInputStreamByteByByte() { + val source = Buffer() + source.writeUtf8("abc") + val `in` = source.inputStream() + assertEquals(3, `in`.available().toLong()) + assertEquals('a'.code.toLong(), `in`.read().toLong()) + assertEquals('b'.code.toLong(), `in`.read().toLong()) + assertEquals('c'.code.toLong(), `in`.read().toLong()) + assertEquals(-1, `in`.read().toLong()) + assertEquals(0, `in`.available().toLong()) + } + + @Test + fun bufferInputStreamBulkReads() { + val source = Buffer() + source.writeUtf8("abc") + val byteArray = ByteArray(4) + byteArray.fill(-5) + val `in` = source.inputStream() + assertEquals(3, `in`.read(byteArray).toLong()) + assertEquals("[97, 98, 99, -5]", Arrays.toString(byteArray)) + byteArray.fill(-7) + assertEquals(-1, `in`.read(byteArray).toLong()) + assertEquals("[-7, -7, -7, -7]", Arrays.toString(byteArray)) + } + + /** + * When writing data that's already buffered, there's no reason to page the + * data by segment. + */ + @Test + fun readAllWritesAllSegmentsAtOnce() { + val write1 = Buffer().writeUtf8( + "" + + "a".repeat(SEGMENT_SIZE) + + "b".repeat(SEGMENT_SIZE) + + "c".repeat(SEGMENT_SIZE), + ) + val source = Buffer().writeUtf8( + "" + + "a".repeat(SEGMENT_SIZE) + + "b".repeat(SEGMENT_SIZE) + + "c".repeat(SEGMENT_SIZE), + ) + val mockSink = MockSink() + assertEquals((SEGMENT_SIZE * 3).toLong(), source.readAll(mockSink)) + assertEquals(0, source.size) + mockSink.assertLog("write(" + write1 + ", " + write1.size + ")") + } + + @Test + fun writeAllMultipleSegments() { + val source = Buffer().writeUtf8("a".repeat(SEGMENT_SIZE * 3)) + val sink = Buffer() + assertEquals((SEGMENT_SIZE * 3).toLong(), sink.writeAll(source)) + assertEquals(0, source.size) + assertEquals("a".repeat(SEGMENT_SIZE * 3), sink.readUtf8()) + } + + @Test + fun copyTo() { + val source = Buffer() + source.writeUtf8("party") + val target = Buffer() + source.copyTo(target, 1, 3) + assertEquals("art", target.readUtf8()) + assertEquals("party", source.readUtf8()) + } + + @Test + fun copyToOnSegmentBoundary() { + val `as` = "a".repeat(SEGMENT_SIZE) + val bs = "b".repeat(SEGMENT_SIZE) + val cs = "c".repeat(SEGMENT_SIZE) + val ds = "d".repeat(SEGMENT_SIZE) + val source = Buffer() + source.writeUtf8(`as`) + source.writeUtf8(bs) + source.writeUtf8(cs) + val target = Buffer() + target.writeUtf8(ds) + source.copyTo(target, `as`.length.toLong(), (bs.length + cs.length).toLong()) + assertEquals(ds + bs + cs, target.readUtf8()) + } + + @Test + fun copyToOffSegmentBoundary() { + val `as` = "a".repeat(SEGMENT_SIZE - 1) + val bs = "b".repeat(SEGMENT_SIZE + 2) + val cs = "c".repeat(SEGMENT_SIZE - 4) + val ds = "d".repeat(SEGMENT_SIZE + 8) + val source = Buffer() + source.writeUtf8(`as`) + source.writeUtf8(bs) + source.writeUtf8(cs) + val target = Buffer() + target.writeUtf8(ds) + source.copyTo(target, `as`.length.toLong(), (bs.length + cs.length).toLong()) + assertEquals(ds + bs + cs, target.readUtf8()) + } + + @Test + fun copyToSourceAndTargetCanBeTheSame() { + val `as` = "a".repeat(SEGMENT_SIZE) + val bs = "b".repeat(SEGMENT_SIZE) + val source = Buffer() + source.writeUtf8(`as`) + source.writeUtf8(bs) + source.copyTo(source, 0, source.size) + assertEquals(`as` + bs + `as` + bs, source.readUtf8()) + } + + @Test + fun copyToEmptySource() { + val source = Buffer() + val target = Buffer().writeUtf8("aaa") + source.copyTo(target, 0L, 0L) + assertEquals("", source.readUtf8()) + assertEquals("aaa", target.readUtf8()) + } + + @Test + fun copyToEmptyTarget() { + val source = Buffer().writeUtf8("aaa") + val target = Buffer() + source.copyTo(target, 0L, 3L) + assertEquals("aaa", source.readUtf8()) + assertEquals("aaa", target.readUtf8()) + } + + @Test + fun snapshotReportsAccurateSize() { + val buf = Buffer().write(byteArrayOf(0, 1, 2, 3)) + assertEquals(1, buf.snapshot(1).size.toLong()) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/BufferedSinkJavaTest.kt b/okio/src/jvmTest/kotlin/okio/BufferedSinkJavaTest.kt new file mode 100644 index 00000000..5faf7605 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/BufferedSinkJavaTest.kt @@ -0,0 +1,250 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.IOException +import okio.TestUtil.SEGMENT_SIZE +import org.junit.Assert.assertEquals +import org.junit.Assert.fail +import org.junit.Test + +/** + * Tests solely for the behavior of RealBufferedSink's implementation. For generic + * BufferedSink behavior use BufferedSinkTest. + */ +class BufferedSinkJavaTest { + @Test + fun inputStreamCloses() { + val sink = (Buffer() as Sink).buffer() + val out = sink.outputStream() + out.close() + try { + sink.writeUtf8("Hi!") + fail() + } catch (e: IllegalStateException) { + assertEquals("closed", e.message) + } + } + + @Test + fun bufferedSinkEmitsTailWhenItIsComplete() { + val sink = Buffer() + val bufferedSink = (sink as Sink).buffer() + bufferedSink.writeUtf8("a".repeat(SEGMENT_SIZE - 1)) + assertEquals(0, sink.size) + bufferedSink.writeByte(0) + assertEquals(SEGMENT_SIZE.toLong(), sink.size) + assertEquals(0, bufferedSink.buffer.size) + } + + @Test + fun bufferedSinkEmitMultipleSegments() { + val sink = Buffer() + val bufferedSink = (sink as Sink).buffer() + bufferedSink.writeUtf8("a".repeat(SEGMENT_SIZE * 4 - 1)) + assertEquals((SEGMENT_SIZE * 3).toLong(), sink.size) + assertEquals((SEGMENT_SIZE - 1).toLong(), bufferedSink.buffer.size) + } + + @Test + fun bufferedSinkFlush() { + val sink = Buffer() + val bufferedSink = (sink as Sink).buffer() + bufferedSink.writeByte('a'.code) + assertEquals(0, sink.size) + bufferedSink.flush() + assertEquals(0, bufferedSink.buffer.size) + assertEquals(1, sink.size) + } + + @Test + fun bytesEmittedToSinkWithFlush() { + val sink = Buffer() + val bufferedSink = (sink as Sink).buffer() + bufferedSink.writeUtf8("abc") + bufferedSink.flush() + assertEquals(3, sink.size) + } + + @Test + fun bytesNotEmittedToSinkWithoutFlush() { + val sink = Buffer() + val bufferedSink = (sink as Sink).buffer() + bufferedSink.writeUtf8("abc") + assertEquals(0, sink.size) + } + + @Test + fun bytesEmittedToSinkWithEmit() { + val sink = Buffer() + val bufferedSink = (sink as Sink).buffer() + bufferedSink.writeUtf8("abc") + bufferedSink.emit() + assertEquals(3, sink.size) + } + + @Test + fun completeSegmentsEmitted() { + val sink = Buffer() + val bufferedSink = (sink as Sink).buffer() + bufferedSink.writeUtf8("a".repeat(SEGMENT_SIZE * 3)) + assertEquals((SEGMENT_SIZE * 3).toLong(), sink.size) + } + + @Test + fun incompleteSegmentsNotEmitted() { + val sink = Buffer() + val bufferedSink = (sink as Sink).buffer() + bufferedSink.writeUtf8("a".repeat(SEGMENT_SIZE * 3 - 1)) + assertEquals((SEGMENT_SIZE * 2).toLong(), sink.size) + } + + @Test + fun closeWithExceptionWhenWriting() { + val mockSink = MockSink() + mockSink.scheduleThrow(0, IOException()) + val bufferedSink = mockSink.buffer() + bufferedSink.writeByte('a'.code) + try { + bufferedSink.close() + fail() + } catch (expected: IOException) { + } + mockSink.assertLog("write([text=a], 1)", "close()") + } + + @Test + fun closeWithExceptionWhenClosing() { + val mockSink = MockSink() + mockSink.scheduleThrow(1, IOException()) + val bufferedSink = mockSink.buffer() + bufferedSink.writeByte('a'.code) + try { + bufferedSink.close() + fail() + } catch (expected: IOException) { + } + mockSink.assertLog("write([text=a], 1)", "close()") + } + + @Test + fun closeWithExceptionWhenWritingAndClosing() { + val mockSink = MockSink() + mockSink.scheduleThrow(0, IOException("first")) + mockSink.scheduleThrow(1, IOException("second")) + val bufferedSink = mockSink.buffer() + bufferedSink.writeByte('a'.code) + try { + bufferedSink.close() + fail() + } catch (expected: IOException) { + assertEquals("first", expected.message) + } + mockSink.assertLog("write([text=a], 1)", "close()") + } + + @Test + fun operationsAfterClose() { + val mockSink = MockSink() + val bufferedSink = mockSink.buffer() + bufferedSink.writeByte('a'.code) + bufferedSink.close() + + // Test a sample set of methods. + try { + bufferedSink.writeByte('a'.code) + fail() + } catch (expected: IllegalStateException) { + } + try { + bufferedSink.write(ByteArray(10)) + fail() + } catch (expected: IllegalStateException) { + } + try { + bufferedSink.emitCompleteSegments() + fail() + } catch (expected: IllegalStateException) { + } + try { + bufferedSink.emit() + fail() + } catch (expected: IllegalStateException) { + } + try { + bufferedSink.flush() + fail() + } catch (expected: IllegalStateException) { + } + + // Test a sample set of methods on the OutputStream. + val os = bufferedSink.outputStream() + try { + os.write('a'.code) + fail() + } catch (expected: IOException) { + } + try { + os.write(ByteArray(10)) + fail() + } catch (expected: IOException) { + } + + // Permitted + os.flush() + } + + @Test + fun writeAll() { + val mockSink = MockSink() + val bufferedSink = mockSink.buffer() + bufferedSink.buffer.writeUtf8("abc") + assertEquals(3, bufferedSink.writeAll(Buffer().writeUtf8("def"))) + assertEquals(6, bufferedSink.buffer.size) + assertEquals("abcdef", bufferedSink.buffer.readUtf8(6)) + mockSink.assertLog() // No writes. + } + + @Test + fun writeAllExhausted() { + val mockSink = MockSink() + val bufferedSink = mockSink.buffer() + assertEquals(0, bufferedSink.writeAll(Buffer())) + assertEquals(0, bufferedSink.buffer.size) + mockSink.assertLog() // No writes. + } + + @Test + fun writeAllWritesOneSegmentAtATime() { + val write1 = Buffer().writeUtf8("a".repeat(SEGMENT_SIZE)) + val write2 = Buffer().writeUtf8("b".repeat(SEGMENT_SIZE)) + val write3 = Buffer().writeUtf8("c".repeat(SEGMENT_SIZE)) + val source = Buffer().writeUtf8( + "" + + "a".repeat(SEGMENT_SIZE) + + "b".repeat(SEGMENT_SIZE) + + "c".repeat(SEGMENT_SIZE), + ) + val mockSink = MockSink() + val bufferedSink = mockSink.buffer() + assertEquals((SEGMENT_SIZE * 3).toLong(), bufferedSink.writeAll(source)) + mockSink.assertLog( + "write(" + write1 + ", " + write1.size + ")", + "write(" + write2 + ", " + write2.size + ")", + "write(" + write3 + ", " + write3.size + ")", + ) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/BufferedSinkTest.kt b/okio/src/jvmTest/kotlin/okio/BufferedSinkTest.kt new file mode 100644 index 00000000..c9b3d187 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/BufferedSinkTest.kt @@ -0,0 +1,389 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.EOFException +import java.math.BigInteger +import java.nio.ByteBuffer +import java.nio.charset.Charset +import kotlin.text.Charsets.UTF_8 +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.encodeUtf8 +import okio.TestUtil.SEGMENT_SIZE +import okio.TestUtil.segmentSizes +import org.junit.Assert.assertEquals +import org.junit.Assert.fail +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +@RunWith(Parameterized::class) +class BufferedSinkTest( + factory: Factory, +) { + interface Factory { + fun create(data: Buffer): BufferedSink + + companion object { + val BUFFER: Factory = object : Factory { + override fun create(data: Buffer) = data + override fun toString() = "Buffer" + } + val REAL_BUFFERED_SINK: Factory = object : Factory { + override fun create(data: Buffer) = (data as Sink).buffer() + override fun toString() = "RealBufferedSink" + } + } + } + + private val data: Buffer = Buffer() + private val sink: BufferedSink = factory.create(data) + + @Test + fun writeNothing() { + sink.writeUtf8("") + sink.flush() + assertEquals(0, data.size) + } + + @Test + fun writeBytes() { + sink.writeByte(0xab) + sink.writeByte(0xcd) + sink.flush() + assertEquals("[hex=abcd]", data.toString()) + } + + @Test + fun writeLastByteInSegment() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 1)) + sink.writeByte(0x20) + sink.writeByte(0x21) + sink.flush() + assertEquals(listOf(SEGMENT_SIZE, 1), segmentSizes(data)) + assertEquals("a".repeat(SEGMENT_SIZE - 1), data.readUtf8((SEGMENT_SIZE - 1).toLong())) + assertEquals("[text= !]", data.toString()) + } + + @Test + fun writeShort() { + sink.writeShort(0xabcd) + sink.writeShort(0x4321) + sink.flush() + assertEquals("[hex=abcd4321]", data.toString()) + } + + @Test + fun writeShortLe() { + sink.writeShortLe(0xcdab) + sink.writeShortLe(0x2143) + sink.flush() + assertEquals("[hex=abcd4321]", data.toString()) + } + + @Test + fun writeInt() { + sink.writeInt(-0x543210ff) + sink.writeInt(-0x789abcdf) + sink.flush() + assertEquals("[hex=abcdef0187654321]", data.toString()) + } + + @Test + fun writeLastIntegerInSegment() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 4)) + sink.writeInt(-0x543210ff) + sink.writeInt(-0x789abcdf) + sink.flush() + assertEquals(listOf(SEGMENT_SIZE, 4), segmentSizes(data)) + assertEquals("a".repeat(SEGMENT_SIZE - 4), data.readUtf8((SEGMENT_SIZE - 4).toLong())) + assertEquals("[hex=abcdef0187654321]", data.toString()) + } + + @Test + fun writeIntegerDoesNotQuiteFitInSegment() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 3)) + sink.writeInt(-0x543210ff) + sink.writeInt(-0x789abcdf) + sink.flush() + assertEquals(listOf(SEGMENT_SIZE - 3, 8), segmentSizes(data)) + assertEquals("a".repeat(SEGMENT_SIZE - 3), data.readUtf8((SEGMENT_SIZE - 3).toLong())) + assertEquals("[hex=abcdef0187654321]", data.toString()) + } + + @Test + fun writeIntLe() { + sink.writeIntLe(-0x543210ff) + sink.writeIntLe(-0x789abcdf) + sink.flush() + assertEquals("[hex=01efcdab21436587]", data.toString()) + } + + @Test + fun writeLong() { + sink.writeLong(-0x543210fe789abcdfL) + sink.writeLong(-0x350145414f4ea400L) + sink.flush() + assertEquals("[hex=abcdef0187654321cafebabeb0b15c00]", data.toString()) + } + + @Test + fun writeLongLe() { + sink.writeLongLe(-0x543210fe789abcdfL) + sink.writeLongLe(-0x350145414f4ea400L) + sink.flush() + assertEquals("[hex=2143658701efcdab005cb1b0bebafeca]", data.toString()) + } + + @Test + fun writeByteString() { + sink.write("təˈranəˌsôr".encodeUtf8()) + sink.flush() + assertEquals("74c999cb8872616ec999cb8c73c3b472".decodeHex(), data.readByteString()) + } + + @Test + fun writeByteStringOffset() { + sink.write("təˈranəˌsôr".encodeUtf8(), 5, 5) + sink.flush() + assertEquals("72616ec999".decodeHex(), data.readByteString()) + } + + @Test + fun writeSegmentedByteString() { + sink.write(Buffer().write("təˈranəˌsôr".encodeUtf8()).snapshot()) + sink.flush() + assertEquals("74c999cb8872616ec999cb8c73c3b472".decodeHex(), data.readByteString()) + } + + @Test + fun writeSegmentedByteStringOffset() { + sink.write(Buffer().write("təˈranəˌsôr".encodeUtf8()).snapshot(), 5, 5) + sink.flush() + assertEquals("72616ec999".decodeHex(), data.readByteString()) + } + + @Test + fun writeStringUtf8() { + sink.writeUtf8("təˈranəˌsôr") + sink.flush() + assertEquals("74c999cb8872616ec999cb8c73c3b472".decodeHex(), data.readByteString()) + } + + @Test + fun writeSubstringUtf8() { + sink.writeUtf8("təˈranəˌsôr", 3, 7) + sink.flush() + assertEquals("72616ec999".decodeHex(), data.readByteString()) + } + + @Test + fun writeStringWithCharset() { + sink.writeString("təˈranəˌsôr", Charset.forName("utf-32be")) + sink.flush() + assertEquals( + ( + "0000007400000259000002c800000072000000610000006e00000259" + + "000002cc00000073000000f400000072" + ).decodeHex(), + data.readByteString(), + ) + } + + @Test + fun writeSubstringWithCharset() { + sink.writeString("təˈranəˌsôr", 3, 7, Charset.forName("utf-32be")) + sink.flush() + assertEquals("00000072000000610000006e00000259".decodeHex(), data.readByteString()) + } + + @Test + fun writeUtf8SubstringWithCharset() { + sink.writeString("təˈranəˌsôr", 3, 7, Charset.forName("utf-8")) + sink.flush() + assertEquals("ranə".encodeUtf8(), data.readByteString()) + } + + @Test + fun writeAll() { + val source = Buffer().writeUtf8("abcdef") + assertEquals(6, sink.writeAll(source)) + assertEquals(0, source.size) + sink.flush() + assertEquals("abcdef", data.readUtf8()) + } + + @Test + fun writeSource() { + val source = Buffer().writeUtf8("abcdef") + + // Force resolution of the Source method overload. + sink.write((source as Source), 4) + sink.flush() + assertEquals("abcd", data.readUtf8()) + assertEquals("ef", source.readUtf8()) + } + + @Test + fun writeSourceReadsFully() { + val source: Source = object : ForwardingSource(Buffer()) { + override fun read(sink: Buffer, byteCount: Long): Long { + sink.writeUtf8("abcd") + return 4 + } + } + sink.write(source, 8) + sink.flush() + assertEquals("abcdabcd", data.readUtf8()) + } + + @Test + fun writeSourcePropagatesEof() { + val source: Source = Buffer().writeUtf8("abcd") + try { + sink.write(source, 8) + fail() + } catch (expected: EOFException) { + } + + // Ensure that whatever was available was correctly written. + sink.flush() + assertEquals("abcd", data.readUtf8()) + } + + @Test + fun writeSourceWithZeroIsNoOp() { + // This test ensures that a zero byte count never calls through to read the source. It may be + // tied to something like a socket which will potentially block trying to read a segment when + // ultimately we don't want any data. + val source: Source = object : ForwardingSource(Buffer()) { + override fun read(sink: Buffer, byteCount: Long): Long { + throw AssertionError() + } + } + sink.write(source, 0) + assertEquals(0, data.size) + } + + @Test + fun writeAllExhausted() { + val source = Buffer() + assertEquals(0, sink.writeAll(source)) + assertEquals(0, source.size) + } + + @Test + fun closeEmitsBufferedBytes() { + sink.writeByte('a'.code) + sink.close() + assertEquals('a'.code.toLong(), data.readByte().toLong()) + } + + @Test + fun outputStream() { + val out = sink.outputStream() + out.write('a'.code) + out.write("b".repeat(9998).toByteArray(UTF_8)) + out.write('c'.code) + out.flush() + assertEquals("a" + "b".repeat(9998) + "c", data.readUtf8()) + } + + @Test + fun outputStreamBounds() { + val out = sink.outputStream() + try { + out.write(ByteArray(100), 50, 51) + fail() + } catch (expected: ArrayIndexOutOfBoundsException) { + } + } + + @Test + fun longDecimalString() { + assertLongDecimalString(0) + assertLongDecimalString(Long.MIN_VALUE) + assertLongDecimalString(Long.MAX_VALUE) + for (i in 1..19) { + val value = BigInteger.valueOf(10L).pow(i).toLong() + assertLongDecimalString(value - 1) + assertLongDecimalString(value) + } + } + + private fun assertLongDecimalString(value: Long) { + sink.writeDecimalLong(value).writeUtf8("zzz").flush() + val expected = java.lang.Long.toString(value) + "zzz" + val actual = data.readUtf8() + assertEquals("$value expected $expected but was $actual", actual, expected) + } + + @Test + fun longHexString() { + assertLongHexString(0) + assertLongHexString(Long.MIN_VALUE) + assertLongHexString(Long.MAX_VALUE) + for (i in 0..62) { + assertLongHexString((1L shl i) - 1) + assertLongHexString(1L shl i) + } + } + + @Test + fun writeNioBuffer() { + val expected = "abcdefg" + val nioByteBuffer = ByteBuffer.allocate(1024) + nioByteBuffer.put("abcdefg".toByteArray(UTF_8)) + (nioByteBuffer as java.nio.Buffer).flip() // Cast necessary for Java 8. + val byteCount = sink.write(nioByteBuffer) + assertEquals(expected.length.toLong(), byteCount.toLong()) + assertEquals(expected.length.toLong(), nioByteBuffer.position().toLong()) + assertEquals(expected.length.toLong(), nioByteBuffer.limit().toLong()) + sink.flush() + assertEquals(expected, data.readUtf8()) + } + + @Test + fun writeLargeNioBufferWritesAllData() { + val expected = "a".repeat(SEGMENT_SIZE * 3) + val nioByteBuffer = ByteBuffer.allocate(SEGMENT_SIZE * 4) + nioByteBuffer.put("a".repeat(SEGMENT_SIZE * 3).toByteArray(UTF_8)) + (nioByteBuffer as java.nio.Buffer).flip() // Cast necessary for Java 8. + val byteCount = sink.write(nioByteBuffer) + assertEquals(expected.length.toLong(), byteCount.toLong()) + assertEquals(expected.length.toLong(), nioByteBuffer.position().toLong()) + assertEquals(expected.length.toLong(), nioByteBuffer.limit().toLong()) + sink.flush() + assertEquals(expected, data.readUtf8()) + } + + private fun assertLongHexString(value: Long) { + sink.writeHexadecimalUnsignedLong(value).writeUtf8("zzz").flush() + val expected = String.format("%x", value) + "zzz" + val actual = data.readUtf8() + assertEquals("$value expected $expected but was $actual", actual, expected) + } + + companion object { + @JvmStatic + @Parameters(name = "{0}") + fun parameters(): List<Array<Any>> = listOf( + arrayOf(Factory.BUFFER), + arrayOf(Factory.REAL_BUFFERED_SINK), + ) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/BufferedSourceJavaTest.kt b/okio/src/jvmTest/kotlin/okio/BufferedSourceJavaTest.kt new file mode 100644 index 00000000..cd7f15ea --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/BufferedSourceJavaTest.kt @@ -0,0 +1,225 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.EOFException +import java.io.IOException +import kotlin.text.Charsets.UTF_8 +import okio.TestUtil.SEGMENT_SIZE +import org.junit.Assert.assertEquals +import org.junit.Assert.fail +import org.junit.Test + +/** + * Tests solely for the behavior of RealBufferedSource's implementation. For generic + * BufferedSource behavior use BufferedSourceTest. + */ +class BufferedSourceJavaTest { + @Test + fun inputStreamTracksSegments() { + val source = Buffer() + source.writeUtf8("a") + source.writeUtf8("b".repeat(SEGMENT_SIZE)) + source.writeUtf8("c") + val `in` = (source as Source).buffer().inputStream() + assertEquals(0, `in`.available().toLong()) + assertEquals((SEGMENT_SIZE + 2).toLong(), source.size) + + // Reading one byte buffers a full segment. + assertEquals('a'.code.toLong(), `in`.read().toLong()) + assertEquals((SEGMENT_SIZE - 1).toLong(), `in`.available().toLong()) + assertEquals(2, source.size) + + // Reading as much as possible reads the rest of that buffered segment. + val data = ByteArray(SEGMENT_SIZE * 2) + assertEquals((SEGMENT_SIZE - 1).toLong(), `in`.read(data, 0, data.size).toLong()) + assertEquals("b".repeat(SEGMENT_SIZE - 1), String(data, 0, SEGMENT_SIZE - 1, UTF_8)) + assertEquals(2, source.size) + + // Continuing to read buffers the next segment. + assertEquals('b'.code.toLong(), `in`.read().toLong()) + assertEquals(1, `in`.available().toLong()) + assertEquals(0, source.size) + + // Continuing to read reads from the buffer. + assertEquals('c'.code.toLong(), `in`.read().toLong()) + assertEquals(0, `in`.available().toLong()) + assertEquals(0, source.size) + + // Once we've exhausted the source, we're done. + assertEquals(-1, `in`.read().toLong()) + assertEquals(0, source.size) + } + + @Test + fun inputStreamCloses() { + val source = (Buffer() as Source).buffer() + val inputStream = source.inputStream() + inputStream.close() + try { + source.require(1) + fail() + } catch (e: IllegalStateException) { + assertEquals("closed", e.message) + } + } + + @Test + fun indexOfStopsReadingAtLimit() { + val buffer = Buffer().writeUtf8("abcdef") + val bufferedSource = object : ForwardingSource(buffer) { + override fun read(sink: Buffer, byteCount: Long): Long { + return super.read(sink, Math.min(1, byteCount)) + } + }.buffer() + assertEquals(6, buffer.size) + assertEquals(-1, bufferedSource.indexOf('e'.code.toByte(), 0, 4)) + assertEquals(2, buffer.size) + } + + @Test + fun requireTracksBufferFirst() { + val source = Buffer() + source.writeUtf8("bb") + val bufferedSource = (source as Source).buffer() + bufferedSource.buffer.writeUtf8("aa") + bufferedSource.require(2) + assertEquals(2, bufferedSource.buffer.size) + assertEquals(2, source.size) + } + + @Test + fun requireIncludesBufferBytes() { + val source = Buffer() + source.writeUtf8("b") + val bufferedSource = (source as Source).buffer() + bufferedSource.buffer.writeUtf8("a") + bufferedSource.require(2) + assertEquals("ab", bufferedSource.buffer.readUtf8(2)) + } + + @Test + fun requireInsufficientData() { + val source = Buffer() + source.writeUtf8("a") + val bufferedSource = (source as Source).buffer() + try { + bufferedSource.require(2) + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun requireReadsOneSegmentAtATime() { + val source = Buffer() + source.writeUtf8("a".repeat(SEGMENT_SIZE)) + source.writeUtf8("b".repeat(SEGMENT_SIZE)) + val bufferedSource = (source as Source).buffer() + bufferedSource.require(2) + assertEquals(SEGMENT_SIZE.toLong(), source.size) + assertEquals(SEGMENT_SIZE.toLong(), bufferedSource.buffer.size) + } + + @Test + fun skipReadsOneSegmentAtATime() { + val source = Buffer() + source.writeUtf8("a".repeat(SEGMENT_SIZE)) + source.writeUtf8("b".repeat(SEGMENT_SIZE)) + val bufferedSource = (source as Source).buffer() + bufferedSource.skip(2) + assertEquals(SEGMENT_SIZE.toLong(), source.size) + assertEquals((SEGMENT_SIZE - 2).toLong(), bufferedSource.buffer.size) + } + + @Test + fun skipTracksBufferFirst() { + val source = Buffer() + source.writeUtf8("bb") + val bufferedSource = (source as Source).buffer() + bufferedSource.buffer.writeUtf8("aa") + bufferedSource.skip(2) + assertEquals(0, bufferedSource.buffer.size) + assertEquals(2, source.size) + } + + @Test + fun operationsAfterClose() { + val source = Buffer() + val bufferedSource = (source as Source).buffer() + bufferedSource.close() + + // Test a sample set of methods. + try { + bufferedSource.indexOf(1.toByte()) + fail() + } catch (expected: IllegalStateException) { + } + try { + bufferedSource.skip(1) + fail() + } catch (expected: IllegalStateException) { + } + try { + bufferedSource.readByte() + fail() + } catch (expected: IllegalStateException) { + } + try { + bufferedSource.readByteString(10) + fail() + } catch (expected: IllegalStateException) { + } + + // Test a sample set of methods on the InputStream. + val inputStream = bufferedSource.inputStream() + try { + inputStream.read() + fail() + } catch (expected: IOException) { + } + try { + inputStream.read(ByteArray(10)) + fail() + } catch (expected: IOException) { + } + } + + /** + * We don't want readAll to buffer an unbounded amount of data. Instead it + * should buffer a segment, write it, and repeat. + */ + @Test + fun readAllReadsOneSegmentAtATime() { + val write1 = Buffer().writeUtf8("a".repeat(SEGMENT_SIZE)) + val write2 = Buffer().writeUtf8("b".repeat(SEGMENT_SIZE)) + val write3 = Buffer().writeUtf8("c".repeat(SEGMENT_SIZE)) + val source = Buffer().writeUtf8( + "" + + "a".repeat(SEGMENT_SIZE) + + "b".repeat(SEGMENT_SIZE) + + "c".repeat(SEGMENT_SIZE), + ) + val mockSink = MockSink() + val bufferedSource = (source as Source).buffer() + assertEquals((SEGMENT_SIZE * 3).toLong(), bufferedSource.readAll(mockSink)) + mockSink.assertLog( + "write(" + write1 + ", " + write1.size + ")", + "write(" + write2 + ", " + write2.size + ")", + "write(" + write3 + ", " + write3.size + ")", + ) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/BufferedSourceTest.kt b/okio/src/jvmTest/kotlin/okio/BufferedSourceTest.kt new file mode 100644 index 00000000..b30944b7 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/BufferedSourceTest.kt @@ -0,0 +1,1503 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.EOFException +import java.nio.ByteBuffer +import java.nio.charset.Charset +import java.util.Arrays +import kotlin.text.Charsets.US_ASCII +import kotlin.text.Charsets.UTF_8 +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.encodeUtf8 +import okio.Options.Companion.of +import okio.TestUtil.SEGMENT_SIZE +import okio.TestUtil.assertByteArrayEquals +import okio.TestUtil.assertByteArraysEquals +import okio.TestUtil.randomBytes +import okio.TestUtil.segmentSizes +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Assume +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +@RunWith(Parameterized::class) +class BufferedSourceTest( + private val factory: Factory, +) { + interface Factory { + fun pipe(): Pipe + val isOneByteAtATime: Boolean + + companion object { + val BUFFER: Factory = object : Factory { + override fun pipe(): Pipe { + val buffer = Buffer() + return Pipe(buffer, buffer) + } + + override val isOneByteAtATime: Boolean get() = false + + override fun toString() = "Buffer" + } + + val REAL_BUFFERED_SOURCE: Factory = object : Factory { + override fun pipe(): Pipe { + val buffer = Buffer() + return Pipe( + sink = buffer, + source = (buffer as Source).buffer(), + ) + } + + override val isOneByteAtATime: Boolean get() = false + + override fun toString() = "RealBufferedSource" + } + + /** + * A factory deliberately written to create buffers whose internal segments are always 1 byte + * long. We like testing with these segments because are likely to trigger bugs! + */ + val ONE_BYTE_AT_A_TIME_BUFFERED_SOURCE: Factory = object : Factory { + override fun pipe(): Pipe { + val buffer = Buffer() + return Pipe( + sink = buffer, + source = object : ForwardingSource(buffer) { + override fun read(sink: Buffer, byteCount: Long): Long { + // Read one byte into a new buffer, then clone it so that the segment is shared. + // Shared segments cannot be compacted so we'll get a long chain of short segments. + val box = Buffer() + val result = super.read(box, Math.min(byteCount, 1L)) + if (result > 0L) sink.write(box.clone(), result) + return result + } + }.buffer(), + ) + } + + override val isOneByteAtATime: Boolean get() = true + + override fun toString() = "OneByteAtATimeBufferedSource" + } + + val ONE_BYTE_AT_A_TIME_BUFFER: Factory = object : Factory { + override fun pipe(): Pipe { + val buffer = Buffer() + val sink = object : ForwardingSink(buffer) { + override fun write(source: Buffer, byteCount: Long) { + // Write each byte into a new buffer, then clone it so that the segments are shared. + // Shared segments cannot be compacted so we'll get a long chain of short segments. + for (i in 0 until byteCount) { + val box = Buffer() + box.write(source, 1) + super.write(box.clone(), 1) + } + } + }.buffer() + return Pipe( + sink = sink, + source = buffer, + ) + } + + override val isOneByteAtATime: Boolean get() = true + + override fun toString() = "OneByteAtATimeBuffer" + } + + val PEEK_BUFFER: Factory = object : Factory { + override fun pipe(): Pipe { + val buffer = Buffer() + return Pipe( + sink = buffer, + source = buffer.peek(), + ) + } + + override val isOneByteAtATime: Boolean get() = false + + override fun toString() = "PeekBuffer" + } + + val PEEK_BUFFERED_SOURCE: Factory = object : Factory { + override fun pipe(): Pipe { + val buffer = Buffer() + return Pipe( + sink = buffer, + source = (buffer as Source).buffer().peek(), + ) + } + + override val isOneByteAtATime: Boolean get() = false + + override fun toString() = "PeekBufferedSource" + } + } + } + + class Pipe( + var sink: BufferedSink, + var source: BufferedSource, + ) + + private val pipe = factory.pipe() + private val sink: BufferedSink = pipe.sink + private val source: BufferedSource = pipe.source + + @Test + fun readBytes() { + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte())) + sink.emit() + assertEquals(0xab, (source.readByte().toInt() and 0xff).toLong()) + assertEquals(0xcd, (source.readByte().toInt() and 0xff).toLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readByteTooShortThrows() { + try { + source.readByte() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun readShort() { + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte(), 0xef.toByte(), 0x01.toByte())) + sink.emit() + assertEquals(0xabcd.toShort().toLong(), source.readShort().toLong()) + assertEquals(0xef01.toShort().toLong(), source.readShort().toLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readShortLe() { + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte(), 0xef.toByte(), 0x10.toByte())) + sink.emit() + assertEquals(0xcdab.toShort().toLong(), source.readShortLe().toLong()) + assertEquals(0x10ef.toShort().toLong(), source.readShortLe().toLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readShortSplitAcrossMultipleSegments() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 1)) + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte())) + sink.emit() + source.skip((SEGMENT_SIZE - 1).toLong()) + assertEquals(0xabcd.toShort().toLong(), source.readShort().toLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readShortTooShortThrows() { + sink.writeShort(Short.MAX_VALUE.toInt()) + sink.emit() + source.readByte() + try { + source.readShort() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun readShortLeTooShortThrows() { + sink.writeShortLe(Short.MAX_VALUE.toInt()) + sink.emit() + source.readByte() + try { + source.readShortLe() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun readInt() { + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte(), 0xef.toByte(), 0x01.toByte(), 0x87.toByte(), 0x65.toByte(), 0x43.toByte(), 0x21.toByte())) + sink.emit() + assertEquals(-0x543210ff, source.readInt().toLong()) + assertEquals(-0x789abcdf, source.readInt().toLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readIntLe() { + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte(), 0xef.toByte(), 0x10.toByte(), 0x87.toByte(), 0x65.toByte(), 0x43.toByte(), 0x21.toByte())) + sink.emit() + assertEquals(0x10efcdab, source.readIntLe().toLong()) + assertEquals(0x21436587, source.readIntLe().toLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readIntSplitAcrossMultipleSegments() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 3)) + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte(), 0xef.toByte(), 0x01.toByte())) + sink.emit() + source.skip((SEGMENT_SIZE - 3).toLong()) + assertEquals(-0x543210ff, source.readInt().toLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readIntTooShortThrows() { + sink.writeInt(Int.MAX_VALUE) + sink.emit() + source.readByte() + try { + source.readInt() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun readIntLeTooShortThrows() { + sink.writeIntLe(Int.MAX_VALUE) + sink.emit() + source.readByte() + try { + source.readIntLe() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun readLong() { + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte(), 0xef.toByte(), 0x10.toByte(), 0x87.toByte(), 0x65.toByte(), 0x43.toByte(), 0x21.toByte(), 0x36.toByte(), 0x47.toByte(), 0x58.toByte(), 0x69.toByte(), 0x12.toByte(), 0x23.toByte(), 0x34.toByte(), 0x45.toByte())) + sink.emit() + assertEquals(-0x543210ef789abcdfL, source.readLong()) + assertEquals(0x3647586912233445L, source.readLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readLongLe() { + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte(), 0xef.toByte(), 0x10.toByte(), 0x87.toByte(), 0x65.toByte(), 0x43.toByte(), 0x21.toByte(), 0x36.toByte(), 0x47.toByte(), 0x58.toByte(), 0x69.toByte(), 0x12.toByte(), 0x23.toByte(), 0x34.toByte(), 0x45.toByte())) + sink.emit() + assertEquals(0x2143658710efcdabL, source.readLongLe()) + assertEquals(0x4534231269584736L, source.readLongLe()) + assertTrue(source.exhausted()) + } + + @Test + fun readLongSplitAcrossMultipleSegments() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 7)) + sink.write(byteArrayOf(0xab.toByte(), 0xcd.toByte(), 0xef.toByte(), 0x01.toByte(), 0x87.toByte(), 0x65.toByte(), 0x43.toByte(), 0x21.toByte())) + sink.emit() + source.skip((SEGMENT_SIZE - 7).toLong()) + assertEquals(-0x543210fe789abcdfL, source.readLong()) + assertTrue(source.exhausted()) + } + + @Test + fun readLongTooShortThrows() { + sink.writeLong(Long.MAX_VALUE) + sink.emit() + source.readByte() + try { + source.readLong() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun readLongLeTooShortThrows() { + sink.writeLongLe(Long.MAX_VALUE) + sink.emit() + source.readByte() + try { + source.readLongLe() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun readAll() { + source.buffer.writeUtf8("abc") + sink.writeUtf8("def") + sink.emit() + val sink = Buffer() + assertEquals(6, source.readAll(sink)) + assertEquals("abcdef", sink.readUtf8()) + assertTrue(source.exhausted()) + } + + @Test + fun readAllExhausted() { + val mockSink = MockSink() + assertEquals(0, source.readAll(mockSink)) + assertTrue(source.exhausted()) + mockSink.assertLog() + } + + @Test + fun readExhaustedSource() { + val sink = Buffer() + sink.writeUtf8("a".repeat(10)) + assertEquals(-1, source.read(sink, 10)) + assertEquals(10, sink.size) + assertTrue(source.exhausted()) + } + + @Test + fun readZeroBytesFromSource() { + val sink = Buffer() + sink.writeUtf8("a".repeat(10)) + + // Either 0 or -1 is reasonable here. For consistency with Android's + // ByteArrayInputStream we return 0. + assertEquals(-1, source.read(sink, 0)) + assertEquals(10, sink.size) + assertTrue(source.exhausted()) + } + + @Test + fun readFully() { + sink.writeUtf8("a".repeat(10000)) + sink.emit() + val sink = Buffer() + source.readFully(sink, 9999) + assertEquals("a".repeat(9999), sink.readUtf8()) + assertEquals("a", source.readUtf8()) + } + + @Test + fun readFullyTooShortThrows() { + sink.writeUtf8("Hi") + sink.emit() + val sink = Buffer() + try { + source.readFully(sink, 5) + fail() + } catch (ignored: EOFException) { + } + + // Verify we read all that we could from the source. + assertEquals("Hi", sink.readUtf8()) + } + + @Test + fun readFullyByteArray() { + val data = Buffer() + data.writeUtf8("Hello").writeUtf8("e".repeat(SEGMENT_SIZE)) + val expected = data.clone().readByteArray() + sink.write(data, data.size) + sink.emit() + val sink = ByteArray(SEGMENT_SIZE + 5) + source.readFully(sink) + assertByteArraysEquals(expected, sink) + } + + @Test + fun readFullyByteArrayTooShortThrows() { + sink.writeUtf8("Hello") + sink.emit() + val array = ByteArray(6) + try { + source.readFully(array) + fail() + } catch (ignored: EOFException) { + } + + // Verify we read all that we could from the source. + assertByteArraysEquals(byteArrayOf('H'.code.toByte(), 'e'.code.toByte(), 'l'.code.toByte(), 'l'.code.toByte(), 'o'.code.toByte(), 0), array) + } + + @Test + fun readIntoByteArray() { + sink.writeUtf8("abcd") + sink.emit() + val sink = ByteArray(3) + val read = source.read(sink) + if (factory.isOneByteAtATime) { + assertEquals(1, read.toLong()) + val expected = byteArrayOf('a'.code.toByte(), 0, 0) + assertByteArraysEquals(expected, sink) + } else { + assertEquals(3, read.toLong()) + val expected = byteArrayOf('a'.code.toByte(), 'b'.code.toByte(), 'c'.code.toByte()) + assertByteArraysEquals(expected, sink) + } + } + + @Test + fun readIntoByteArrayNotEnough() { + sink.writeUtf8("abcd") + sink.emit() + val sink = ByteArray(5) + val read = source.read(sink) + if (factory.isOneByteAtATime) { + assertEquals(1, read.toLong()) + val expected = byteArrayOf('a'.code.toByte(), 0, 0, 0, 0) + assertByteArraysEquals(expected, sink) + } else { + assertEquals(4, read.toLong()) + val expected = byteArrayOf('a'.code.toByte(), 'b'.code.toByte(), 'c'.code.toByte(), 'd'.code.toByte(), 0) + assertByteArraysEquals(expected, sink) + } + } + + @Test + fun readIntoByteArrayOffsetAndCount() { + sink.writeUtf8("abcd") + sink.emit() + val sink = ByteArray(7) + val read = source.read(sink, 2, 3) + if (factory.isOneByteAtATime) { + assertEquals(1, read.toLong()) + val expected = byteArrayOf(0, 0, 'a'.code.toByte(), 0, 0, 0, 0) + assertByteArraysEquals(expected, sink) + } else { + assertEquals(3, read.toLong()) + val expected = byteArrayOf(0, 0, 'a'.code.toByte(), 'b'.code.toByte(), 'c'.code.toByte(), 0, 0) + assertByteArraysEquals(expected, sink) + } + } + + @Test + fun readByteArray() { + val string = "abcd" + "e".repeat(SEGMENT_SIZE) + sink.writeUtf8(string) + sink.emit() + assertByteArraysEquals(string.toByteArray(UTF_8), source.readByteArray()) + } + + @Test + fun readByteArrayPartial() { + sink.writeUtf8("abcd") + sink.emit() + assertEquals("[97, 98, 99]", Arrays.toString(source.readByteArray(3))) + assertEquals("d", source.readUtf8(1)) + } + + @Test + fun readByteArrayTooShortThrows() { + sink.writeUtf8("abc") + sink.emit() + try { + source.readByteArray(4) + fail() + } catch (expected: EOFException) { + } + assertEquals("abc", source.readUtf8()) // The read shouldn't consume any data. + } + + @Test + fun readByteString() { + sink.writeUtf8("abcd").writeUtf8("e".repeat(SEGMENT_SIZE)) + sink.emit() + assertEquals("abcd" + "e".repeat(SEGMENT_SIZE), source.readByteString().utf8()) + } + + @Test + fun readByteStringPartial() { + sink.writeUtf8("abcd").writeUtf8("e".repeat(SEGMENT_SIZE)) + sink.emit() + assertEquals("abc", source.readByteString(3).utf8()) + assertEquals("d", source.readUtf8(1)) + } + + @Test + fun readByteStringTooShortThrows() { + sink.writeUtf8("abc") + sink.emit() + try { + source.readByteString(4) + fail() + } catch (expected: EOFException) { + } + assertEquals("abc", source.readUtf8()) // The read shouldn't consume any data. + } + + @Test + fun readSpecificCharsetPartial() { + sink.write( + ( + "0000007600000259000002c80000006c000000e40000007300000259" + + "000002cc000000720000006100000070000000740000025900000072" + ).decodeHex(), + ) + sink.emit() + assertEquals("vəˈläsə", source.readString((7 * 4).toLong(), Charset.forName("utf-32"))) + } + + @Test + fun readSpecificCharset() { + sink.write( + ( + "0000007600000259000002c80000006c000000e40000007300000259" + + "000002cc000000720000006100000070000000740000025900000072" + ).decodeHex(), + ) + sink.emit() + assertEquals("vəˈläsəˌraptər", source.readString(Charset.forName("utf-32"))) + } + + @Test + fun readStringTooShortThrows() { + sink.writeString("abc", US_ASCII) + sink.emit() + try { + source.readString(4, US_ASCII) + fail() + } catch (expected: EOFException) { + } + assertEquals("abc", source.readUtf8()) // The read shouldn't consume any data. + } + + @Test + fun readUtf8SpansSegments() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE * 2)) + sink.emit() + source.skip((SEGMENT_SIZE - 1).toLong()) + assertEquals("aa", source.readUtf8(2)) + } + + @Test + fun readUtf8Segment() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE)) + sink.emit() + assertEquals("a".repeat(SEGMENT_SIZE), source.readUtf8(SEGMENT_SIZE.toLong())) + } + + @Test + fun readUtf8PartialBuffer() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE + 20)) + sink.emit() + assertEquals("a".repeat(SEGMENT_SIZE + 10), source.readUtf8((SEGMENT_SIZE + 10).toLong())) + } + + @Test + fun readUtf8EntireBuffer() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE * 2)) + sink.emit() + assertEquals("a".repeat(SEGMENT_SIZE * 2), source.readUtf8()) + } + + @Test + fun readUtf8TooShortThrows() { + sink.writeUtf8("abc") + sink.emit() + try { + source.readUtf8(4L) + fail() + } catch (expected: EOFException) { + } + assertEquals("abc", source.readUtf8()) // The read shouldn't consume any data. + } + + @Test + fun skip() { + sink.writeUtf8("a") + sink.writeUtf8("b".repeat(SEGMENT_SIZE)) + sink.writeUtf8("c") + sink.emit() + source.skip(1) + assertEquals('b'.code.toLong(), (source.readByte().toInt() and 0xff).toLong()) + source.skip((SEGMENT_SIZE - 2).toLong()) + assertEquals('b'.code.toLong(), (source.readByte().toInt() and 0xff).toLong()) + source.skip(1) + assertTrue(source.exhausted()) + } + + @Test + fun skipInsufficientData() { + sink.writeUtf8("a") + sink.emit() + try { + source.skip(2) + fail() + } catch (ignored: EOFException) { + } + } + + @Test + fun indexOf() { + // The segment is empty. + assertEquals(-1, source.indexOf('a'.code.toByte())) + + // The segment has one value. + sink.writeUtf8("a") // a + sink.emit() + assertEquals(0, source.indexOf('a'.code.toByte())) + assertEquals(-1, source.indexOf('b'.code.toByte())) + + // The segment has lots of data. + sink.writeUtf8("b".repeat(SEGMENT_SIZE - 2)) // ab...b + sink.emit() + assertEquals(0, source.indexOf('a'.code.toByte())) + assertEquals(1, source.indexOf('b'.code.toByte())) + assertEquals(-1, source.indexOf('c'.code.toByte())) + + // The segment doesn't start at 0, it starts at 2. + source.skip(2) // b...b + assertEquals(-1, source.indexOf('a'.code.toByte())) + assertEquals(0, source.indexOf('b'.code.toByte())) + assertEquals(-1, source.indexOf('c'.code.toByte())) + + // The segment is full. + sink.writeUtf8("c") // b...bc + sink.emit() + assertEquals(-1, source.indexOf('a'.code.toByte())) + assertEquals(0, source.indexOf('b'.code.toByte())) + assertEquals((SEGMENT_SIZE - 3).toLong(), source.indexOf('c'.code.toByte())) + + // The segment doesn't start at 2, it starts at 4. + source.skip(2) // b...bc + assertEquals(-1, source.indexOf('a'.code.toByte())) + assertEquals(0, source.indexOf('b'.code.toByte())) + assertEquals((SEGMENT_SIZE - 5).toLong(), source.indexOf('c'.code.toByte())) + + // Two segments. + sink.writeUtf8("d") // b...bcd, d is in the 2nd segment. + sink.emit() + assertEquals((SEGMENT_SIZE - 4).toLong(), source.indexOf('d'.code.toByte())) + assertEquals(-1, source.indexOf('e'.code.toByte())) + } + + @Test + fun indexOfByteWithStartOffset() { + sink.writeUtf8("a").writeUtf8("b".repeat(SEGMENT_SIZE)).writeUtf8("c") + sink.emit() + assertEquals(-1, source.indexOf('a'.code.toByte(), 1)) + assertEquals(15, source.indexOf('b'.code.toByte(), 15)) + } + + @Test + fun indexOfByteWithBothOffsets() { + if (factory.isOneByteAtATime) { + // When run on Travis this causes out-of-memory errors. + return + } + val a = 'a'.code.toByte() + val c = 'c'.code.toByte() + val size: Int = SEGMENT_SIZE * 5 + val bytes = ByteArray(size) + Arrays.fill(bytes, a) + + // These are tricky places where the buffer + // starts, ends, or segments come together. + val points = intArrayOf( + 0, 1, 2, + SEGMENT_SIZE - 1, SEGMENT_SIZE, SEGMENT_SIZE + 1, + size / 2 - 1, size / 2, size / 2 + 1, + size - SEGMENT_SIZE - 1, size - SEGMENT_SIZE, size - SEGMENT_SIZE + 1, + size - 3, size - 2, size - 1, + ) + + // In each iteration, we write c to the known point and then search for it using different + // windows. Some of the windows don't overlap with c's position, and therefore a match shouldn't + // be found. + for (p in points) { + bytes[p] = c + sink.write(bytes) + sink.emit() + assertEquals(p.toLong(), source.indexOf(c, 0, size.toLong())) + assertEquals(p.toLong(), source.indexOf(c, 0, (p + 1).toLong())) + assertEquals(p.toLong(), source.indexOf(c, p.toLong(), size.toLong())) + assertEquals(p.toLong(), source.indexOf(c, p.toLong(), (p + 1).toLong())) + assertEquals(p.toLong(), source.indexOf(c, (p / 2).toLong(), (p * 2 + 1).toLong())) + assertEquals(-1, source.indexOf(c, 0, (p / 2).toLong())) + assertEquals(-1, source.indexOf(c, 0, p.toLong())) + assertEquals(-1, source.indexOf(c, 0, 0)) + assertEquals(-1, source.indexOf(c, p.toLong(), p.toLong())) + + // Reset. + source.readUtf8() + bytes[p] = a + } + } + + @Test + fun indexOfByteInvalidBoundsThrows() { + sink.writeUtf8("abc") + sink.emit() + try { + source.indexOf('a'.code.toByte(), -1) + fail("Expected failure: fromIndex < 0") + } catch (expected: IllegalArgumentException) { + } + try { + source.indexOf('a'.code.toByte(), 10, 0) + fail("Expected failure: fromIndex > toIndex") + } catch (expected: IllegalArgumentException) { + } + } + + @Test + fun indexOfByteString() { + assertEquals(-1, source.indexOf("flop".encodeUtf8())) + sink.writeUtf8("flip flop") + sink.emit() + assertEquals(5, source.indexOf("flop".encodeUtf8())) + source.readUtf8() // Clear stream. + + // Make sure we backtrack and resume searching after partial match. + sink.writeUtf8("hi hi hi hey") + sink.emit() + assertEquals(3, source.indexOf("hi hi hey".encodeUtf8())) + } + + @Test + fun indexOfByteStringAtSegmentBoundary() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 1)) + sink.writeUtf8("bcd") + sink.emit() + assertEquals((SEGMENT_SIZE - 3).toLong(), source.indexOf("aabc".encodeUtf8(), (SEGMENT_SIZE - 4).toLong())) + assertEquals((SEGMENT_SIZE - 3).toLong(), source.indexOf("aabc".encodeUtf8(), (SEGMENT_SIZE - 3).toLong())) + assertEquals((SEGMENT_SIZE - 2).toLong(), source.indexOf("abcd".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals((SEGMENT_SIZE - 2).toLong(), source.indexOf("abc".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals((SEGMENT_SIZE - 2).toLong(), source.indexOf("abc".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals((SEGMENT_SIZE - 2).toLong(), source.indexOf("ab".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals((SEGMENT_SIZE - 2).toLong(), source.indexOf("a".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals((SEGMENT_SIZE - 1).toLong(), source.indexOf("bc".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals((SEGMENT_SIZE - 1).toLong(), source.indexOf("b".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals(SEGMENT_SIZE.toLong(), source.indexOf("c".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals(SEGMENT_SIZE.toLong(), source.indexOf("c".encodeUtf8(), SEGMENT_SIZE.toLong())) + assertEquals((SEGMENT_SIZE + 1).toLong(), source.indexOf("d".encodeUtf8(), (SEGMENT_SIZE - 2).toLong())) + assertEquals((SEGMENT_SIZE + 1).toLong(), source.indexOf("d".encodeUtf8(), (SEGMENT_SIZE + 1).toLong())) + } + + @Test + fun indexOfDoesNotWrapAround() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 1)) + sink.writeUtf8("bcd") + sink.emit() + assertEquals(-1, source.indexOf("abcda".encodeUtf8(), (SEGMENT_SIZE - 3).toLong())) + } + + @Test + fun indexOfByteStringWithOffset() { + assertEquals(-1, source.indexOf("flop".encodeUtf8(), 1)) + sink.writeUtf8("flop flip flop") + sink.emit() + assertEquals(10, source.indexOf("flop".encodeUtf8(), 1)) + source.readUtf8() // Clear stream + + // Make sure we backtrack and resume searching after partial match. + sink.writeUtf8("hi hi hi hi hey") + sink.emit() + assertEquals(6, source.indexOf("hi hi hey".encodeUtf8(), 1)) + } + + @Test + fun indexOfByteStringInvalidArgumentsThrows() { + try { + source.indexOf(ByteString.of()) + fail() + } catch (e: IllegalArgumentException) { + assertEquals("bytes is empty", e.message) + } + try { + source.indexOf("hi".encodeUtf8(), -1) + fail() + } catch (e: IllegalArgumentException) { + assertEquals("fromIndex < 0: -1", e.message) + } + } + + /** + * With [Factory.ONE_BYTE_AT_A_TIME_BUFFERED_SOURCE], this code was extremely slow. + * https://github.com/square/okio/issues/171 + */ + @Test + fun indexOfByteStringAcrossSegmentBoundaries() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE * 2 - 3)) + sink.writeUtf8("bcdefg") + sink.emit() + assertEquals((SEGMENT_SIZE * 2 - 4).toLong(), source.indexOf("ab".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 - 4).toLong(), source.indexOf("abc".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 - 4).toLong(), source.indexOf("abcd".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 - 4).toLong(), source.indexOf("abcde".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 - 4).toLong(), source.indexOf("abcdef".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 - 4).toLong(), source.indexOf("abcdefg".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 - 3).toLong(), source.indexOf("bcdefg".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 - 2).toLong(), source.indexOf("cdefg".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 - 1).toLong(), source.indexOf("defg".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2).toLong(), source.indexOf("efg".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 + 1).toLong(), source.indexOf("fg".encodeUtf8())) + assertEquals((SEGMENT_SIZE * 2 + 2).toLong(), source.indexOf("g".encodeUtf8())) + } + + @Test + fun indexOfElement() { + sink.writeUtf8("a").writeUtf8("b".repeat(SEGMENT_SIZE)).writeUtf8("c") + sink.emit() + assertEquals(0, source.indexOfElement("DEFGaHIJK".encodeUtf8())) + assertEquals(1, source.indexOfElement("DEFGHIJKb".encodeUtf8())) + assertEquals((SEGMENT_SIZE + 1).toLong(), source.indexOfElement("cDEFGHIJK".encodeUtf8())) + assertEquals(1, source.indexOfElement("DEFbGHIc".encodeUtf8())) + assertEquals(-1L, source.indexOfElement("DEFGHIJK".encodeUtf8())) + assertEquals(-1L, source.indexOfElement("".encodeUtf8())) + } + + @Test + fun indexOfElementWithOffset() { + sink.writeUtf8("a").writeUtf8("b".repeat(SEGMENT_SIZE)).writeUtf8("c") + sink.emit() + assertEquals(-1, source.indexOfElement("DEFGaHIJK".encodeUtf8(), 1)) + assertEquals(15, source.indexOfElement("DEFGHIJKb".encodeUtf8(), 15)) + } + + @Test + fun indexOfByteWithFromIndex() { + sink.writeUtf8("aaa") + sink.emit() + assertEquals(0, source.indexOf('a'.code.toByte())) + assertEquals(0, source.indexOf('a'.code.toByte(), 0)) + assertEquals(1, source.indexOf('a'.code.toByte(), 1)) + assertEquals(2, source.indexOf('a'.code.toByte(), 2)) + } + + @Test + fun indexOfByteStringWithFromIndex() { + sink.writeUtf8("aaa") + sink.emit() + assertEquals(0, source.indexOf("a".encodeUtf8())) + assertEquals(0, source.indexOf("a".encodeUtf8(), 0)) + assertEquals(1, source.indexOf("a".encodeUtf8(), 1)) + assertEquals(2, source.indexOf("a".encodeUtf8(), 2)) + } + + @Test + fun indexOfElementWithFromIndex() { + sink.writeUtf8("aaa") + sink.emit() + assertEquals(0, source.indexOfElement("a".encodeUtf8())) + assertEquals(0, source.indexOfElement("a".encodeUtf8(), 0)) + assertEquals(1, source.indexOfElement("a".encodeUtf8(), 1)) + assertEquals(2, source.indexOfElement("a".encodeUtf8(), 2)) + } + + @Test + fun request() { + sink.writeUtf8("a").writeUtf8("b".repeat(SEGMENT_SIZE)).writeUtf8("c") + sink.emit() + assertTrue(source.request((SEGMENT_SIZE + 2).toLong())) + assertFalse(source.request((SEGMENT_SIZE + 3).toLong())) + } + + @Test + fun require() { + sink.writeUtf8("a").writeUtf8("b".repeat(SEGMENT_SIZE)).writeUtf8("c") + sink.emit() + source.require((SEGMENT_SIZE + 2).toLong()) + try { + source.require((SEGMENT_SIZE + 3).toLong()) + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun inputStream() { + sink.writeUtf8("abc") + sink.emit() + val `in` = source.inputStream() + val bytes = byteArrayOf('z'.code.toByte(), 'z'.code.toByte(), 'z'.code.toByte()) + var read = `in`.read(bytes) + if (factory.isOneByteAtATime) { + assertEquals(1, read.toLong()) + assertByteArrayEquals("azz", bytes) + read = `in`.read(bytes) + assertEquals(1, read.toLong()) + assertByteArrayEquals("bzz", bytes) + read = `in`.read(bytes) + assertEquals(1, read.toLong()) + assertByteArrayEquals("czz", bytes) + } else { + assertEquals(3, read.toLong()) + assertByteArrayEquals("abc", bytes) + } + assertEquals(-1, `in`.read().toLong()) + } + + @Test + fun inputStreamOffsetCount() { + sink.writeUtf8("abcde") + sink.emit() + val `in` = source.inputStream() + val bytes = byteArrayOf('z'.code.toByte(), 'z'.code.toByte(), 'z'.code.toByte(), 'z'.code.toByte(), 'z'.code.toByte()) + val read = `in`.read(bytes, 1, 3) + if (factory.isOneByteAtATime) { + assertEquals(1, read.toLong()) + assertByteArrayEquals("zazzz", bytes) + } else { + assertEquals(3, read.toLong()) + assertByteArrayEquals("zabcz", bytes) + } + } + + @Test + fun inputStreamSkip() { + sink.writeUtf8("abcde") + sink.emit() + val `in` = source.inputStream() + assertEquals(4, `in`.skip(4)) + assertEquals('e'.code.toLong(), `in`.read().toLong()) + sink.writeUtf8("abcde") + sink.emit() + assertEquals(5, `in`.skip(10)) // Try to skip too much. + assertEquals(0, `in`.skip(1)) // Try to skip when exhausted. + } + + @Test + fun inputStreamCharByChar() { + sink.writeUtf8("abc") + sink.emit() + val `in` = source.inputStream() + assertEquals('a'.code.toLong(), `in`.read().toLong()) + assertEquals('b'.code.toLong(), `in`.read().toLong()) + assertEquals('c'.code.toLong(), `in`.read().toLong()) + assertEquals(-1, `in`.read().toLong()) + } + + @Test + fun inputStreamBounds() { + sink.writeUtf8("a".repeat(100)) + sink.emit() + val `in` = source.inputStream() + try { + `in`.read(ByteArray(100), 50, 51) + fail() + } catch (expected: java.lang.ArrayIndexOutOfBoundsException) { + } + } + + @Test + fun longHexString() { + assertLongHexString("8000000000000000", -0x7fffffffffffffffL - 1L) + assertLongHexString("fffffffffffffffe", -0x2L) + assertLongHexString("FFFFFFFFFFFFFFFe", -0x2L) + assertLongHexString("ffffffffffffffff", -0x1L) + assertLongHexString("FFFFFFFFFFFFFFFF", -0x1L) + assertLongHexString("0000000000000000", 0x0) + assertLongHexString("0000000000000001", 0x1) + assertLongHexString("7999999999999999", 0x7999999999999999L) + assertLongHexString("FF", 0xFF) + assertLongHexString("0000000000000001", 0x1) + } + + @Test + fun hexStringWithManyLeadingZeros() { + assertLongHexString("00000000000000001", 0x1) + assertLongHexString("0000000000000000ffffffffffffffff", -0x1L) + assertLongHexString("00000000000000007fffffffffffffff", 0x7fffffffffffffffL) + assertLongHexString("0".repeat(SEGMENT_SIZE + 1) + "1", 0x1) + } + + private fun assertLongHexString(s: String, expected: Long) { + sink.writeUtf8(s) + sink.emit() + val actual = source.readHexadecimalUnsignedLong() + assertEquals("$s --> $expected", expected, actual) + } + + @Test + fun longHexStringAcrossSegment() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 8)).writeUtf8("FFFFFFFFFFFFFFFF") + sink.emit() + source.skip((SEGMENT_SIZE - 8).toLong()) + assertEquals(-1, source.readHexadecimalUnsignedLong()) + } + + @Test + fun longHexStringTooLongThrows() { + try { + sink.writeUtf8("fffffffffffffffff") + sink.emit() + source.readHexadecimalUnsignedLong() + fail() + } catch (e: NumberFormatException) { + assertEquals("Number too large: fffffffffffffffff", e.message) + } + } + + @Test + fun longHexStringTooShortThrows() { + try { + sink.writeUtf8(" ") + sink.emit() + source.readHexadecimalUnsignedLong() + fail() + } catch (e: NumberFormatException) { + assertEquals("Expected leading [0-9a-fA-F] character but was 0x20", e.message) + } + } + + @Test + fun longHexEmptySourceThrows() { + try { + sink.writeUtf8("") + sink.emit() + source.readHexadecimalUnsignedLong() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun longDecimalString() { + assertLongDecimalString("-9223372036854775808", -9223372036854775807L - 1L) + assertLongDecimalString("-1", -1L) + assertLongDecimalString("0", 0L) + assertLongDecimalString("1", 1L) + assertLongDecimalString("9223372036854775807", 9223372036854775807L) + assertLongDecimalString("00000001", 1L) + assertLongDecimalString("-000001", -1L) + } + + private fun assertLongDecimalString(s: String, expected: Long) { + sink.writeUtf8(s) + sink.writeUtf8("zzz") + sink.emit() + val actual = source.readDecimalLong() + assertEquals("$s --> $expected", expected, actual) + assertEquals("zzz", source.readUtf8()) + } + + @Test + fun longDecimalStringAcrossSegment() { + sink.writeUtf8("a".repeat(SEGMENT_SIZE - 8)).writeUtf8("1234567890123456") + sink.writeUtf8("zzz") + sink.emit() + source.skip((SEGMENT_SIZE - 8).toLong()) + assertEquals(1234567890123456L, source.readDecimalLong()) + assertEquals("zzz", source.readUtf8()) + } + + @Test + fun longDecimalStringTooLongThrows() { + try { + sink.writeUtf8("12345678901234567890") // Too many digits. + sink.emit() + source.readDecimalLong() + fail() + } catch (e: NumberFormatException) { + assertEquals("Number too large: 12345678901234567890", e.message) + } + } + + @Test + fun longDecimalStringTooHighThrows() { + try { + sink.writeUtf8("9223372036854775808") // Right size but cannot fit. + sink.emit() + source.readDecimalLong() + fail() + } catch (e: NumberFormatException) { + assertEquals("Number too large: 9223372036854775808", e.message) + } + } + + @Test + fun longDecimalStringTooLowThrows() { + try { + sink.writeUtf8("-9223372036854775809") // Right size but cannot fit. + sink.emit() + source.readDecimalLong() + fail() + } catch (e: NumberFormatException) { + assertEquals("Number too large: -9223372036854775809", e.message) + } + } + + @Test + fun longDecimalStringTooShortThrows() { + try { + sink.writeUtf8(" ") + sink.emit() + source.readDecimalLong() + fail() + } catch (e: NumberFormatException) { + assertEquals("Expected a digit or '-' but was 0x20", e.message) + } + } + + @Test + fun longDecimalEmptyThrows() { + try { + sink.writeUtf8("") + sink.emit() + source.readDecimalLong() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun codePoints() { + sink.write("7f".decodeHex()) + sink.emit() + assertEquals(0x7f, source.readUtf8CodePoint().toLong()) + sink.write("dfbf".decodeHex()) + sink.emit() + assertEquals(0x07ff, source.readUtf8CodePoint().toLong()) + sink.write("efbfbf".decodeHex()) + sink.emit() + assertEquals(0xffff, source.readUtf8CodePoint().toLong()) + sink.write("f48fbfbf".decodeHex()) + sink.emit() + assertEquals(0x10ffff, source.readUtf8CodePoint().toLong()) + } + + @Test + fun decimalStringWithManyLeadingZeros() { + assertLongDecimalString("00000000000000001", 1) + assertLongDecimalString("00000000000000009223372036854775807", 9223372036854775807L) + assertLongDecimalString("-00000000000000009223372036854775808", -9223372036854775807L - 1L) + assertLongDecimalString("0".repeat(SEGMENT_SIZE + 1) + "1", 1) + } + + @Test + fun select() { + val options = of( + "ROCK".encodeUtf8(), + "SCISSORS".encodeUtf8(), + "PAPER".encodeUtf8(), + ) + sink.writeUtf8("PAPER,SCISSORS,ROCK") + sink.emit() + assertEquals(2, source.select(options).toLong()) + assertEquals(','.code.toLong(), source.readByte().toLong()) + assertEquals(1, source.select(options).toLong()) + assertEquals(','.code.toLong(), source.readByte().toLong()) + assertEquals(0, source.select(options).toLong()) + assertTrue(source.exhausted()) + } + + /** Note that this test crashes the VM on Android. */ + @Test + fun selectSpanningMultipleSegments() { + val commonPrefix = randomBytes(SEGMENT_SIZE + 10) + val a = Buffer().write(commonPrefix).writeUtf8("a").readByteString() + val bc = Buffer().write(commonPrefix).writeUtf8("bc").readByteString() + val bd = Buffer().write(commonPrefix).writeUtf8("bd").readByteString() + val options = of(a, bc, bd) + sink.write(bd) + sink.write(a) + sink.write(bc) + sink.emit() + assertEquals(2, source.select(options).toLong()) + assertEquals(0, source.select(options).toLong()) + assertEquals(1, source.select(options).toLong()) + assertTrue(source.exhausted()) + } + + @Test + fun selectNotFound() { + val options = of( + "ROCK".encodeUtf8(), + "SCISSORS".encodeUtf8(), + "PAPER".encodeUtf8(), + ) + sink.writeUtf8("SPOCK") + sink.emit() + assertEquals(-1, source.select(options).toLong()) + assertEquals("SPOCK", source.readUtf8()) + } + + @Test + fun selectValuesHaveCommonPrefix() { + val options = of( + "abcd".encodeUtf8(), + "abce".encodeUtf8(), + "abcc".encodeUtf8(), + ) + sink.writeUtf8("abcc").writeUtf8("abcd").writeUtf8("abce") + sink.emit() + assertEquals(2, source.select(options).toLong()) + assertEquals(0, source.select(options).toLong()) + assertEquals(1, source.select(options).toLong()) + } + + @Test + fun selectLongerThanSource() { + val options = of( + "abcd".encodeUtf8(), + "abce".encodeUtf8(), + "abcc".encodeUtf8(), + ) + sink.writeUtf8("abc") + sink.emit() + assertEquals(-1, source.select(options).toLong()) + assertEquals("abc", source.readUtf8()) + } + + @Test + fun selectReturnsFirstByteStringThatMatches() { + val options = of( + "abcd".encodeUtf8(), + "abc".encodeUtf8(), + "abcde".encodeUtf8(), + ) + sink.writeUtf8("abcdef") + sink.emit() + assertEquals(0, source.select(options).toLong()) + assertEquals("ef", source.readUtf8()) + } + + @Test + fun selectFromEmptySource() { + val options = of( + "abc".encodeUtf8(), + "def".encodeUtf8(), + ) + assertEquals(-1, source.select(options).toLong()) + } + + @Test + fun selectNoByteStringsFromEmptySource() { + val options = Options.of() + assertEquals(-1, source.select(options).toLong()) + } + + @Test + fun peek() { + sink.writeUtf8("abcdefghi") + sink.emit() + assertEquals("abc", source.readUtf8(3)) + val peek = source.peek() + assertEquals("def", peek.readUtf8(3)) + assertEquals("ghi", peek.readUtf8(3)) + assertFalse(peek.request(1)) + assertEquals("def", source.readUtf8(3)) + } + + @Test + fun peekMultiple() { + sink.writeUtf8("abcdefghi") + sink.emit() + assertEquals("abc", source.readUtf8(3)) + val peek1 = source.peek() + val peek2 = source.peek() + assertEquals("def", peek1.readUtf8(3)) + assertEquals("def", peek2.readUtf8(3)) + assertEquals("ghi", peek2.readUtf8(3)) + assertFalse(peek2.request(1)) + assertEquals("ghi", peek1.readUtf8(3)) + assertFalse(peek1.request(1)) + assertEquals("def", source.readUtf8(3)) + } + + @Test + fun peekLarge() { + sink.writeUtf8("abcdef") + sink.writeUtf8("g".repeat(2 * SEGMENT_SIZE)) + sink.writeUtf8("hij") + sink.emit() + assertEquals("abc", source.readUtf8(3)) + val peek = source.peek() + assertEquals("def", peek.readUtf8(3)) + peek.skip((2 * SEGMENT_SIZE).toLong()) + assertEquals("hij", peek.readUtf8(3)) + assertFalse(peek.request(1)) + assertEquals("def", source.readUtf8(3)) + source.skip((2 * SEGMENT_SIZE).toLong()) + assertEquals("hij", source.readUtf8(3)) + } + + @Test + fun peekInvalid() { + sink.writeUtf8("abcdefghi") + sink.emit() + assertEquals("abc", source.readUtf8(3)) + val peek = source.peek() + assertEquals("def", peek.readUtf8(3)) + assertEquals("ghi", peek.readUtf8(3)) + assertFalse(peek.request(1)) + assertEquals("def", source.readUtf8(3)) + try { + peek.readUtf8() + fail() + } catch (e: IllegalStateException) { + assertEquals("Peek source is invalid because upstream source was used", e.message) + } + } + + @Test + fun peekSegmentThenInvalid() { + sink.writeUtf8("abc") + sink.writeUtf8("d".repeat(2 * SEGMENT_SIZE)) + sink.emit() + assertEquals("abc", source.readUtf8(3)) + + // Peek a little data and skip the rest of the upstream source + val peek = source.peek() + assertEquals("ddd", peek.readUtf8(3)) + source.readAll(blackholeSink()) + + // Skip the rest of the buffered data + peek.skip(peek.buffer.size) + try { + peek.readByte() + fail() + } catch (e: IllegalStateException) { + assertEquals("Peek source is invalid because upstream source was used", e.message) + } + } + + @Test + fun peekDoesntReadTooMuch() { + // 6 bytes in source's buffer plus 3 bytes upstream. + sink.writeUtf8("abcdef") + sink.emit() + source.require(6L) + sink.writeUtf8("ghi") + sink.emit() + val peek = source.peek() + + // Read 3 bytes. This reads some of the buffered data. + assertTrue(peek.request(3)) + if (source !is Buffer) { + assertEquals(6, source.buffer.size) + assertEquals(6, peek.buffer.size) + } + assertEquals("abc", peek.readUtf8(3L)) + + // Read 3 more bytes. This exhausts the buffered data. + assertTrue(peek.request(3)) + if (source !is Buffer) { + assertEquals(6, source.buffer.size) + assertEquals(3, peek.buffer.size) + } + assertEquals("def", peek.readUtf8(3L)) + + // Read 3 more bytes. This draws new bytes. + assertTrue(peek.request(3)) + assertEquals(9, source.buffer.size) + assertEquals(3, peek.buffer.size) + assertEquals("ghi", peek.readUtf8(3L)) + } + + @Test + fun rangeEquals() { + sink.writeUtf8("A man, a plan, a canal. Panama.") + sink.emit() + assertTrue(source.rangeEquals(7, "a plan".encodeUtf8())) + assertTrue(source.rangeEquals(0, "A man".encodeUtf8())) + assertTrue(source.rangeEquals(24, "Panama".encodeUtf8())) + assertFalse(source.rangeEquals(24, "Panama. Panama. Panama.".encodeUtf8())) + } + + @Test + fun rangeEqualsWithOffsetAndCount() { + sink.writeUtf8("A man, a plan, a canal. Panama.") + sink.emit() + assertTrue(source.rangeEquals(7, "aaa plannn".encodeUtf8(), 2, 6)) + assertTrue(source.rangeEquals(0, "AAA mannn".encodeUtf8(), 2, 5)) + assertTrue(source.rangeEquals(24, "PPPanamaaa".encodeUtf8(), 2, 6)) + } + + @Test + fun rangeEqualsOnlyReadsUntilMismatch() { + Assume.assumeTrue(factory === Factory.ONE_BYTE_AT_A_TIME_BUFFERED_SOURCE) // Other sources read in chunks anyway. + sink.writeUtf8("A man, a plan, a canal. Panama.") + sink.emit() + assertFalse(source.rangeEquals(0, "A man.".encodeUtf8())) + assertEquals("A man,", source.buffer.readUtf8()) + } + + @Test + fun rangeEqualsArgumentValidation() { + // Negative source offset. + assertFalse(source.rangeEquals(-1, "A".encodeUtf8())) + // Negative bytes offset. + assertFalse(source.rangeEquals(0, "A".encodeUtf8(), -1, 1)) + // Bytes offset longer than bytes length. + assertFalse(source.rangeEquals(0, "A".encodeUtf8(), 2, 1)) + // Negative byte count. + assertFalse(source.rangeEquals(0, "A".encodeUtf8(), 0, -1)) + // Byte count longer than bytes length. + assertFalse(source.rangeEquals(0, "A".encodeUtf8(), 0, 2)) + // Bytes offset plus byte count longer than bytes length. + assertFalse(source.rangeEquals(0, "A".encodeUtf8(), 1, 1)) + } + + @Test + fun readNioBuffer() { + val expected = if (factory.isOneByteAtATime) "a" else "abcdefg" + sink.writeUtf8("abcdefg") + sink.emit() + val nioByteBuffer = ByteBuffer.allocate(1024) + val byteCount = source.read(nioByteBuffer) + assertEquals(expected.length.toLong(), byteCount.toLong()) + assertEquals(expected.length.toLong(), nioByteBuffer.position().toLong()) + assertEquals(nioByteBuffer.capacity().toLong(), nioByteBuffer.limit().toLong()) + (nioByteBuffer as java.nio.Buffer).flip() // Cast necessary for Java 8. + val data = ByteArray(expected.length) + nioByteBuffer[data] + assertEquals(expected, data.decodeToString()) + } + + /** Note that this test crashes the VM on Android. */ + @Test + fun readLargeNioBufferOnlyReadsOneSegment() { + val expected = if (factory.isOneByteAtATime) "a" else "a".repeat(SEGMENT_SIZE) + sink.writeUtf8("a".repeat(SEGMENT_SIZE * 4)) + sink.emit() + val nioByteBuffer = ByteBuffer.allocate(SEGMENT_SIZE * 3) + val byteCount = source.read(nioByteBuffer) + assertEquals(expected.length.toLong(), byteCount.toLong()) + assertEquals(expected.length.toLong(), nioByteBuffer.position().toLong()) + assertEquals(nioByteBuffer.capacity().toLong(), nioByteBuffer.limit().toLong()) + (nioByteBuffer as java.nio.Buffer).flip() // Cast necessary for Java 8. + val data = ByteArray(expected.length) + nioByteBuffer[data] + assertEquals(expected, data.decodeToString()) + } + + @Test + fun factorySegmentSizes() { + sink.writeUtf8("abc") + sink.emit() + source.require(3) + if (factory.isOneByteAtATime) { + assertEquals(mutableListOf(1, 1, 1), segmentSizes(source.buffer)) + } else { + assertEquals(listOf(3), segmentSizes(source.buffer)) + } + } + + companion object { + @JvmStatic + @Parameters(name = "{0}") + fun parameters(): List<Array<Any>> { + return listOf( + arrayOf(Factory.BUFFER), + arrayOf(Factory.REAL_BUFFERED_SOURCE), + arrayOf(Factory.ONE_BYTE_AT_A_TIME_BUFFERED_SOURCE), + arrayOf(Factory.ONE_BYTE_AT_A_TIME_BUFFER), + arrayOf(Factory.PEEK_BUFFER), + arrayOf(Factory.PEEK_BUFFERED_SOURCE), + ) + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/ByteStringJavaTest.kt b/okio/src/jvmTest/kotlin/okio/ByteStringJavaTest.kt new file mode 100644 index 00000000..490eb425 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/ByteStringJavaTest.kt @@ -0,0 +1,278 @@ +/* + * Copyright 2014 Square Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.nio.ByteBuffer +import java.util.Random +import kotlin.text.Charsets.US_ASCII +import kotlin.text.Charsets.UTF_16BE +import kotlin.text.Charsets.UTF_32BE +import kotlin.text.Charsets.UTF_8 +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.encode +import okio.ByteString.Companion.encodeUtf8 +import okio.ByteString.Companion.readByteString +import okio.ByteString.Companion.toByteString +import okio.TestUtil.assertByteArraysEquals +import okio.TestUtil.assertEquivalent +import okio.TestUtil.makeSegments +import okio.TestUtil.reserialize +import org.junit.Assert.assertEquals +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +@RunWith(Parameterized::class) +class ByteStringJavaTest { + interface Factory { + fun decodeHex(hex: String): ByteString + fun encodeUtf8(s: String): ByteString + + companion object { + val BYTE_STRING: Factory = object : Factory { + override fun decodeHex(hex: String): ByteString { + return hex.decodeHex() + } + + override fun encodeUtf8(s: String): ByteString { + return s.encodeUtf8() + } + } + val SEGMENTED_BYTE_STRING: Factory = object : Factory { + override fun decodeHex(hex: String): ByteString { + val buffer = Buffer() + buffer.write(hex.decodeHex()) + return buffer.snapshot() + } + + override fun encodeUtf8(s: String): ByteString { + val buffer = Buffer() + buffer.writeUtf8(s) + return buffer.snapshot() + } + } + val ONE_BYTE_PER_SEGMENT: Factory = object : Factory { + override fun decodeHex(hex: String): ByteString { + return makeSegments(hex.decodeHex()) + } + + override fun encodeUtf8(s: String): ByteString { + return makeSegments(s.encodeUtf8()) + } + } + } + } + + @Parameter(0) + lateinit var factory: Factory + + @Parameter(1) + lateinit var name: String + + @Test + fun ofByteBuffer() { + val bytes = "Hello, World!".toByteArray(UTF_8) + val byteBuffer = ByteBuffer.wrap(bytes) + (byteBuffer as java.nio.Buffer).position(2).limit(11) // Cast necessary for Java 8. + val byteString: ByteString = byteBuffer.toByteString() + // Verify that the bytes were copied out. + byteBuffer.put(4, 'a'.code.toByte()) + assertEquals("llo, Worl", byteString.utf8()) + } + + @Test + fun read() { + val inputStream = ByteArrayInputStream("abc".toByteArray(UTF_8)) + assertEquals("6162".decodeHex(), inputStream.readByteString(2)) + assertEquals("63".decodeHex(), inputStream.readByteString(1)) + assertEquals(ByteString.of(), inputStream.readByteString(0)) + } + + @Test + fun readAndToLowercase() { + val inputStream = ByteArrayInputStream("ABC".toByteArray(UTF_8)) + assertEquals("ab".encodeUtf8(), inputStream.readByteString(2).toAsciiLowercase()) + assertEquals("c".encodeUtf8(), inputStream.readByteString(1).toAsciiLowercase()) + assertEquals(ByteString.EMPTY, inputStream.readByteString(0).toAsciiLowercase()) + } + + @Test + fun readAndToUppercase() { + val inputStream = ByteArrayInputStream("abc".toByteArray(UTF_8)) + assertEquals("AB".encodeUtf8(), inputStream.readByteString(2).toAsciiUppercase()) + assertEquals("C".encodeUtf8(), inputStream.readByteString(1).toAsciiUppercase()) + assertEquals(ByteString.EMPTY, inputStream.readByteString(0).toAsciiUppercase()) + } + + @Test + fun write() { + val out = ByteArrayOutputStream() + factory.decodeHex("616263").write(out) + assertByteArraysEquals(byteArrayOf(0x61, 0x62, 0x63), out.toByteArray()) + } + + @Test + fun compareToSingleBytes() { + val originalByteStrings = listOf( + factory.decodeHex("00"), + factory.decodeHex("01"), + factory.decodeHex("7e"), + factory.decodeHex("7f"), + factory.decodeHex("80"), + factory.decodeHex("81"), + factory.decodeHex("fe"), + factory.decodeHex("ff"), + ) + val sortedByteStrings = originalByteStrings.toMutableList() + sortedByteStrings.shuffle(Random(0)) + sortedByteStrings.sort() + assertEquals(originalByteStrings, sortedByteStrings) + } + + @Test + fun compareToMultipleBytes() { + val originalByteStrings = listOf( + factory.decodeHex(""), + factory.decodeHex("00"), + factory.decodeHex("0000"), + factory.decodeHex("000000"), + factory.decodeHex("00000000"), + factory.decodeHex("0000000000"), + factory.decodeHex("0000000001"), + factory.decodeHex("000001"), + factory.decodeHex("00007f"), + factory.decodeHex("0000ff"), + factory.decodeHex("000100"), + factory.decodeHex("000101"), + factory.decodeHex("007f00"), + factory.decodeHex("00ff00"), + factory.decodeHex("010000"), + factory.decodeHex("010001"), + factory.decodeHex("01007f"), + factory.decodeHex("0100ff"), + factory.decodeHex("010100"), + factory.decodeHex("01010000"), + factory.decodeHex("0101000000"), + factory.decodeHex("0101000001"), + factory.decodeHex("010101"), + factory.decodeHex("7f0000"), + factory.decodeHex("7f0000ffff"), + factory.decodeHex("ffffff"), + ) + val sortedByteStrings = originalByteStrings.toMutableList() + sortedByteStrings.shuffle(Random(0)) + sortedByteStrings.sort() + assertEquals(originalByteStrings, sortedByteStrings) + } + + @Test + fun javaSerializationTestNonEmpty() { + val byteString = factory.encodeUtf8(bronzeHorseman) + assertEquivalent(byteString, reserialize(byteString)) + } + + @Test + fun javaSerializationTestEmpty() { + val byteString = factory.decodeHex("") + assertEquivalent(byteString, reserialize(byteString)) + } + + @Test + fun asByteBuffer() { + assertEquals( + 0x42, + ByteString.of(0x41.toByte(), 0x42.toByte(), 0x43.toByte()).asByteBuffer()[1].toLong(), + ) + } + + @Test + fun encodeDecodeStringUtf8() { + val byteString = bronzeHorseman.encode(UTF_8) + assertByteArraysEquals(byteString.toByteArray(), bronzeHorseman.toByteArray(UTF_8)) + assertEquals( + byteString, + ( + "d09dd0b020d0b1d0b5d180d0b5d0b3d18320d0bfd183d181" + + "d182d18bd0bdd0bdd18bd18520d0b2d0bed0bbd0bd" + ).decodeHex(), + ) + assertEquals(bronzeHorseman, byteString.string(UTF_8)) + } + + @Test + fun encodeDecodeStringUtf16be() { + val byteString = bronzeHorseman.encode(UTF_16BE) + assertByteArraysEquals(byteString.toByteArray(), bronzeHorseman.toByteArray(UTF_16BE)) + assertEquals( + byteString, + ( + "041d043000200431043504400435043304430020043f0443" + + "04410442044b043d043d044b044500200432043e043b043d" + ).decodeHex(), + ) + assertEquals(bronzeHorseman, byteString.string(UTF_16BE)) + } + + @Test + fun encodeDecodeStringUtf32be() { + val byteString: ByteString = bronzeHorseman.encode(UTF_32BE) + assertByteArraysEquals(byteString.toByteArray(), bronzeHorseman.toByteArray(UTF_32BE)) + assertEquals( + byteString, + ( + "0000041d0000043000000020000004310000043500000440" + + "000004350000043300000443000000200000043f0000044300000441000004420000044b0000043d0000043d" + + "0000044b0000044500000020000004320000043e0000043b0000043d" + ).decodeHex(), + ) + assertEquals(bronzeHorseman, byteString.string(UTF_32BE)) + } + + @Test + fun encodeDecodeStringAsciiIsLossy() { + val byteString: ByteString = bronzeHorseman.encode(US_ASCII) + assertByteArraysEquals(byteString.toByteArray(), bronzeHorseman.toByteArray(US_ASCII)) + assertEquals( + byteString, + "3f3f203f3f3f3f3f3f203f3f3f3f3f3f3f3f3f203f3f3f3f".decodeHex(), + ) + assertEquals("?? ?????? ????????? ????", byteString.string(US_ASCII)) + } + + @Test + fun decodeMalformedStringReturnsReplacementCharacter() { + val string = "04".decodeHex().string(UTF_16BE) + assertEquals("\ufffd", string) + } + + companion object { + private val bronzeHorseman = "На берегу пустынных волн" + + @JvmStatic + @Parameters(name = "{1}") + fun parameters(): List<Array<Any>> { + return listOf( + arrayOf(Factory.BYTE_STRING, "ByteString"), + arrayOf(Factory.SEGMENTED_BYTE_STRING, "SegmentedByteString"), + arrayOf(Factory.ONE_BYTE_PER_SEGMENT, "SegmentedByteString (one-at-a-time)"), + ) + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/ByteStringKotlinTest.kt b/okio/src/jvmTest/kotlin/okio/ByteStringKotlinTest.kt index c554251f..ba91338d 100644 --- a/okio/src/jvmTest/kotlin/okio/ByteStringKotlinTest.kt +++ b/okio/src/jvmTest/kotlin/okio/ByteStringKotlinTest.kt @@ -15,14 +15,14 @@ */ package okio -import okio.ByteString.Companion.encode -import okio.ByteString.Companion.encodeUtf8 -import okio.ByteString.Companion.readByteString -import okio.ByteString.Companion.toByteString import java.io.ByteArrayInputStream import java.nio.ByteBuffer import kotlin.test.Test import kotlin.test.assertEquals +import okio.ByteString.Companion.encode +import okio.ByteString.Companion.encodeUtf8 +import okio.ByteString.Companion.readByteString +import okio.ByteString.Companion.toByteString class ByteStringKotlinTest { @Test fun arrayToByteString() { diff --git a/okio/src/jvmTest/kotlin/okio/CipherAlgorithm.kt b/okio/src/jvmTest/kotlin/okio/CipherAlgorithm.kt index f9f42b03..69415241 100644 --- a/okio/src/jvmTest/kotlin/okio/CipherAlgorithm.kt +++ b/okio/src/jvmTest/kotlin/okio/CipherAlgorithm.kt @@ -23,7 +23,7 @@ data class CipherAlgorithm( val transformation: String, val padding: Boolean, val keyLength: Int, - val ivLength: Int? = null + val ivLength: Int? = null, ) { fun createCipherFactory(random: Random): CipherFactory { val key = random.nextBytes(keyLength) @@ -57,7 +57,7 @@ data class CipherAlgorithm( CipherAlgorithm("DESede/CBC/NoPadding", false, 24, 8), CipherAlgorithm("DESede/CBC/PKCS5Padding", true, 24, 8), CipherAlgorithm("DESede/ECB/NoPadding", false, 24), - CipherAlgorithm("DESede/ECB/PKCS5Padding", true, 24) + CipherAlgorithm("DESede/ECB/PKCS5Padding", true, 24), ) } } diff --git a/okio/src/jvmTest/kotlin/okio/CipherFactory.kt b/okio/src/jvmTest/kotlin/okio/CipherFactory.kt index 7b93c652..bb4d3cf3 100644 --- a/okio/src/jvmTest/kotlin/okio/CipherFactory.kt +++ b/okio/src/jvmTest/kotlin/okio/CipherFactory.kt @@ -19,7 +19,7 @@ import javax.crypto.Cipher class CipherFactory( private val transformation: String, - private val init: Cipher.(mode: Int) -> Unit + private val init: Cipher.(mode: Int) -> Unit, ) { val blockSize get() = newCipher().blockSize diff --git a/okio/src/jvmTest/kotlin/okio/CipherSinkTest.kt b/okio/src/jvmTest/kotlin/okio/CipherSinkTest.kt index f273971d..84d27d06 100644 --- a/okio/src/jvmTest/kotlin/okio/CipherSinkTest.kt +++ b/okio/src/jvmTest/kotlin/okio/CipherSinkTest.kt @@ -15,10 +15,10 @@ */ package okio +import kotlin.random.Random import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized -import kotlin.random.Random @RunWith(Parameterized::class) class CipherSinkTest(private val cipherAlgorithm: CipherAlgorithm) { @@ -84,7 +84,7 @@ class CipherSinkTest(private val cipherAlgorithm: CipherAlgorithm) { val cipherSink = buffer.cipherSink(cipherFactory.encrypt) cipherSink.buffer().use { data.forEach { - byte -> + byte -> it.writeByte(byte.toInt()) } } diff --git a/okio/src/jvmTest/kotlin/okio/CipherSourceTest.kt b/okio/src/jvmTest/kotlin/okio/CipherSourceTest.kt index f97258e2..16775350 100644 --- a/okio/src/jvmTest/kotlin/okio/CipherSourceTest.kt +++ b/okio/src/jvmTest/kotlin/okio/CipherSourceTest.kt @@ -15,10 +15,10 @@ */ package okio +import kotlin.random.Random import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized -import kotlin.random.Random @RunWith(Parameterized::class) class CipherSourceTest(private val cipherAlgorithm: CipherAlgorithm) { diff --git a/okio/src/jvmTest/kotlin/okio/DeflateKotlinTest.kt b/okio/src/jvmTest/kotlin/okio/DeflateKotlinTest.kt index 7a1744ad..5174871b 100644 --- a/okio/src/jvmTest/kotlin/okio/DeflateKotlinTest.kt +++ b/okio/src/jvmTest/kotlin/okio/DeflateKotlinTest.kt @@ -16,11 +16,11 @@ package okio -import okio.ByteString.Companion.decodeHex -import org.junit.Test import java.util.zip.Deflater import java.util.zip.Inflater import kotlin.test.assertEquals +import okio.ByteString.Companion.decodeHex +import org.junit.Test class DeflateKotlinTest { @Test fun deflate() { diff --git a/okio/src/jvmTest/kotlin/okio/DeflaterSinkTest.kt b/okio/src/jvmTest/kotlin/okio/DeflaterSinkTest.kt new file mode 100644 index 00000000..57a8586d --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/DeflaterSinkTest.kt @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.util.zip.Deflater +import java.util.zip.Inflater +import java.util.zip.InflaterInputStream +import okio.TestUtil.SEGMENT_SIZE +import okio.TestUtil.randomBytes +import org.junit.Assert +import org.junit.Test + +class DeflaterSinkTest { + @Test + fun deflateWithClose() { + val data = Buffer() + val original = "They're moving in herds. They do move in herds." + data.writeUtf8(original) + val sink = Buffer() + val deflaterSink = DeflaterSink(sink, Deflater()) + deflaterSink.write(data, data.size) + deflaterSink.close() + val inflated = inflate(sink) + Assert.assertEquals(original, inflated.readUtf8()) + } + + @Test + fun deflateWithSyncFlush() { + val original = "Yes, yes, yes. That's why we're taking extreme precautions." + val data = Buffer() + data.writeUtf8(original) + val sink = Buffer() + val deflaterSink = DeflaterSink(sink, Deflater()) + deflaterSink.write(data, data.size) + deflaterSink.flush() + val inflated = inflate(sink) + Assert.assertEquals(original, inflated.readUtf8()) + } + + @Test + fun deflateWellCompressed() { + val original = "a".repeat(1024 * 1024) + val data = Buffer() + data.writeUtf8(original) + val sink = Buffer() + val deflaterSink = DeflaterSink(sink, Deflater()) + deflaterSink.write(data, data.size) + deflaterSink.close() + val inflated = inflate(sink) + Assert.assertEquals(original, inflated.readUtf8()) + } + + @Test + fun deflatePoorlyCompressed() { + val original = randomBytes(1024 * 1024) + val data = Buffer() + data.write(original) + val sink = Buffer() + val deflaterSink = DeflaterSink(sink, Deflater()) + deflaterSink.write(data, data.size) + deflaterSink.close() + val inflated = inflate(sink) + Assert.assertEquals(original, inflated.readByteString()) + } + + @Test + fun multipleSegmentsWithoutCompression() { + val buffer = Buffer() + val deflater = Deflater() + deflater.setLevel(Deflater.NO_COMPRESSION) + val deflaterSink = DeflaterSink(buffer, deflater) + val byteCount = SEGMENT_SIZE * 4 + deflaterSink.write(Buffer().writeUtf8("a".repeat(byteCount)), byteCount.toLong()) + deflaterSink.close() + Assert.assertEquals("a".repeat(byteCount), inflate(buffer).readUtf8(byteCount.toLong())) + } + + @Test + fun deflateIntoNonemptySink() { + val original = "They're moving in herds. They do move in herds." + + // Exercise all possible offsets for the outgoing segment. + for (i in 0 until SEGMENT_SIZE) { + val data = Buffer().writeUtf8(original) + val sink = Buffer().writeUtf8("a".repeat(i)) + val deflaterSink = DeflaterSink(sink, Deflater()) + deflaterSink.write(data, data.size) + deflaterSink.close() + sink.skip(i.toLong()) + val inflated = inflate(sink) + Assert.assertEquals(original, inflated.readUtf8()) + } + } + + /** + * This test deflates a single segment of without compression because that's + * the easiest way to force close() to emit a large amount of data to the + * underlying sink. + */ + @Test + fun closeWithExceptionWhenWritingAndClosing() { + val mockSink = MockSink() + mockSink.scheduleThrow(0, IOException("first")) + mockSink.scheduleThrow(1, IOException("second")) + val deflater = Deflater() + deflater.setLevel(Deflater.NO_COMPRESSION) + val deflaterSink = DeflaterSink(mockSink, deflater) + deflaterSink.write(Buffer().writeUtf8("a".repeat(SEGMENT_SIZE)), SEGMENT_SIZE.toLong()) + try { + deflaterSink.close() + Assert.fail() + } catch (expected: IOException) { + Assert.assertEquals("first", expected.message) + } + mockSink.assertLogContains("close()") + } + + /** + * This test confirms that we swallow NullPointerException from Deflater and + * rethrow as an IOException. + */ + @Test + fun rethrowNullPointerAsIOException() { + val deflater = Deflater() + // Close to cause a NullPointerException + deflater.end() + + val data = Buffer().apply { + writeUtf8("They're moving in herds. They do move in herds.") + } + val deflaterSink = DeflaterSink(Buffer(), deflater) + + val ioe = Assert.assertThrows("", IOException::class.java) { + deflaterSink.write(data, data.size) + } + + Assert.assertTrue(ioe.cause is NullPointerException) + } + + /** + * Uses streaming decompression to inflate `deflated`. The input must + * either be finished or have a trailing sync flush. + */ + private fun inflate(deflated: Buffer): Buffer { + val deflatedIn = deflated.inputStream() + val inflater = Inflater() + val inflatedIn = InflaterInputStream(deflatedIn, inflater) + val result = Buffer() + val buffer = ByteArray(8192) + while (!inflater.needsInput() || deflated.size > 0 || deflatedIn.available() > 0) { + val count = inflatedIn.read(buffer, 0, buffer.size) + if (count != -1) { + result.write(buffer, 0, count) + } + } + return result + } +} diff --git a/okio/src/jvmTest/kotlin/okio/FileHandleFileSystemTest.kt b/okio/src/jvmTest/kotlin/okio/FileHandleFileSystemTest.kt new file mode 100644 index 00000000..021ffb4c --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/FileHandleFileSystemTest.kt @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import com.google.common.jimfs.Configuration +import com.google.common.jimfs.Jimfs +import java.nio.file.FileSystems +import kotlinx.datetime.Clock +import okio.FileHandleFileSystemTest.FileHandleTestingFileSystem +import okio.FileSystem.Companion.asOkioFileSystem + +/** + * Run a regular file system test, but use [FileHandle] for more file system operations than usual. + * This is intended to increase test coverage for [FileHandle]. + */ +class FileHandleFileSystemTest : AbstractFileSystemTest( + clock = Clock.System, + fileSystem = FileHandleTestingFileSystem(FileSystem.SYSTEM), + windowsLimitations = Path.DIRECTORY_SEPARATOR == "\\", + allowClobberingEmptyDirectories = Path.DIRECTORY_SEPARATOR == "\\", + allowAtomicMoveFromFileToDirectory = false, + temporaryDirectory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY, +) { + /** + * A testing-only file system that implements all reading and writing operations with + * [FileHandle]. This is intended to increase test coverage for [FileHandle]. + */ + class FileHandleTestingFileSystem(delegate: FileSystem) : ForwardingFileSystem(delegate) { + override fun source(file: Path): Source { + val fileHandle = openReadOnly(file) + return fileHandle.source() + .also { fileHandle.close() } + } + + override fun sink(file: Path, mustCreate: Boolean): Sink { + val fileHandle = openReadWrite(file, mustCreate = mustCreate, mustExist = false) + fileHandle.resize(0L) // If the file already has data, get rid of it. + return fileHandle.sink() + .also { fileHandle.close() } + } + + override fun appendingSink(file: Path, mustExist: Boolean): Sink { + val fileHandle = openReadWrite(file, mustCreate = false, mustExist = mustExist) + return fileHandle.appendingSink() + .also { fileHandle.close() } + } + } +} + +class FileHandleNioJimFileSystemWrapperFileSystemTest : AbstractFileSystemTest( + clock = Clock.System, + fileSystem = FileHandleTestingFileSystem( + Jimfs + .newFileSystem( + when (Path.DIRECTORY_SEPARATOR == "\\") { + true -> Configuration.windows() + false -> Configuration.unix() + }, + ).asOkioFileSystem(), + ), + windowsLimitations = false, + allowClobberingEmptyDirectories = true, + allowAtomicMoveFromFileToDirectory = true, + temporaryDirectory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY, +) + +class FileHandleNioDefaultFileSystemWrapperFileSystemTest : AbstractFileSystemTest( + clock = Clock.System, + fileSystem = FileHandleTestingFileSystem( + FileSystems.getDefault().asOkioFileSystem(), + ), + windowsLimitations = false, + allowClobberingEmptyDirectories = Path.DIRECTORY_SEPARATOR == "\\", + allowAtomicMoveFromFileToDirectory = false, + allowRenameWhenTargetIsOpen = Path.DIRECTORY_SEPARATOR != "\\", + temporaryDirectory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY, +) diff --git a/okio/src/jvmTest/kotlin/okio/FileLeakTest.kt b/okio/src/jvmTest/kotlin/okio/FileLeakTest.kt new file mode 100644 index 00000000..5fd18a6b --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/FileLeakTest.kt @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2023 Square, Inc. and others. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.util.zip.ZipEntry +import java.util.zip.ZipOutputStream +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import okio.Path.Companion.toPath +import okio.fakefilesystem.FakeFileSystem +import org.junit.After +import org.junit.Before +import org.junit.Test + +class FileLeakTest { + + private lateinit var fakeFileSystem: FakeFileSystem + private val fakeZip = "/test.zip".toPath() + private val fakeEntry = "some.file".toPath() + private val fakeDirectory = "/another/".toPath() + private val fakeEntry2 = fakeDirectory / "another.file" + + @Before + fun setup() { + fakeFileSystem = FakeFileSystem() + with(fakeFileSystem) { + write(fakeZip) { + writeZip { + putEntry(fakeEntry.name) { + writeUtf8("FooBar") + } + try { + putNextEntry(ZipEntry(fakeDirectory.name).apply { time = 0L }) + } finally { + closeEntry() + } + putEntry(fakeEntry2.toString()) { + writeUtf8("SomethingElse") + } + } + } + } + } + + @After + fun tearDown() { + fakeFileSystem.checkNoOpenFiles() + } + + @Test + fun zipFileSystemExistsTest() { + val zipFileSystem = fakeFileSystem.openZip(fakeZip) + assertTrue(zipFileSystem.exists(fakeEntry)) + } + + @Test + fun zipFileSystemMetadataTest() { + val zipFileSystem = fakeFileSystem.openZip(fakeZip) + assertNotNull(zipFileSystem.metadataOrNull(fakeEntry)) + } + + @Test + fun zipFileSystemSourceTest() { + val zipFileSystem = fakeFileSystem.openZip(fakeZip) + zipFileSystem.source(fakeEntry).use { source -> + assertEquals("FooBar", source.buffer().readUtf8()) + } + } + + @Test + fun zipFileSystemListRecursiveTest() { + val zipFileSystem = fakeFileSystem.openZip(fakeZip) + zipFileSystem.listRecursively("/".toPath()).toList() + fakeFileSystem.delete(fakeZip) + } +} + +/** + * Writes a ZIP file to a [BufferedSink]. + */ +private inline fun <R> BufferedSink.writeZip(action: ZipOutputStream.() -> R): R { + return ZipOutputStream(outputStream()).use(action) +} + +/** + * Adds a new ZIP entry named [name], populates it with [action], and closes the entry. + */ +private inline fun <R> ZipOutputStream.putEntry(name: String, action: BufferedSink.() -> R): R { + putNextEntry(ZipEntry(name).apply { time = 0L }) + val sink = sink().buffer() + return try { + sink.action() + } finally { + sink.flush() + closeEntry() + } +} diff --git a/okio/src/jvmTest/kotlin/okio/FileSystemJavaTest.kt b/okio/src/jvmTest/kotlin/okio/FileSystemJavaTest.kt new file mode 100644 index 00000000..699db3ef --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/FileSystemJavaTest.kt @@ -0,0 +1,122 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.File +import java.nio.file.Paths +import okio.ByteString.Companion.encodeUtf8 +import okio.Path.Companion.toOkioPath +import okio.Path.Companion.toPath +import okio.fakefilesystem.FakeFileSystem +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class FileSystemJavaTest { + @Test + fun pathApi() { + val path = "/home/jesse/todo.txt".toPath(false) + assertThat("/home/jesse".toPath(false).div("todo.txt")).isEqualTo(path) + assertThat("/home/jesse/todo.txt".toPath(false)).isEqualTo(path) + assertThat(path.isAbsolute).isTrue() + assertThat(path.isRelative).isFalse() + assertThat(path.isRoot).isFalse() + assertThat(path.name).isEqualTo("todo.txt") + assertThat(path.nameBytes).isEqualTo("todo.txt".encodeUtf8()) + assertThat(path.parent).isEqualTo("/home/jesse".toPath(false)) + assertThat(path.volumeLetter).isNull() + } + + @Test + fun directorySeparator() { + assertThat(Path.DIRECTORY_SEPARATOR).isIn("/", "\\") + } + + /** Like the same test in JvmTest, but this is using the Java APIs. */ + @Test + fun javaIoFileToOkioPath() { + val string = "/foo/bar/baz".replace("/", Path.DIRECTORY_SEPARATOR) + val javaIoFile = File(string) + val okioPath: Path = string.toPath(false) + assertThat(javaIoFile.toOkioPath(false)).isEqualTo(okioPath) + assertThat(okioPath.toFile()).isEqualTo(javaIoFile) + } + + /** Like the same test in JvmTest, but this is using the Java APIs. */ + @Test + fun nioPathToOkioPath() { + val string = "/foo/bar/baz".replace("/", Path.DIRECTORY_SEPARATOR) + val nioPath = Paths.get(string) + val okioPath: Path = string.toPath(false) + assertThat(nioPath.toOkioPath(false)).isEqualTo(okioPath) + assertThat(okioPath.toNioPath() as Any).isEqualTo(nioPath) + } + + // Just confirm these APIs exist; don't invoke them + @Suppress("unused") + fun fileSystemApi() { + val fileSystem = FileSystem.SYSTEM + val pathA: Path = "a.txt".toPath() + val pathB: Path = "b.txt".toPath() + fileSystem.canonicalize(pathA) + fileSystem.metadata(pathA) + fileSystem.metadataOrNull(pathA) + fileSystem.exists(pathA) + fileSystem.list(pathA) + fileSystem.listOrNull(pathA) + fileSystem.listRecursively(pathA, false) + fileSystem.listRecursively(pathA) + fileSystem.openReadOnly(pathA) + fileSystem.openReadWrite(pathA, mustCreate = false, mustExist = false) + fileSystem.openReadWrite(pathA) + fileSystem.source(pathA) + // Note that FileSystem.read() isn't available to Java callers. + fileSystem.sink(pathA, false) + fileSystem.sink(pathA) + // Note that FileSystem.write() isn't available to Java callers. + fileSystem.appendingSink(pathA, false) + fileSystem.appendingSink(pathA) + fileSystem.createDirectory(pathA) + fileSystem.createDirectories(pathA) + fileSystem.atomicMove(pathA, pathB) + fileSystem.copy(pathA, pathB) + fileSystem.delete(pathA) + fileSystem.deleteRecursively(pathA) + fileSystem.createSymlink(pathA, pathB) + } + + @Test + fun fakeFileSystemApi() { + val fakeFileSystem = FakeFileSystem() + assertThat(fakeFileSystem.clock).isNotNull() + assertThat(fakeFileSystem.allPaths).isEmpty() + assertThat(fakeFileSystem.openPaths).isEmpty() + fakeFileSystem.checkNoOpenFiles() + } + + @Test + fun forwardingFileSystemApi() { + val fakeFileSystem = FakeFileSystem() + val log: MutableList<String> = ArrayList() + val forwardingFileSystem: ForwardingFileSystem = object : ForwardingFileSystem(fakeFileSystem) { + override fun onPathParameter(path: Path, functionName: String, parameterName: String): Path { + log.add("$functionName($parameterName=$path)") + return path + } + } + forwardingFileSystem.metadataOrNull("/".toPath(false)) + assertThat(log).containsExactly("metadataOrNull(path=/)") + } +} diff --git a/okio/src/jvmTest/kotlin/okio/FixedLengthSourceTest.kt b/okio/src/jvmTest/kotlin/okio/FixedLengthSourceTest.kt new file mode 100644 index 00000000..79184177 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/FixedLengthSourceTest.kt @@ -0,0 +1,165 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import kotlin.test.fail +import okio.internal.FixedLengthSource +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +internal class FixedLengthSourceTest { + @Test + fun happyPathWithTruncate() { + val delegate = Buffer().writeUtf8("abcdefghijklmnop") + val fixedLengthSource = FixedLengthSource(delegate, 16, truncate = true) + val buffer = Buffer() + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(10L) + assertThat(buffer.readUtf8()).isEqualTo("abcdefghij") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(6L) + assertThat(buffer.readUtf8()).isEqualTo("klmnop") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(-1L) + assertThat(buffer.readUtf8()).isEqualTo("") + } + + @Test + fun happyPathNoTruncate() { + val delegate = Buffer().writeUtf8("abcdefghijklmnop") + val fixedLengthSource = FixedLengthSource(delegate, 16, truncate = false) + val buffer = Buffer() + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(10L) + assertThat(buffer.readUtf8()).isEqualTo("abcdefghij") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(6L) + assertThat(buffer.readUtf8()).isEqualTo("klmnop") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(-1L) + assertThat(buffer.readUtf8()).isEqualTo("") + } + + @Test + fun delegateTooLongWithTruncate() { + val delegate = Buffer().writeUtf8("abcdefghijklmnopqr") + val fixedLengthSource = FixedLengthSource(delegate, 16, truncate = true) + val buffer = Buffer() + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(10L) + assertThat(buffer.readUtf8()).isEqualTo("abcdefghij") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(6L) + assertThat(buffer.readUtf8()).isEqualTo("klmnop") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(-1L) + assertThat(buffer.readUtf8()).isEqualTo("") + } + + @Test + fun delegateTooLongWithTruncateFencepost() { + val delegate = Buffer().writeUtf8("abcdefghijklmnop") + val fixedLengthSource = FixedLengthSource(delegate, 10, truncate = true) + val buffer = Buffer() + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(10L) + assertThat(buffer.readUtf8()).isEqualTo("abcdefghij") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(-1L) + assertThat(buffer.readUtf8()).isEmpty() + } + + @Test + fun delegateTooLongNoTruncate() { + val delegate = Buffer().writeUtf8("abcdefghijklmnopqr") + val fixedLengthSource = FixedLengthSource(delegate, 16, truncate = false) + val buffer = Buffer() + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(10L) + assertThat(buffer.readUtf8()).isEqualTo("abcdefghij") + try { + fixedLengthSource.read(buffer, 10L) + fail() + } catch (e: IOException) { + assertThat(e).hasMessage("expected 16 bytes but got 18") + assertThat(buffer.readUtf8()).isEqualTo("klmnop") // Doesn't produce too many bytes! + } + try { + fixedLengthSource.read(buffer, 10L) + fail() + } catch (e: IOException) { + assertThat(e).hasMessage("expected 16 bytes but got 18") + assertThat(buffer.readUtf8()).isEmpty() // Doesn't produce any bytes! + } + } + + @Test + fun delegateTooLongNoTruncateFencepost() { + val delegate = Buffer().writeUtf8("abcdefghijklmnop") + val fixedLengthSource = FixedLengthSource(delegate, 10, truncate = false) + val buffer = Buffer() + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(10L) + assertThat(buffer.readUtf8()).isEqualTo("abcdefghij") + try { + fixedLengthSource.read(buffer, 10L) + fail() + } catch (e: IOException) { + assertThat(e).hasMessage("expected 10 bytes but got 16") + assertThat(buffer.readUtf8()).isEmpty() // Doesn't produce too many bytes! + } + try { + fixedLengthSource.read(buffer, 10L) + fail() + } catch (e: IOException) { + assertThat(e).hasMessage("expected 10 bytes but got 16") + assertThat(buffer.readUtf8()).isEmpty() // Doesn't produce any bytes! + } + } + + @Test + fun delegateTooShortWithTruncate() { + val delegate = Buffer().writeUtf8("abcdefghijklmn") + val fixedLengthSource = FixedLengthSource(delegate, 16, truncate = true) + val buffer = Buffer() + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(10L) + assertThat(buffer.readUtf8()).isEqualTo("abcdefghij") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(4L) + assertThat(buffer.readUtf8()).isEqualTo("klmn") + try { + fixedLengthSource.read(buffer, 10L) + fail() + } catch (e: IOException) { + assertThat(e).hasMessage("expected 16 bytes but got 14") + } + try { + fixedLengthSource.read(buffer, 10L) + fail() + } catch (e: IOException) { + assertThat(e).hasMessage("expected 16 bytes but got 14") + } + } + + @Test + fun delegateTooShortNoTruncate() { + val delegate = Buffer().writeUtf8("abcdefghijklmn") + val fixedLengthSource = FixedLengthSource(delegate, 16, truncate = false) + val buffer = Buffer() + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(10L) + assertThat(buffer.readUtf8()).isEqualTo("abcdefghij") + assertThat(fixedLengthSource.read(buffer, 10L)).isEqualTo(4L) + assertThat(buffer.readUtf8()).isEqualTo("klmn") + try { + fixedLengthSource.read(buffer, 10L) + fail() + } catch (e: IOException) { + assertThat(e).hasMessage("expected 16 bytes but got 14") + } + try { + fixedLengthSource.read(buffer, 10L) + fail() + } catch (e: IOException) { + assertThat(e).hasMessage("expected 16 bytes but got 14") + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/ForwardingTimeoutKotlinTest.kt b/okio/src/jvmTest/kotlin/okio/ForwardingTimeoutKotlinTest.kt index 4db02c50..f2c19c6e 100644 --- a/okio/src/jvmTest/kotlin/okio/ForwardingTimeoutKotlinTest.kt +++ b/okio/src/jvmTest/kotlin/okio/ForwardingTimeoutKotlinTest.kt @@ -15,9 +15,9 @@ */ package okio +import java.util.concurrent.TimeUnit import org.assertj.core.api.Assertions.assertThat import org.junit.Test -import java.util.concurrent.TimeUnit class ForwardingTimeoutKotlinTest { @Test fun getAndSetDelegate() { diff --git a/okio/src/jvmTest/kotlin/okio/ForwardingTimeoutTest.kt b/okio/src/jvmTest/kotlin/okio/ForwardingTimeoutTest.kt new file mode 100644 index 00000000..637cb976 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/ForwardingTimeoutTest.kt @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2018 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.util.concurrent.TimeUnit +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class ForwardingTimeoutTest { + @Test + fun getAndSetDelegate() { + val timeout1 = Timeout() + val timeout2 = Timeout() + val forwardingTimeout = ForwardingTimeout(timeout1) + forwardingTimeout.timeout(5, TimeUnit.SECONDS) + assertThat(timeout1.timeoutNanos()).isNotEqualTo(0L) + assertThat(timeout2.timeoutNanos()).isEqualTo(0L) + forwardingTimeout.clearTimeout() + assertThat(timeout1.timeoutNanos()).isEqualTo(0L) + assertThat(timeout2.timeoutNanos()).isEqualTo(0L) + assertThat(forwardingTimeout.delegate).isEqualTo(timeout1) + assertThat(forwardingTimeout.setDelegate(timeout2)).isEqualTo(forwardingTimeout) + forwardingTimeout.timeout(5, TimeUnit.SECONDS) + assertThat(timeout1.timeoutNanos()).isEqualTo(0L) + assertThat(timeout2.timeoutNanos()).isNotEqualTo(0L) + forwardingTimeout.clearTimeout() + assertThat(timeout1.timeoutNanos()).isEqualTo(0L) + assertThat(timeout2.timeoutNanos()).isEqualTo(0L) + assertThat(forwardingTimeout.delegate).isEqualTo(timeout2) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/GzipKotlinTest.kt b/okio/src/jvmTest/kotlin/okio/GzipKotlinTest.kt index 6bbe9b51..dfe7182a 100644 --- a/okio/src/jvmTest/kotlin/okio/GzipKotlinTest.kt +++ b/okio/src/jvmTest/kotlin/okio/GzipKotlinTest.kt @@ -16,21 +16,35 @@ package okio +import kotlin.test.assertEquals import okio.ByteString.Companion.decodeHex import org.junit.Test -import kotlin.test.assertEquals class GzipKotlinTest { @Test fun sink() { val data = Buffer() - val gzip = (data as Sink).gzip() - gzip.buffer().writeUtf8("Hi!").close() + (data as Sink).gzip().buffer().use { gzip -> + gzip.writeUtf8("Hi!") + } assertEquals("1f8b0800000000000000f3c8540400dac59e7903000000", data.readByteString().hex()) } @Test fun source() { val buffer = Buffer().write("1f8b0800000000000000f3c8540400dac59e7903000000".decodeHex()) - val gzip = (buffer as Source).gzip() - assertEquals("Hi!", gzip.buffer().readUtf8()) + (buffer as Source).gzip().buffer().use { gzip -> + assertEquals("Hi!", gzip.readUtf8()) + } + } + + @Test fun extraLongXlen() { + val xlen = 0xffff + val buffer = Buffer() + .write("1f8b0804000000000000".decodeHex()) + .writeShort(xlen) + .write(ByteArray(xlen)) + .write("f3c8540400dac59e7903000000".decodeHex()) + (buffer as Source).gzip().buffer().use { gzip -> + assertEquals("Hi!", gzip.readUtf8()) + } } } diff --git a/okio/src/jvmTest/kotlin/okio/GzipSinkTest.kt b/okio/src/jvmTest/kotlin/okio/GzipSinkTest.kt new file mode 100644 index 00000000..b3fe171a --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/GzipSinkTest.kt @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.IOException +import okio.TestUtil.SEGMENT_SIZE +import org.junit.Assert.assertEquals +import org.junit.Assert.fail +import org.junit.Test + +class GzipSinkTest { + @Test + fun gzipGunzip() { + val data = Buffer() + val original = "It's a UNIX system! I know this!" + data.writeUtf8(original) + val sink = Buffer() + val gzipSink = GzipSink(sink) + gzipSink.write(data, data.size) + gzipSink.close() + val inflated = gunzip(sink) + assertEquals(original, inflated.readUtf8()) + } + + @Test + fun closeWithExceptionWhenWritingAndClosing() { + val mockSink = MockSink() + mockSink.scheduleThrow(0, IOException("first")) + mockSink.scheduleThrow(1, IOException("second")) + val gzipSink = GzipSink(mockSink) + gzipSink.write(Buffer().writeUtf8("a".repeat(SEGMENT_SIZE)), SEGMENT_SIZE.toLong()) + try { + gzipSink.close() + fail() + } catch (expected: IOException) { + assertEquals("first", expected.message) + } + mockSink.assertLogContains("close()") + } + + private fun gunzip(gzipped: Buffer): Buffer { + val result = Buffer() + val source = GzipSource(gzipped) + while (source.read(result, Int.MAX_VALUE.toLong()) != -1L) { + } + return result + } +} diff --git a/okio/src/jvmTest/kotlin/okio/GzipSourceTest.kt b/okio/src/jvmTest/kotlin/okio/GzipSourceTest.kt new file mode 100644 index 00000000..812aab14 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/GzipSourceTest.kt @@ -0,0 +1,232 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.IOException +import java.util.zip.CRC32 +import kotlin.text.Charsets.UTF_8 +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.of +import okio.TestUtil.reverseBytes +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Test + +class GzipSourceTest { + @Test + fun gunzip() { + val gzipped = Buffer() + gzipped.write(gzipHeader) + gzipped.write(deflated) + gzipped.write(gzipTrailer) + assertGzipped(gzipped) + } + + @Test + fun gunzip_withHCRC() { + val hcrc = CRC32() + val gzipHeader = gzipHeaderWithFlags(0x02.toByte()) + hcrc.update(gzipHeader.toByteArray()) + val gzipped = Buffer() + gzipped.write(gzipHeader) + gzipped.writeShort(hcrc.value.toShort().reverseBytes().toInt()) // little endian + gzipped.write(deflated) + gzipped.write(gzipTrailer) + assertGzipped(gzipped) + } + + @Test + fun gunzip_withExtra() { + val gzipped = Buffer() + gzipped.write(gzipHeaderWithFlags(0x04.toByte())) + gzipped.writeShort(7.toShort().reverseBytes().toInt()) // little endian extra length + gzipped.write("blubber".toByteArray(UTF_8), 0, 7) + gzipped.write(deflated) + gzipped.write(gzipTrailer) + assertGzipped(gzipped) + } + + @Test + fun gunzip_withName() { + val gzipped = Buffer() + gzipped.write(gzipHeaderWithFlags(0x08.toByte())) + gzipped.write("foo.txt".toByteArray(UTF_8), 0, 7) + gzipped.writeByte(0) // zero-terminated + gzipped.write(deflated) + gzipped.write(gzipTrailer) + assertGzipped(gzipped) + } + + @Test + fun gunzip_withComment() { + val gzipped = Buffer() + gzipped.write(gzipHeaderWithFlags(0x10.toByte())) + gzipped.write("rubbish".toByteArray(UTF_8), 0, 7) + gzipped.writeByte(0) // zero-terminated + gzipped.write(deflated) + gzipped.write(gzipTrailer) + assertGzipped(gzipped) + } + + /** + * For portability, it is a good idea to export the gzipped bytes and try running gzip. Ex. + * `echo gzipped | base64 --decode | gzip -l -v` + */ + @Test + fun gunzip_withAll() { + val gzipped = Buffer() + gzipped.write(gzipHeaderWithFlags(0x1c.toByte())) + gzipped.writeShort(7.toShort().reverseBytes().toInt()) // little endian extra length + gzipped.write("blubber".toByteArray(UTF_8), 0, 7) + gzipped.write("foo.txt".toByteArray(UTF_8), 0, 7) + gzipped.writeByte(0) // zero-terminated + gzipped.write("rubbish".toByteArray(UTF_8), 0, 7) + gzipped.writeByte(0) // zero-terminated + gzipped.write(deflated) + gzipped.write(gzipTrailer) + assertGzipped(gzipped) + } + + private fun assertGzipped(gzipped: Buffer) { + val gunzipped = gunzip(gzipped) + assertEquals("It's a UNIX system! I know this!", gunzipped.readUtf8()) + } + + /** + * Note that you cannot test this with old versions of gzip, as they interpret flag bit 1 as + * CONTINUATION, not HCRC. For example, this is the case with the default gzip on osx. + */ + @Test + fun gunzipWhenHeaderCRCIncorrect() { + val gzipped = Buffer() + gzipped.write(gzipHeaderWithFlags(0x02.toByte())) + gzipped.writeShort(0.toShort().toInt()) // wrong HCRC! + gzipped.write(deflated) + gzipped.write(gzipTrailer) + try { + gunzip(gzipped) + fail() + } catch (e: IOException) { + assertEquals("FHCRC: actual 0x0000261d != expected 0x00000000", e.message) + } + } + + @Test + fun gunzipWhenCRCIncorrect() { + val gzipped = Buffer() + gzipped.write(gzipHeader) + gzipped.write(deflated) + gzipped.writeInt(0x1234567.reverseBytes()) // wrong CRC + gzipped.write(gzipTrailer.toByteArray(), 3, 4) + try { + gunzip(gzipped) + fail() + } catch (e: IOException) { + assertEquals("CRC: actual 0x37ad8f8d != expected 0x01234567", e.message) + } + } + + @Test + fun gunzipWhenLengthIncorrect() { + val gzipped = Buffer() + gzipped.write(gzipHeader) + gzipped.write(deflated) + gzipped.write(gzipTrailer.toByteArray(), 0, 4) + gzipped.writeInt(0x123456.reverseBytes()) // wrong length + try { + gunzip(gzipped) + fail() + } catch (e: IOException) { + assertEquals("ISIZE: actual 0x00000020 != expected 0x00123456", e.message) + } + } + + @Test + fun gunzipExhaustsSource() { + val gzippedSource = Buffer() + .write("1f8b08000000000000004b4c4a0600c241243503000000".decodeHex()) // 'abc' + val exhaustableSource = ExhaustableSource(gzippedSource) + val gunzippedSource = GzipSource(exhaustableSource).buffer() + assertEquals('a'.code.toLong(), gunzippedSource.readByte().toLong()) + assertEquals('b'.code.toLong(), gunzippedSource.readByte().toLong()) + assertEquals('c'.code.toLong(), gunzippedSource.readByte().toLong()) + assertFalse(exhaustableSource.exhausted) + assertEquals(-1, gunzippedSource.read(Buffer(), 1)) + assertTrue(exhaustableSource.exhausted) + } + + @Test + fun gunzipThrowsIfSourceIsNotExhausted() { + val gzippedSource = Buffer() + .write("1f8b08000000000000004b4c4a0600c241243503000000".decodeHex()) // 'abc' + gzippedSource.writeByte('d'.code) // This byte shouldn't be here! + val gunzippedSource = GzipSource(gzippedSource).buffer() + assertEquals('a'.code.toLong(), gunzippedSource.readByte().toLong()) + assertEquals('b'.code.toLong(), gunzippedSource.readByte().toLong()) + assertEquals('c'.code.toLong(), gunzippedSource.readByte().toLong()) + try { + gunzippedSource.readByte() + fail() + } catch (expected: IOException) { + } + } + + private fun gzipHeaderWithFlags(flags: Byte): ByteString { + val result = gzipHeader.toByteArray() + result[3] = flags + return of(*result) + } + + private val gzipHeader = "1f8b0800000000000000".decodeHex() + + // Deflated "It's a UNIX system! I know this!" + private val deflated = "f32c512f56485408f5f38c5028ae2c2e49cd5554f054c8cecb2f5728c9c82c560400".decodeHex() + private val gzipTrailer = ( + "" + + "8d8fad37" + // Checksum of deflated. + "20000000" + ) // 32 in little endian. + .decodeHex() + + private fun gunzip(gzipped: Buffer): Buffer { + val result = Buffer() + val source = GzipSource(gzipped) + while (source.read(result, Int.MAX_VALUE.toLong()) != -1L) { + } + return result + } + + /** This source keeps track of whether its read has returned -1. */ + internal class ExhaustableSource(private val source: Source) : Source { + var exhausted = false + + override fun read(sink: Buffer, byteCount: Long): Long { + val result = source.read(sink, byteCount) + if (result == -1L) exhausted = true + return result + } + + override fun timeout(): Timeout { + return source.timeout() + } + + override fun close() { + source.close() + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/InflaterSourceTest.kt b/okio/src/jvmTest/kotlin/okio/InflaterSourceTest.kt new file mode 100644 index 00000000..53273776 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/InflaterSourceTest.kt @@ -0,0 +1,215 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.EOFException +import java.util.zip.DeflaterOutputStream +import java.util.zip.Inflater +import okio.BufferedSourceFactory.Companion.PARAMETERIZED_TEST_VALUES +import okio.ByteString.Companion.decodeBase64 +import okio.ByteString.Companion.encodeUtf8 +import okio.TestUtil.SEGMENT_SIZE +import okio.TestUtil.randomBytes +import org.assertj.core.api.Assertions.assertThat +import org.junit.Assert.assertEquals +import org.junit.Assert.fail +import org.junit.Assume.assumeFalse +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +@RunWith(Parameterized::class) +class InflaterSourceTest( + private val bufferFactory: BufferedSourceFactory, +) { + private lateinit var deflatedSink: BufferedSink + private lateinit var deflatedSource: BufferedSource + + init { + resetDeflatedSourceAndSink() + } + + private fun resetDeflatedSourceAndSink() { + val pipe = bufferFactory.pipe() + deflatedSink = pipe.sink + deflatedSource = pipe.source + } + + @Test + fun inflate() { + decodeBase64("eJxzz09RyEjNKVAoLdZRKE9VL0pVyMxTKMlIVchIzEspVshPU0jNS8/MS00tKtYDAF6CD5s=") + val inflated = inflate(deflatedSource) + assertEquals("God help us, we're in the hands of engineers.", inflated.readUtf8()) + } + + @Test + fun inflateTruncated() { + decodeBase64("eJxzz09RyEjNKVAoLdZRKE9VL0pVyMxTKMlIVchIzEspVshPU0jNS8/MS00tKtYDAF6CDw==") + try { + inflate(deflatedSource) + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun inflateWellCompressed() { + decodeBase64( + "eJztwTEBAAAAwqCs61/CEL5AAQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB8BtFeWvE=", + ) + val original = "a".repeat(1024 * 1024) + deflate(original.encodeUtf8()) + val inflated = inflate(deflatedSource) + assertEquals(original, inflated.readUtf8()) + } + + @Test + fun inflatePoorlyCompressed() { + assumeFalse(bufferFactory.isOneByteAtATime) // 8 GiB for 1 byte per segment! + val original = randomBytes(1024 * 1024) + deflate(original) + val inflated = inflate(deflatedSource) + assertEquals(original, inflated.readByteString()) + } + + @Test + fun inflateIntoNonemptySink() { + for (i in 0 until SEGMENT_SIZE) { + resetDeflatedSourceAndSink() + val inflated = Buffer().writeUtf8("a".repeat(i)) + deflate("God help us, we're in the hands of engineers.".encodeUtf8()) + val source = InflaterSource(deflatedSource, Inflater()) + while (source.read(inflated, Int.MAX_VALUE.toLong()) != -1L) { + } + inflated.skip(i.toLong()) + assertEquals("God help us, we're in the hands of engineers.", inflated.readUtf8()) + } + } + + @Test + fun inflateSingleByte() { + val inflated = Buffer() + decodeBase64("eJxzz09RyEjNKVAoLdZRKE9VL0pVyMxTKMlIVchIzEspVshPU0jNS8/MS00tKtYDAF6CD5s=") + val source = InflaterSource(deflatedSource, Inflater()) + source.read(inflated, 1) + source.close() + assertEquals("G", inflated.readUtf8()) + assertEquals(0, inflated.size) + } + + @Test + fun inflateByteCount() { + assumeFalse(bufferFactory.isOneByteAtATime) // This test assumes one step. + val inflated = Buffer() + decodeBase64("eJxzz09RyEjNKVAoLdZRKE9VL0pVyMxTKMlIVchIzEspVshPU0jNS8/MS00tKtYDAF6CD5s=") + val source = InflaterSource(deflatedSource, Inflater()) + source.read(inflated, 11) + source.close() + assertEquals("God help us", inflated.readUtf8()) + assertEquals(0, inflated.size) + } + + @Test + fun sourceExhaustedPrematurelyOnRead() { + // Deflate 0 bytes of data that lacks the in-stream terminator. + decodeBase64("eJwAAAD//w==") + val inflated = Buffer() + val inflater = Inflater() + val source = InflaterSource(deflatedSource, inflater) + assertThat(deflatedSource.exhausted()).isFalse + try { + source.read(inflated, Long.MAX_VALUE) + fail() + } catch (expected: EOFException) { + assertThat(expected).hasMessage("source exhausted prematurely") + } + + // Despite the exception, the read() call made forward progress on the underlying stream! + assertThat(deflatedSource.exhausted()).isTrue + } + + /** + * Confirm that [InflaterSource.readOrInflate] consumes a byte on each call even if it + * doesn't produce a byte on every call. + */ + @Test + fun readOrInflateMakesByteByByteProgress() { + // Deflate 0 bytes of data that lacks the in-stream terminator. + decodeBase64("eJwAAAD//w==") + val deflatedByteCount = 7 + val inflated = Buffer() + val inflater = Inflater() + val source = InflaterSource(deflatedSource, inflater) + assertThat(deflatedSource.exhausted()).isFalse + if (bufferFactory.isOneByteAtATime) { + for (i in 0 until deflatedByteCount) { + assertThat(inflater.bytesRead).isEqualTo(i.toLong()) + assertThat(source.readOrInflate(inflated, Long.MAX_VALUE)).isEqualTo(0L) + } + } else { + assertThat(source.readOrInflate(inflated, Long.MAX_VALUE)).isEqualTo(0L) + } + assertThat(inflater.bytesRead).isEqualTo(deflatedByteCount.toLong()) + assertThat(deflatedSource.exhausted()).isTrue() + } + + private fun decodeBase64(s: String) { + deflatedSink.write(s.decodeBase64()!!) + deflatedSink.flush() + } + + /** Use DeflaterOutputStream to deflate source. */ + private fun deflate(source: ByteString) { + val sink = DeflaterOutputStream(deflatedSink.outputStream()).sink() + sink.write(Buffer().write(source), source.size.toLong()) + sink.close() + } + + /** Returns a new buffer containing the inflated contents of `deflated`. */ + private fun inflate(deflated: BufferedSource?): Buffer { + val result = Buffer() + val source = InflaterSource(deflated!!, Inflater()) + while (source.read(result, Int.MAX_VALUE.toLong()) != -1L) { + } + return result + } + + companion object { + /** + * Use a parameterized test to control how many bytes the InflaterSource gets with each request + * for more bytes. + */ + @JvmStatic + @Parameters(name = "{0}") + fun parameters(): List<Array<Any>> = PARAMETERIZED_TEST_VALUES + } +} diff --git a/okio/src/jvmTest/kotlin/okio/JimfsOkioRoundTripTest.kt b/okio/src/jvmTest/kotlin/okio/JimfsOkioRoundTripTest.kt new file mode 100644 index 00000000..d4e2239b --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/JimfsOkioRoundTripTest.kt @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2023 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import com.google.common.jimfs.Configuration +import com.google.common.jimfs.Jimfs +import java.nio.file.StandardOpenOption +import kotlin.io.path.readText +import kotlin.io.path.writeText +import kotlin.test.BeforeTest +import kotlin.test.assertEquals +import okio.FileSystem.Companion.asOkioFileSystem +import org.junit.Test + +class JimfsOkioRoundTripTest { + private val temporaryDirectory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY + private val jimFs = Jimfs.newFileSystem( + when (Path.DIRECTORY_SEPARATOR == "\\") { + true -> Configuration.windows() + false -> Configuration.unix() + }.toBuilder() + .setWorkingDirectory(temporaryDirectory.toString()) + .build(), + ) + private val jimFsRoot = jimFs.rootDirectories.first() + private val okioFs = jimFs.asOkioFileSystem() + private val base: Path = temporaryDirectory / "${this::class.simpleName}-${randomToken(16)}" + + @BeforeTest + fun setUp() { + okioFs.createDirectory(base) + } + + @Test + fun writeOkioReadJim() { + val path = base / "file-handle-write-okio-and-read-jim" + + okioFs.write(path) { + writeUtf8("abcdefghijklmnop") + } + + assertEquals("abcdefghijklmnop", jimFsRoot.resolve(path.toString()).readText(Charsets.UTF_8)) + } + + @Test + fun writeJimReadOkio() { + val path = base / "file-handle-write-jim-and-read-okio" + jimFsRoot.resolve(path.toString()).writeText("abcdefghijklmnop", Charsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.WRITE) + + okioFs.openReadWrite(path).use { handle -> + handle.source().buffer().use { source -> + assertEquals("abcde", source.readUtf8(5)) + assertEquals("fghijklmnop", source.readUtf8()) + } + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/JvmSystemFileSystemTest.kt b/okio/src/jvmTest/kotlin/okio/JvmSystemFileSystemTest.kt new file mode 100644 index 00000000..fc4b7c09 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/JvmSystemFileSystemTest.kt @@ -0,0 +1,86 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import com.google.common.jimfs.Configuration +import com.google.common.jimfs.Jimfs +import java.io.InterruptedIOException +import java.nio.file.FileSystems +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.fail +import kotlinx.datetime.Clock +import okio.FileSystem.Companion.asOkioFileSystem +import org.junit.Test + +/** + * This test will run using [NioSystemFileSystem] by default. If [java.nio.file.Files] is not found + * on the classpath, [JvmSystemFileSystem] will be use instead. + */ +class NioSystemFileSystemTest : AbstractFileSystemTest( + clock = Clock.System, + fileSystem = FileSystem.SYSTEM, + windowsLimitations = Path.DIRECTORY_SEPARATOR == "\\", + allowClobberingEmptyDirectories = Path.DIRECTORY_SEPARATOR == "\\", + allowAtomicMoveFromFileToDirectory = false, + temporaryDirectory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY, +) + +class JvmSystemFileSystemTest : AbstractFileSystemTest( + clock = Clock.System, + fileSystem = JvmSystemFileSystem(), + windowsLimitations = Path.DIRECTORY_SEPARATOR == "\\", + allowClobberingEmptyDirectories = Path.DIRECTORY_SEPARATOR == "\\", + allowAtomicMoveFromFileToDirectory = false, + temporaryDirectory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY, +) { + + @Test fun checkInterruptedBeforeDeleting() { + Thread.currentThread().interrupt() + try { + fileSystem.delete(base) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("interrupted", expected.message) + assertFalse(Thread.interrupted()) + } + } +} + +class NioJimFileSystemWrappingFileSystemTest : AbstractFileSystemTest( + clock = Clock.System, + fileSystem = Jimfs + .newFileSystem( + when (Path.DIRECTORY_SEPARATOR == "\\") { + true -> Configuration.windows() + false -> Configuration.unix() + }, + ).asOkioFileSystem(), + windowsLimitations = false, + allowClobberingEmptyDirectories = true, + allowAtomicMoveFromFileToDirectory = true, + temporaryDirectory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY, +) + +class NioDefaultFileSystemWrappingFileSystemTest : AbstractFileSystemTest( + clock = Clock.System, + fileSystem = FileSystems.getDefault().asOkioFileSystem(), + windowsLimitations = false, + allowClobberingEmptyDirectories = Path.DIRECTORY_SEPARATOR == "\\", + allowAtomicMoveFromFileToDirectory = false, + allowRenameWhenTargetIsOpen = Path.DIRECTORY_SEPARATOR != "\\", + temporaryDirectory = FileSystem.SYSTEM_TEMPORARY_DIRECTORY, +) diff --git a/okio/src/jvmTest/kotlin/okio/JvmTest.kt b/okio/src/jvmTest/kotlin/okio/JvmTest.kt index f7114bcf..690e64a3 100644 --- a/okio/src/jvmTest/kotlin/okio/JvmTest.kt +++ b/okio/src/jvmTest/kotlin/okio/JvmTest.kt @@ -15,14 +15,13 @@ */ package okio -import okio.Path.Companion.toOkioPath -import okio.Path.Companion.toPath -import org.assertj.core.api.Assertions.assertThat import java.io.File import java.nio.file.Paths import kotlin.test.Test +import okio.Path.Companion.toOkioPath +import okio.Path.Companion.toPath +import org.assertj.core.api.Assertions.assertThat -@ExperimentalFileSystem class JvmTest { @Test fun baseDirectoryConsistentWithJavaIoFile() { diff --git a/okio/src/jvmTest/kotlin/okio/JvmTesting.kt b/okio/src/jvmTest/kotlin/okio/JvmTesting.kt new file mode 100644 index 00000000..e6b091d7 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/JvmTesting.kt @@ -0,0 +1,54 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import okio.Path.Companion.toOkioPath +import okio.Path.Companion.toPath + +actual fun assertRelativeTo( + a: Path, + b: Path, + bRelativeToA: Path, + sameAsNio: Boolean, +) { + val actual = b.relativeTo(a) + assertEquals(bRelativeToA, actual) + assertEquals(b.normalized().withUnixSlashes(), (a / actual).normalized().withUnixSlashes()) + // Also confirm our behavior is consistent with java.nio. + if (sameAsNio) { + // On Windows, java.nio will modify slashes to backslashes for relative paths, so we force it. + val nioPath = a.toNioPath().relativize(b.toNioPath()) + .toOkioPath(normalize = true).withUnixSlashes().toPath() + assertEquals(bRelativeToA, nioPath) + } +} + +actual fun assertRelativeToFails( + a: Path, + b: Path, + sameAsNio: Boolean, +): IllegalArgumentException { + // Check java.nio first. + if (sameAsNio) { + assertFailsWith<IllegalArgumentException> { + a.toNioPath().relativize(b.toNioPath()) + } + } + // Return okio. + return assertFailsWith { b.relativeTo(a) } +} diff --git a/okio/src/jvmTest/kotlin/okio/LargeStreamsTest.kt b/okio/src/jvmTest/kotlin/okio/LargeStreamsTest.kt new file mode 100644 index 00000000..47d387fe --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/LargeStreamsTest.kt @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.util.concurrent.Future +import java.util.zip.Deflater +import java.util.zip.GZIPInputStream +import java.util.zip.GZIPOutputStream +import okio.ByteString.Companion.decodeHex +import okio.HashingSink.Companion.sha256 +import okio.TestUtil.SEGMENT_SIZE +import okio.TestUtil.randomSource +import okio.TestingExecutors.newExecutorService +import org.junit.Assert.assertEquals +import org.junit.Test + +/** Slow running tests that run a large amount of data through a stream. */ +class LargeStreamsTest { + @Test + fun test() { + val pipe = Pipe((1024 * 1024).toLong()) + val future = readAllAndCloseAsync(randomSource(FOUR_GIB_PLUS_ONE), pipe.sink) + val hashingSink = sha256(blackholeSink()) + readAllAndClose(pipe.source, hashingSink) + assertEquals(FOUR_GIB_PLUS_ONE, future.get() as Long) + assertEquals(SHA256_RANDOM_FOUR_GIB_PLUS_1, hashingSink.hash) + } + + /** Note that this test hangs on Android. */ + @Test + fun gzipSource() { + val pipe = Pipe(1024L * 1024) + val gzipOut = object : GZIPOutputStream(pipe.sink.buffer().outputStream()) { + init { + // Disable compression to speed up a slow test. Improved from 141s to 33s on one machine. + def.setLevel(Deflater.NO_COMPRESSION) + } + } + val future = readAllAndCloseAsync( + randomSource(FOUR_GIB_PLUS_ONE), + gzipOut.sink(), + ) + val hashingSink = sha256(blackholeSink()) + val gzipSource = GzipSource(pipe.source) + readAllAndClose(gzipSource, hashingSink) + assertEquals(FOUR_GIB_PLUS_ONE, future.get() as Long) + assertEquals(SHA256_RANDOM_FOUR_GIB_PLUS_1, hashingSink.hash) + } + + /** Note that this test hangs on Android. */ + @Test + fun gzipSink() { + val pipe = Pipe(1024L * 1024) + val gzipSink = GzipSink(pipe.sink) + + // Disable compression to speed up a slow test. Improved from 141s to 35s on one machine. + gzipSink.deflater.setLevel(Deflater.NO_COMPRESSION) + val future = readAllAndCloseAsync(randomSource(FOUR_GIB_PLUS_ONE), gzipSink) + val hashingSink = sha256(blackholeSink()) + val gzipIn = GZIPInputStream(pipe.source.buffer().inputStream()) + readAllAndClose(gzipIn.source(), hashingSink) + assertEquals(FOUR_GIB_PLUS_ONE, future.get() as Long) + assertEquals(SHA256_RANDOM_FOUR_GIB_PLUS_1, hashingSink.hash) + } + + /** Reads all bytes from `source` and writes them to `sink`. */ + private fun readAllAndClose(source: Source, sink: Sink): Long { + var result = 0L + val buffer = Buffer() + while (true) { + val count = source.read(buffer, SEGMENT_SIZE.toLong()) + if (count == -1L) break + sink.write(buffer, count) + result += count + } + source.close() + sink.close() + return result + } + + /** Calls [readAllAndClose] on a background thread. */ + private fun readAllAndCloseAsync(source: Source, sink: Sink): Future<Long> { + val executor = newExecutorService(0) + return try { + executor.submit<Long> { readAllAndClose(source, sink) } + } finally { + executor.shutdown() + } + } + + companion object { + /** 4 GiB plus 1 byte. This is greater than what can be expressed in an unsigned int. */ + const val FOUR_GIB_PLUS_ONE = 0x100000001L + + /** SHA-256 of `TestUtil.randomSource(FOUR_GIB_PLUS_ONE)`. */ + val SHA256_RANDOM_FOUR_GIB_PLUS_1 = + "9654947a655c5efc445502fd1bf11117d894b7812b7974fde8ca4a02c5066315".decodeHex() + } +} diff --git a/okio/src/jvmTest/kotlin/okio/MessageDigestConsistencyTest.kt b/okio/src/jvmTest/kotlin/okio/MessageDigestConsistencyTest.kt new file mode 100644 index 00000000..dac72988 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/MessageDigestConsistencyTest.kt @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.security.MessageDigest +import java.util.Random +import kotlin.test.Test +import okio.ByteString.Companion.toByteString +import okio.internal.HashFunction +import okio.internal.Md5 +import okio.internal.Sha1 +import okio.internal.Sha256 +import okio.internal.Sha512 +import org.assertj.core.api.Assertions.assertThat + +/** + * Confirm Okio is consistent with the JDK's MessageDigest algorithms for various sizes and slices. + * This makes repeated calls to update() with byte arrays of various sizes and contents to defend + * against bugs in batching inputs. + */ +class MessageDigestConsistencyTest { + @Test fun sha1() { + test("SHA-1") { Sha1() } + } + + @Test fun sha256() { + test("SHA-256") { Sha256() } + } + + @Test fun sha512() { + test("SHA-512") { Sha512() } + } + + @Test fun md5() { + test("MD5") { Md5() } + } + + private fun test(algorithm: String, newHashFunction: () -> HashFunction) { + for (seed in 0L until 1000L) { + for (updateCount in 0 until 10) { + test( + algorithm = algorithm, + hashFunction = newHashFunction(), + seed = seed, + updateCount = updateCount, + ) + } + } + } + + private fun test( + algorithm: String, + hashFunction: HashFunction, + seed: Long, + updateCount: Int, + ) { + val data = Buffer() + + val random = Random(seed) + for (i in 0 until updateCount) { + val size = random.nextInt(1000) + 1 // size must be >= 1. + val byteArray = ByteArray(size).also { random.nextBytes(it) } + val offset = random.nextInt(size) + val byteCount = random.nextInt(size - offset) + + hashFunction.update( + input = byteArray, + offset = offset, + byteCount = byteCount, + ) + + data.write( + source = byteArray, + offset = offset, + byteCount = byteCount, + ) + } + + val okioHash = hashFunction.digest() + + val byteArray = data.readByteArray() + val jdkMessageDigest = MessageDigest.getInstance(algorithm) + val jdkHash = jdkMessageDigest.digest(byteArray) + + assertThat(okioHash.toByteString()).isEqualTo(jdkHash.toByteString()) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/NioTest.kt b/okio/src/jvmTest/kotlin/okio/NioTest.kt new file mode 100644 index 00000000..35e11ceb --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/NioTest.kt @@ -0,0 +1,142 @@ +/* + * Copyright (C) 2018 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.channels.ReadableByteChannel +import java.nio.channels.WritableByteChannel +import java.nio.file.StandardOpenOption +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import kotlin.text.Charsets.UTF_8 +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder + +/** Test interop between our beloved Okio and java.nio. */ +class NioTest { + @JvmField @Rule + var temporaryFolder = TemporaryFolder() + + @Test + fun sourceIsOpen() { + val source = (Buffer() as Source).buffer() + assertTrue(source.isOpen()) + source.close() + assertFalse(source.isOpen()) + } + + @Test + fun sinkIsOpen() { + val sink = (Buffer() as Sink).buffer() + assertTrue(sink.isOpen) + sink.close() + assertFalse(sink.isOpen) + } + + @Test + fun writableChannelNioFile() { + val file = temporaryFolder.newFile() + val fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.WRITE) + testWritableByteChannel(fileChannel) + val emitted = file.source().buffer() + assertEquals("defghijklmnopqrstuvw", emitted.readUtf8()) + emitted.close() + } + + @Test + fun writableChannelBuffer() { + val buffer = Buffer() + testWritableByteChannel(buffer) + assertEquals("defghijklmnopqrstuvw", buffer.readUtf8()) + } + + @Test + fun writableChannelBufferedSink() { + val buffer = Buffer() + val bufferedSink = (buffer as Sink).buffer() + testWritableByteChannel(bufferedSink) + assertEquals("defghijklmnopqrstuvw", buffer.readUtf8()) + } + + @Test + fun readableChannelNioFile() { + val file = temporaryFolder.newFile() + val initialData = file.sink().buffer() + initialData.writeUtf8("abcdefghijklmnopqrstuvwxyz") + initialData.close() + val fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ) + testReadableByteChannel(fileChannel) + } + + @Test + fun readableChannelBuffer() { + val buffer = Buffer() + buffer.writeUtf8("abcdefghijklmnopqrstuvwxyz") + testReadableByteChannel(buffer) + } + + @Test + fun readableChannelBufferedSource() { + val buffer = Buffer() + val bufferedSource = (buffer as Source).buffer() + buffer.writeUtf8("abcdefghijklmnopqrstuvwxyz") + testReadableByteChannel(bufferedSource) + } + + /** + * Does some basic writes to `channel`. We execute this against both Okio's channels and + * also a standard implementation from the JDK to confirm that their behavior is consistent. + */ + private fun testWritableByteChannel(channel: WritableByteChannel) { + assertTrue(channel.isOpen) + val byteBuffer = ByteBuffer.allocate(1024) + byteBuffer.put("abcdefghijklmnopqrstuvwxyz".toByteArray(UTF_8)) + (byteBuffer as java.nio.Buffer).flip() // Cast necessary for Java 8. + (byteBuffer as java.nio.Buffer).position(3) // Cast necessary for Java 8. + (byteBuffer as java.nio.Buffer).limit(23) // Cast necessary for Java 8. + val byteCount = channel.write(byteBuffer) + assertEquals(20, byteCount) + assertEquals(23, byteBuffer.position()) + assertEquals(23, byteBuffer.limit()) + channel.close() + assertEquals(channel is Buffer, channel.isOpen) // Buffer.close() does nothing. + } + + /** + * Does some basic reads from `channel`. We execute this against both Okio's channels and + * also a standard implementation from the JDK to confirm that their behavior is consistent. + */ + private fun testReadableByteChannel(channel: ReadableByteChannel) { + assertTrue(channel.isOpen) + val byteBuffer = ByteBuffer.allocate(1024) + (byteBuffer as java.nio.Buffer).position(3) // Cast necessary for Java 8. + (byteBuffer as java.nio.Buffer).limit(23) // Cast necessary for Java 8. + val byteCount = channel.read(byteBuffer) + assertEquals(20, byteCount) + assertEquals(23, byteBuffer.position()) + assertEquals(23, byteBuffer.limit()) + channel.close() + assertEquals(channel is Buffer, channel.isOpen) // Buffer.close() does nothing. + (byteBuffer as java.nio.Buffer).flip() // Cast necessary for Java 8. + (byteBuffer as java.nio.Buffer).position(3) // Cast necessary for Java 8. + val data = ByteArray(byteBuffer.remaining()) + byteBuffer[data] + assertEquals("abcdefghijklmnopqrst", String(data, UTF_8)) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/OkioKotlinTest.kt b/okio/src/jvmTest/kotlin/okio/OkioKotlinTest.kt index 58c6bcad..86270d1e 100644 --- a/okio/src/jvmTest/kotlin/okio/OkioKotlinTest.kt +++ b/okio/src/jvmTest/kotlin/okio/OkioKotlinTest.kt @@ -16,17 +16,17 @@ package okio -import org.assertj.core.api.Assertions.assertThat -import org.junit.Ignore -import org.junit.Rule -import org.junit.Test -import org.junit.rules.TemporaryFolder import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.io.File import java.net.Socket import java.nio.file.StandardOpenOption import java.nio.file.StandardOpenOption.APPEND +import org.assertj.core.api.Assertions.assertThat +import org.junit.Ignore +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder class OkioKotlinTest { @get:Rule val temp = TemporaryFolder() @@ -96,7 +96,8 @@ class OkioKotlinTest { } @Ignore("Not sure how to test this") - @Test fun pathSourceWithOptions() { + @Test + fun pathSourceWithOptions() { val folder = temp.newFolder() val file = File(folder, "new.txt") file.toPath().source(StandardOpenOption.CREATE_NEW) diff --git a/okio/src/jvmTest/kotlin/okio/OkioTest.kt b/okio/src/jvmTest/kotlin/okio/OkioTest.kt new file mode 100644 index 00000000..9514ddd2 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/OkioTest.kt @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.ByteArrayInputStream +import java.io.ByteArrayOutputStream +import java.nio.file.Files +import kotlin.text.Charsets.UTF_8 +import okio.TestUtil.SEGMENT_SIZE +import okio.TestUtil.assertNoEmptySegments +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder + +class OkioTest { + @JvmField + @Rule + var temporaryFolder = TemporaryFolder() + + @Test + fun readWriteFile() { + val file = temporaryFolder.newFile() + val sink = file.sink().buffer() + sink.writeUtf8("Hello, java.io file!") + sink.close() + assertTrue(file.exists()) + assertEquals(20, file.length()) + val source = file.source().buffer() + assertEquals("Hello, java.io file!", source.readUtf8()) + source.close() + } + + @Test + fun appendFile() { + val file = temporaryFolder.newFile() + var sink = file.appendingSink().buffer() + sink.writeUtf8("Hello, ") + sink.close() + assertTrue(file.exists()) + assertEquals(7, file.length()) + sink = file.appendingSink().buffer() + sink.writeUtf8("java.io file!") + sink.close() + assertEquals(20, file.length()) + val source = file.source().buffer() + assertEquals("Hello, java.io file!", source.readUtf8()) + source.close() + } + + @Test + fun readWritePath() { + val path = temporaryFolder.newFile().toPath() + val sink = path.sink().buffer() + sink.writeUtf8("Hello, java.nio file!") + sink.close() + assertTrue(Files.exists(path)) + assertEquals(21, Files.size(path)) + val source = path.source().buffer() + assertEquals("Hello, java.nio file!", source.readUtf8()) + source.close() + } + + @Test + fun sinkFromOutputStream() { + val data = Buffer() + data.writeUtf8("a") + data.writeUtf8("b".repeat(9998)) + data.writeUtf8("c") + val out = ByteArrayOutputStream() + val sink = out.sink() + sink.write(data, 3) + assertEquals("abb", out.toString("UTF-8")) + sink.write(data, data.size) + assertEquals("a" + "b".repeat(9998) + "c", out.toString("UTF-8")) + } + + @Test + fun sourceFromInputStream() { + val inputStream = ByteArrayInputStream( + ("a" + "b".repeat(SEGMENT_SIZE * 2) + "c").toByteArray(UTF_8), + ) + + // Source: ab...bc + val source = inputStream.source() + val sink = Buffer() + + // Source: b...bc. Sink: abb. + assertEquals(3, source.read(sink, 3)) + assertEquals("abb", sink.readUtf8(3)) + + // Source: b...bc. Sink: b...b. + assertEquals(SEGMENT_SIZE.toLong(), source.read(sink, 20000)) + assertEquals("b".repeat(SEGMENT_SIZE), sink.readUtf8()) + + // Source: b...bc. Sink: b...bc. + assertEquals((SEGMENT_SIZE - 1).toLong(), source.read(sink, 20000)) + assertEquals("b".repeat(SEGMENT_SIZE - 2) + "c", sink.readUtf8()) + + // Source and sink are empty. + assertEquals(-1, source.read(sink, 1)) + } + + @Test + fun sourceFromInputStreamWithSegmentSize() { + val inputStream = ByteArrayInputStream(ByteArray(SEGMENT_SIZE)) + val source = inputStream.source() + val sink = Buffer() + assertEquals(SEGMENT_SIZE.toLong(), source.read(sink, SEGMENT_SIZE.toLong())) + assertEquals(-1, source.read(sink, SEGMENT_SIZE.toLong())) + assertNoEmptySegments(sink) + } + + @Test + fun sourceFromInputStreamBounds() { + val source = ByteArrayInputStream(ByteArray(100)).source() + try { + source.read(Buffer(), -1) + fail() + } catch (expected: IllegalArgumentException) { + } + } + + @Test + fun blackhole() { + val data = Buffer() + data.writeUtf8("blackhole") + val blackhole = blackholeSink() + blackhole.write(data, 5) + assertEquals("hole", data.readUtf8()) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/PipeKotlinTest.kt b/okio/src/jvmTest/kotlin/okio/PipeKotlinTest.kt index 3a41e742..ac50f3a0 100644 --- a/okio/src/jvmTest/kotlin/okio/PipeKotlinTest.kt +++ b/okio/src/jvmTest/kotlin/okio/PipeKotlinTest.kt @@ -15,6 +15,10 @@ */ package okio +import java.io.IOException +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import kotlin.test.assertFailsWith import org.junit.After import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse @@ -22,19 +26,16 @@ import org.junit.Assert.assertTrue import org.junit.Assert.fail import org.junit.Rule import org.junit.Test -import java.io.IOException -import java.util.concurrent.CountDownLatch -import java.util.concurrent.Executors -import java.util.concurrent.TimeUnit -import kotlin.test.assertFailsWith import org.junit.rules.Timeout as JUnitTimeout class PipeKotlinTest { - @JvmField @Rule val timeout = JUnitTimeout(5, TimeUnit.SECONDS) + @JvmField @Rule + val timeout = JUnitTimeout(5, TimeUnit.SECONDS) - private val executorService = Executors.newScheduledThreadPool(1) + private val executorService = TestingExecutors.newScheduledExecutorService(1) - @After @Throws(Exception::class) + @After + @Throws(Exception::class) fun tearDown() { executorService.shutdown() } @@ -108,7 +109,8 @@ class PipeKotlinTest { pipe.fold(foldSink) latch.countDown() }, - 500, TimeUnit.MILLISECONDS + 500, + TimeUnit.MILLISECONDS, ) val sink = pipe.sink.buffer() @@ -536,7 +538,8 @@ class PipeKotlinTest { } assertEquals("boom", foldFailure.message) }, - 500, TimeUnit.MILLISECONDS + 500, + TimeUnit.MILLISECONDS, ) val writeFailure = assertFailsWith<IOException> { @@ -656,7 +659,8 @@ class PipeKotlinTest { { pipe.cancel() }, - smallerTimeoutNanos, TimeUnit.NANOSECONDS + smallerTimeoutNanos, + TimeUnit.NANOSECONDS, ) val pipeSink = pipe.sink.buffer() @@ -699,7 +703,8 @@ class PipeKotlinTest { { pipe.cancel() }, - smallerTimeoutNanos, TimeUnit.NANOSECONDS + smallerTimeoutNanos, + TimeUnit.NANOSECONDS, ) val pipeSource = pipe.source.buffer() @@ -783,8 +788,9 @@ class PipeKotlinTest { val elapsed = TimeUnit.MILLISECONDS.toNanos(System.currentTimeMillis() - start) assertEquals( - expected.toDouble(), elapsed.toDouble(), - TimeUnit.MILLISECONDS.toNanos(200).toDouble() + expected.toDouble(), + elapsed.toDouble(), + TimeUnit.MILLISECONDS.toNanos(200).toDouble(), ) } diff --git a/okio/src/jvmTest/kotlin/okio/PipeTest.kt b/okio/src/jvmTest/kotlin/okio/PipeTest.kt new file mode 100644 index 00000000..74b61290 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/PipeTest.kt @@ -0,0 +1,358 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.IOException +import java.io.InterruptedIOException +import java.util.Random +import java.util.concurrent.TimeUnit +import okio.ByteString.Companion.decodeHex +import okio.HashingSink.Companion.sha1 +import okio.TestUtil.assumeNotWindows +import okio.TestingExecutors.newScheduledExecutorService +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Test + +class PipeTest { + private val executorService = newScheduledExecutorService(2) + + @After + fun tearDown() { + executorService.shutdown() + } + + @Test + fun test() { + val pipe = Pipe(6) + pipe.sink.write(Buffer().writeUtf8("abc"), 3L) + val source = pipe.source + val readBuffer = Buffer() + assertEquals(3L, source.read(readBuffer, 6L)) + assertEquals("abc", readBuffer.readUtf8()) + pipe.sink.close() + assertEquals(-1L, source.read(readBuffer, 6L)) + source.close() + } + + /** + * A producer writes the first 16 MiB of bytes generated by `new Random(0)` to a sink, and a + * consumer consumes them. Both compute hashes of their data to confirm that they're as expected. + */ + @Test + fun largeDataset() { + val pipe = Pipe(1000L) // An awkward size to force producer/consumer exchange. + val totalBytes = 16L * 1024L * 1024L + val expectedHash = "7c3b224bea749086babe079360cf29f98d88262d".decodeHex() + + // Write data to the sink. + val sinkHash = executorService.submit<ByteString> { + val hashingSink = sha1(pipe.sink) + val random = Random(0) + val data = ByteArray(8192) + val buffer = Buffer() + var i = 0L + while (i < totalBytes) { + random.nextBytes(data) + buffer.write(data) + hashingSink.write(buffer, buffer.size) + i += data.size.toLong() + } + hashingSink.close() + hashingSink.hash + } + + // Read data from the source. + val sourceHash = executorService.submit<ByteString> { + val blackhole = Buffer() + val hashingSink = sha1(blackhole) + val buffer = Buffer() + while (pipe.source.read(buffer, Long.MAX_VALUE) != -1L) { + hashingSink.write(buffer, buffer.size) + blackhole.clear() + } + pipe.source.close() + hashingSink.hash + } + assertEquals(expectedHash, sinkHash.get()) + assertEquals(expectedHash, sourceHash.get()) + } + + @Test + fun sinkTimeout() { + assumeNotWindows() + val pipe = Pipe(3) + pipe.sink.timeout().timeout(1000, TimeUnit.MILLISECONDS) + pipe.sink.write(Buffer().writeUtf8("abc"), 3L) + val start = now() + try { + pipe.sink.write(Buffer().writeUtf8("def"), 3L) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + val readBuffer = Buffer() + assertEquals(3L, pipe.source.read(readBuffer, 6L)) + assertEquals("abc", readBuffer.readUtf8()) + } + + @Test + fun sourceTimeout() { + assumeNotWindows() + val pipe = Pipe(3L) + pipe.source.timeout().timeout(1000, TimeUnit.MILLISECONDS) + val start = now() + val readBuffer = Buffer() + try { + pipe.source.read(readBuffer, 6L) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + assertEquals(0, readBuffer.size) + } + + /** + * The writer is writing 12 bytes as fast as it can to a 3 byte buffer. The reader alternates + * sleeping 1000 ms, then reading 3 bytes. That should make for an approximate timeline like + * this: + * + * ``` + * 0: writer writes 'abc', blocks 0: reader sleeps until 1000 + * 1000: reader reads 'abc', sleeps until 2000 + * 1000: writer writes 'def', blocks + * 2000: reader reads 'def', sleeps until 3000 + * 2000: writer writes 'ghi', blocks + * 3000: reader reads 'ghi', sleeps until 4000 + * 3000: writer writes 'jkl', returns + * 4000: reader reads 'jkl', returns + * ``` + * + * + * Because the writer is writing to a buffer, it finishes before the reader does. + */ + @Test + fun sinkBlocksOnSlowReader() { + val pipe = Pipe(3L) + executorService.execute { + val buffer = Buffer() + Thread.sleep(1000L) + assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE)) + assertEquals("abc", buffer.readUtf8()) + Thread.sleep(1000L) + assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE)) + assertEquals("def", buffer.readUtf8()) + Thread.sleep(1000L) + assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE)) + assertEquals("ghi", buffer.readUtf8()) + Thread.sleep(1000L) + assertEquals(3, pipe.source.read(buffer, Long.MAX_VALUE)) + assertEquals("jkl", buffer.readUtf8()) + } + val start = now() + pipe.sink.write(Buffer().writeUtf8("abcdefghijkl"), 12) + assertElapsed(3000.0, start) + } + + @Test + fun sinkWriteFailsByClosedReader() { + val pipe = Pipe(3L) + executorService.schedule( + { + pipe.source.close() + }, + 1000, + TimeUnit.MILLISECONDS, + ) + val start = now() + try { + pipe.sink.write(Buffer().writeUtf8("abcdef"), 6) + fail() + } catch (expected: IOException) { + assertEquals("source is closed", expected.message) + assertElapsed(1000.0, start) + } + } + + @Test + fun sinkFlushDoesntWaitForReader() { + val pipe = Pipe(100L) + pipe.sink.write(Buffer().writeUtf8("abc"), 3) + pipe.sink.flush() + val bufferedSource = pipe.source.buffer() + assertEquals("abc", bufferedSource.readUtf8(3)) + } + + @Test + fun sinkFlushFailsIfReaderIsClosedBeforeAllDataIsRead() { + val pipe = Pipe(100L) + pipe.sink.write(Buffer().writeUtf8("abc"), 3) + pipe.source.close() + try { + pipe.sink.flush() + fail() + } catch (expected: IOException) { + assertEquals("source is closed", expected.message) + } + } + + @Test + fun sinkCloseFailsIfReaderIsClosedBeforeAllDataIsRead() { + val pipe = Pipe(100L) + pipe.sink.write(Buffer().writeUtf8("abc"), 3) + pipe.source.close() + try { + pipe.sink.close() + fail() + } catch (expected: IOException) { + assertEquals("source is closed", expected.message) + } + } + + @Test + fun sinkClose() { + val pipe = Pipe(100L) + pipe.sink.close() + try { + pipe.sink.write(Buffer().writeUtf8("abc"), 3) + fail() + } catch (expected: IllegalStateException) { + assertEquals("closed", expected.message) + } + try { + pipe.sink.flush() + fail() + } catch (expected: IllegalStateException) { + assertEquals("closed", expected.message) + } + } + + @Test + fun sinkMultipleClose() { + val pipe = Pipe(100L) + pipe.sink.close() + pipe.sink.close() + } + + @Test + fun sinkCloseDoesntWaitForSourceRead() { + val pipe = Pipe(100L) + pipe.sink.write(Buffer().writeUtf8("abc"), 3) + pipe.sink.close() + val bufferedSource = pipe.source.buffer() + assertEquals("abc", bufferedSource.readUtf8()) + assertTrue(bufferedSource.exhausted()) + } + + @Test + fun sourceClose() { + val pipe = Pipe(100L) + pipe.source.close() + try { + pipe.source.read(Buffer(), 3) + fail() + } catch (expected: IllegalStateException) { + assertEquals("closed", expected.message) + } + } + + @Test + fun sourceMultipleClose() { + val pipe = Pipe(100L) + pipe.source.close() + pipe.source.close() + } + + @Test + fun sourceReadUnblockedByClosedSink() { + val pipe = Pipe(3L) + executorService.schedule( + { + pipe.sink.close() + }, + 1000, + TimeUnit.MILLISECONDS, + ) + val start = now() + val readBuffer = Buffer() + assertEquals(-1, pipe.source.read(readBuffer, Long.MAX_VALUE)) + assertEquals(0, readBuffer.size) + assertElapsed(1000.0, start) + } + + /** + * The writer has 12 bytes to write. It alternates sleeping 1000 ms, then writing 3 bytes. The + * reader is reading as fast as it can. That should make for an approximate timeline like this: + * + * ``` + * 0: writer sleeps until 1000 + * 0: reader blocks + * 1000: writer writes 'abc', sleeps until 2000 + * 1000: reader reads 'abc' + * 2000: writer writes 'def', sleeps until 3000 + * 2000: reader reads 'def' + * 3000: writer writes 'ghi', sleeps until 4000 + * 3000: reader reads 'ghi' + * 4000: writer writes 'jkl', returns + * 4000: reader reads 'jkl', returns + * ``` + */ + @Test + fun sourceBlocksOnSlowWriter() { + val pipe = Pipe(100L) + executorService.execute { + Thread.sleep(1000L) + pipe.sink.write(Buffer().writeUtf8("abc"), 3) + Thread.sleep(1000L) + pipe.sink.write(Buffer().writeUtf8("def"), 3) + Thread.sleep(1000L) + pipe.sink.write(Buffer().writeUtf8("ghi"), 3) + Thread.sleep(1000L) + pipe.sink.write(Buffer().writeUtf8("jkl"), 3) + } + val start = now() + val readBuffer = Buffer() + assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE)) + assertEquals("abc", readBuffer.readUtf8()) + assertElapsed(1000.0, start) + assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE)) + assertEquals("def", readBuffer.readUtf8()) + assertElapsed(2000.0, start) + assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE)) + assertEquals("ghi", readBuffer.readUtf8()) + assertElapsed(3000.0, start) + assertEquals(3, pipe.source.read(readBuffer, Long.MAX_VALUE)) + assertEquals("jkl", readBuffer.readUtf8()) + assertElapsed(4000.0, start) + } + + /** Returns the nanotime in milliseconds as a double for measuring timeouts. */ + private fun now(): Double { + return System.nanoTime() / 1000000.0 + } + + /** + * Fails the test unless the time from start until now is duration, accepting differences in + * -50..+450 milliseconds. + */ + private fun assertElapsed(duration: Double, start: Double) { + assertEquals(duration, now() - start - 200.0, 250.0) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/ReadUtf8LineTest.kt b/okio/src/jvmTest/kotlin/okio/ReadUtf8LineTest.kt new file mode 100644 index 00000000..6f64355b --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/ReadUtf8LineTest.kt @@ -0,0 +1,217 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.EOFException +import okio.TestUtil.SEGMENT_SIZE +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNull +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameter +import org.junit.runners.Parameterized.Parameters + +@RunWith(Parameterized::class) +class ReadUtf8LineTest { + interface Factory { + fun create(data: Buffer): BufferedSource + } + + @Parameter + lateinit var factory: Factory + private lateinit var data: Buffer + private lateinit var source: BufferedSource + + @Before + fun setUp() { + data = Buffer() + source = factory.create(data) + } + + @Test + fun readLines() { + data.writeUtf8("abc\ndef\n") + assertEquals("abc", source.readUtf8LineStrict()) + assertEquals("def", source.readUtf8LineStrict()) + try { + source.readUtf8LineStrict() + fail() + } catch (expected: EOFException) { + assertEquals("\\n not found: limit=0 content=…", expected.message) + } + } + + @Test + fun readUtf8LineStrictWithLimits() { + val lens = intArrayOf(1, SEGMENT_SIZE - 2, SEGMENT_SIZE - 1, SEGMENT_SIZE, SEGMENT_SIZE * 10) + for (len in lens) { + data.writeUtf8("a".repeat(len)).writeUtf8("\n") + assertEquals(len.toLong(), source.readUtf8LineStrict(len.toLong()).length.toLong()) + source.readUtf8() + data.writeUtf8("a".repeat(len)).writeUtf8("\n").writeUtf8("a".repeat(len)) + assertEquals(len.toLong(), source.readUtf8LineStrict(len.toLong()).length.toLong()) + source.readUtf8() + data.writeUtf8("a".repeat(len)).writeUtf8("\r\n") + assertEquals(len.toLong(), source.readUtf8LineStrict(len.toLong()).length.toLong()) + source.readUtf8() + data.writeUtf8("a".repeat(len)).writeUtf8("\r\n").writeUtf8("a".repeat(len)) + assertEquals(len.toLong(), source.readUtf8LineStrict(len.toLong()).length.toLong()) + source.readUtf8() + } + } + + @Test + fun readUtf8LineStrictNoBytesConsumedOnFailure() { + data.writeUtf8("abc\n") + try { + source.readUtf8LineStrict(2) + fail() + } catch (expected: EOFException) { + assertTrue(expected.message!!.startsWith("\\n not found: limit=2 content=61626")) + } + assertEquals("abc", source.readUtf8LineStrict(3)) + } + + @Test + fun readUtf8LineStrictEmptyString() { + data.writeUtf8("\r\nabc") + assertEquals("", source.readUtf8LineStrict(0)) + assertEquals("abc", source.readUtf8()) + } + + @Test + fun readUtf8LineStrictNonPositive() { + data.writeUtf8("\r\n") + try { + source.readUtf8LineStrict(-1) + fail("Expected failure: limit must be greater than 0") + } catch (expected: IllegalArgumentException) { + } + } + + @Test + fun eofExceptionProvidesLimitedContent() { + data.writeUtf8("aaaaaaaabbbbbbbbccccccccdddddddde") + try { + source.readUtf8LineStrict() + fail() + } catch (expected: EOFException) { + assertEquals( + "\\n not found: limit=33 content=616161616161616162626262626262626363636363636363" + + "6464646464646464…", + expected.message, + ) + } + } + + @Test + fun newlineAtEnd() { + data.writeUtf8("abc\n") + assertEquals("abc", source.readUtf8LineStrict(3)) + assertTrue(source.exhausted()) + data.writeUtf8("abc\r\n") + assertEquals("abc", source.readUtf8LineStrict(3)) + assertTrue(source.exhausted()) + data.writeUtf8("abc\r") + try { + source.readUtf8LineStrict(3) + fail() + } catch (expected: EOFException) { + assertEquals("\\n not found: limit=3 content=6162630d…", expected.message) + } + source.readUtf8() + data.writeUtf8("abc") + try { + source.readUtf8LineStrict(3) + fail() + } catch (expected: EOFException) { + assertEquals("\\n not found: limit=3 content=616263…", expected.message) + } + } + + @Test + fun emptyLines() { + data.writeUtf8("\n\n\n") + assertEquals("", source.readUtf8LineStrict()) + assertEquals("", source.readUtf8LineStrict()) + assertEquals("", source.readUtf8LineStrict()) + assertTrue(source.exhausted()) + } + + @Test + fun crDroppedPrecedingLf() { + data.writeUtf8("abc\r\ndef\r\nghi\rjkl\r\n") + assertEquals("abc", source.readUtf8LineStrict()) + assertEquals("def", source.readUtf8LineStrict()) + assertEquals("ghi\rjkl", source.readUtf8LineStrict()) + } + + @Test + fun bufferedReaderCompatible() { + data.writeUtf8("abc\ndef") + assertEquals("abc", source.readUtf8Line()) + assertEquals("def", source.readUtf8Line()) + assertNull(source.readUtf8Line()) + } + + @Test + fun bufferedReaderCompatibleWithTrailingNewline() { + data.writeUtf8("abc\ndef\n") + assertEquals("abc", source.readUtf8Line()) + assertEquals("def", source.readUtf8Line()) + assertNull(source.readUtf8Line()) + } + + companion object { + @JvmStatic + @Parameters(name = "{0}") + fun parameters(): List<Array<Any>> { + return listOf( + arrayOf( + object : Factory { + override fun create(data: Buffer) = data + override fun toString() = "Buffer" + }, + ), + arrayOf( + object : Factory { + override fun create(data: Buffer) = RealBufferedSource(data) + override fun toString() = "RealBufferedSource" + }, + ), + arrayOf( + object : Factory { + override fun create(data: Buffer): BufferedSource { + return RealBufferedSource( + object : ForwardingSource(data) { + override fun read(sink: Buffer, byteCount: Long): Long { + return super.read(sink, 1L.coerceAtMost(byteCount)) + } + }, + ) + } + + override fun toString() = "Slow RealBufferedSource" + }, + ), + ) + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/SegmentSharingTest.kt b/okio/src/jvmTest/kotlin/okio/SegmentSharingTest.kt index 73e74d99..df4a9d5e 100644 --- a/okio/src/jvmTest/kotlin/okio/SegmentSharingTest.kt +++ b/okio/src/jvmTest/kotlin/okio/SegmentSharingTest.kt @@ -15,14 +15,14 @@ */ package okio +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue import okio.ByteString.Companion.encodeUtf8 import okio.TestUtil.assertEquivalent import okio.TestUtil.bufferWithSegments import okio.TestUtil.takeAllPoolSegments import org.junit.Test -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertTrue /** Tests behavior optimized by sharing segments between buffers and byte strings. */ class SegmentSharingTest { @@ -40,12 +40,12 @@ class SegmentSharingTest { @Test fun snapshotGetByte() { val byteString = bufferWithSegments(xs, ys, zs).snapshot() - assertEquals('x', byteString[0].toChar()) - assertEquals('x', byteString[xs.length - 1].toChar()) - assertEquals('y', byteString[xs.length].toChar()) - assertEquals('y', byteString[xs.length + ys.length - 1].toChar()) - assertEquals('z', byteString[xs.length + ys.length].toChar()) - assertEquals('z', byteString[xs.length + ys.length + zs.length - 1].toChar()) + assertEquals('x', byteString[0].toInt().toChar()) + assertEquals('x', byteString[xs.length - 1].toInt().toChar()) + assertEquals('y', byteString[xs.length].toInt().toChar()) + assertEquals('y', byteString[xs.length + ys.length - 1].toInt().toChar()) + assertEquals('z', byteString[xs.length + ys.length].toInt().toChar()) + assertEquals('z', byteString[xs.length + ys.length + zs.length - 1].toInt().toChar()) assertFailsWith<IndexOutOfBoundsException> { byteString[-1] } diff --git a/okio/src/jvmTest/kotlin/okio/SocketTimeoutTest.kt b/okio/src/jvmTest/kotlin/okio/SocketTimeoutTest.kt new file mode 100644 index 00000000..f38796f7 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/SocketTimeoutTest.kt @@ -0,0 +1,141 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.EOFException +import java.io.IOException +import java.io.InputStream +import java.io.OutputStream +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.net.SocketTimeoutException +import java.util.concurrent.TimeUnit +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Test + +class SocketTimeoutTest { + @Test + fun readWithoutTimeout() { + val socket = socket(ONE_MB, 0) + val source = socket.source().buffer() + source.timeout().timeout(5000, TimeUnit.MILLISECONDS) + source.require(ONE_MB.toLong()) + socket.close() + } + + @Test + fun readWithTimeout() { + val socket = socket(0, 0) + val source = socket.source().buffer() + source.timeout().timeout(250, TimeUnit.MILLISECONDS) + try { + source.require(ONE_MB.toLong()) + fail() + } catch (expected: SocketTimeoutException) { + } + socket.close() + } + + @Test + fun writeWithoutTimeout() { + val socket = socket(0, ONE_MB) + val sink: Sink = socket.sink().buffer() + sink.timeout().timeout(500, TimeUnit.MILLISECONDS) + val data = ByteArray(ONE_MB) + sink.write(Buffer().write(data), data.size.toLong()) + sink.flush() + socket.close() + } + + @Test + fun writeWithTimeout() { + val socket = socket(0, 0) + val sink = socket.sink() + sink.timeout().timeout(500, TimeUnit.MILLISECONDS) + val data = ByteArray(ONE_MB) + val start = System.nanoTime() + try { + sink.write(Buffer().write(data), data.size.toLong()) + sink.flush() + fail() + } catch (expected: SocketTimeoutException) { + } + val elapsed = System.nanoTime() - start + socket.close() + assertTrue("elapsed: $elapsed", TimeUnit.NANOSECONDS.toMillis(elapsed) >= 500) + assertTrue("elapsed: $elapsed", TimeUnit.NANOSECONDS.toMillis(elapsed) <= 750) + } + + companion object { + // The size of the socket buffers to use. Less than half the data transferred during tests to + // ensure send and receive buffers are flooded and any necessary blocking behavior takes place. + private const val SOCKET_BUFFER_SIZE = 256 * 1024 + private const val ONE_MB = 1024 * 1024 + + /** + * Returns a socket that can read `readableByteCount` incoming bytes and + * will accept `writableByteCount` written bytes. The socket will idle + * for 5 seconds when the required data has been read and written. + */ + fun socket(readableByteCount: Int, writableByteCount: Int): Socket { + val inetAddress = InetAddress.getByName("localhost") + val serverSocket = ServerSocket(0, 50, inetAddress) + serverSocket.reuseAddress = true + serverSocket.receiveBufferSize = SOCKET_BUFFER_SIZE + val peer: Thread = object : Thread("peer") { + override fun run() { + var socket: Socket? = null + try { + socket = serverSocket.accept() + socket.sendBufferSize = SOCKET_BUFFER_SIZE + writeFully(socket.getOutputStream(), readableByteCount) + readFully(socket.getInputStream(), writableByteCount) + sleep(5000) // Sleep 5 seconds so the peer can close the connection. + } catch (ignored: Exception) { + } finally { + try { + socket?.close() + } catch (ignored: IOException) { + } + } + } + } + peer.start() + val socket = Socket(serverSocket.inetAddress, serverSocket.localPort) + socket.receiveBufferSize = SOCKET_BUFFER_SIZE + socket.sendBufferSize = SOCKET_BUFFER_SIZE + return socket + } + + private fun writeFully(out: OutputStream, byteCount: Int) { + out.write(ByteArray(byteCount)) + out.flush() + } + + private fun readFully(`in`: InputStream, byteCount: Int): ByteArray { + var count = 0 + val result = ByteArray(byteCount) + while (count < byteCount) { + val read = `in`.read(result, count, result.size - count) + if (read == -1) throw EOFException() + count += read + } + return result + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/TestUtil.kt b/okio/src/jvmTest/kotlin/okio/TestUtil.kt index f9f61e62..703986e7 100644 --- a/okio/src/jvmTest/kotlin/okio/TestUtil.kt +++ b/okio/src/jvmTest/kotlin/okio/TestUtil.kt @@ -15,16 +15,17 @@ */ package okio -import okio.ByteString.Companion.encodeUtf8 -import org.junit.Assume import java.io.IOException import java.io.ObjectInputStream import java.io.ObjectOutputStream import java.io.Serializable +import java.util.Locale import java.util.Random import kotlin.test.assertEquals import kotlin.test.assertFalse import kotlin.test.assertTrue +import okio.ByteString.Companion.encodeUtf8 +import org.junit.Assume object TestUtil { // Necessary to make an internal member visible to Java. @@ -175,10 +176,10 @@ object TestUtil { } /** Serializes original to bytes, then deserializes those bytes and returns the result. */ + // Assume serialization doesn't change types. @Suppress("UNCHECKED_CAST") @Throws(Exception::class) @JvmStatic - // Assume serialization doesn't change types. fun <T : Serializable> reserialize(original: T): T { val buffer = Buffer() val out = ObjectOutputStream(buffer.outputStream()) @@ -298,5 +299,5 @@ object TestUtil { return reversed.toShort() } - fun assumeNotWindows() = Assume.assumeFalse(System.getProperty("os.name").toLowerCase().contains("win")) + fun assumeNotWindows() = Assume.assumeFalse(System.getProperty("os.name").lowercase(Locale.getDefault()).contains("win")) } diff --git a/okio/src/jvmTest/kotlin/okio/ThrottlerTakeTest.kt b/okio/src/jvmTest/kotlin/okio/ThrottlerTakeTest.kt index aab94b30..5349dd3d 100644 --- a/okio/src/jvmTest/kotlin/okio/ThrottlerTakeTest.kt +++ b/okio/src/jvmTest/kotlin/okio/ThrottlerTakeTest.kt @@ -15,9 +15,9 @@ */ package okio +import java.util.concurrent.TimeUnit import org.assertj.core.api.Assertions.assertThat import org.junit.Test -import java.util.concurrent.TimeUnit class ThrottlerTakeTest { private var nowNanos = 0L diff --git a/okio/src/jvmTest/kotlin/okio/ThrottlerTest.kt b/okio/src/jvmTest/kotlin/okio/ThrottlerTest.kt index 0b151794..baa487c1 100644 --- a/okio/src/jvmTest/kotlin/okio/ThrottlerTest.kt +++ b/okio/src/jvmTest/kotlin/okio/ThrottlerTest.kt @@ -15,12 +15,11 @@ */ package okio +import kotlin.test.Ignore import okio.TestUtil.randomSource import org.junit.After import org.junit.Before import org.junit.Test -import java.util.concurrent.Executors -import kotlin.test.Ignore @Ignore("These tests are flaky and fail on slower hardware, need to be improved") class ThrottlerTest { @@ -31,7 +30,7 @@ class ThrottlerTest { private val throttlerSlow = Throttler() private val threads = 4 - private val executorService = Executors.newFixedThreadPool(threads) + private val executorService = TestingExecutors.newExecutorService(threads) private var stopwatch = Stopwatch() @Before fun setup() { diff --git a/okio/src/jvmTest/kotlin/okio/TimeoutFactory.kt b/okio/src/jvmTest/kotlin/okio/TimeoutFactory.kt new file mode 100644 index 00000000..3a9cac7c --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/TimeoutFactory.kt @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2023 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +enum class TimeoutFactory { + BASE { + override fun newTimeout() = Timeout() + }, + + FORWARDING { + override fun newTimeout() = ForwardingTimeout(BASE.newTimeout()) + }, + + ASYNC { + override fun newTimeout() = AsyncTimeout() + }, + ; + + abstract fun newTimeout(): Timeout +} diff --git a/okio/src/jvmTest/kotlin/okio/TimeoutTest.kt b/okio/src/jvmTest/kotlin/okio/TimeoutTest.kt new file mode 100644 index 00000000..2d5c7989 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/TimeoutTest.kt @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.util.concurrent.TimeUnit +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Rule +import org.junit.Test +import org.junit.rules.Timeout as JUnitTimeout + +class TimeoutTest { + @JvmField @Rule + val timeout = JUnitTimeout(5, TimeUnit.SECONDS) + + private val executorService = TestingExecutors.newExecutorService(1) + + @After + @Throws(Exception::class) + fun tearDown() { + executorService.shutdown() + } + + @Test fun intersectWithReturnsAValue() { + val timeoutA = Timeout() + val timeoutB = Timeout() + + val s = timeoutA.intersectWith(timeoutB) { "hello" } + assertEquals("hello", s) + } + + @Test fun intersectWithPrefersSmallerTimeout() { + val timeoutA = Timeout() + timeoutA.timeout(smallerTimeoutNanos, TimeUnit.NANOSECONDS) + + val timeoutB = Timeout() + timeoutB.timeout(biggerTimeoutNanos, TimeUnit.NANOSECONDS) + + timeoutA.intersectWith(timeoutB) { + assertEquals(smallerTimeoutNanos, timeoutA.timeoutNanos()) + assertEquals(biggerTimeoutNanos, timeoutB.timeoutNanos()) + } + timeoutB.intersectWith(timeoutA) { + assertEquals(smallerTimeoutNanos, timeoutA.timeoutNanos()) + assertEquals(smallerTimeoutNanos, timeoutB.timeoutNanos()) + } + assertEquals(smallerTimeoutNanos, timeoutA.timeoutNanos()) + assertEquals(biggerTimeoutNanos, timeoutB.timeoutNanos()) + } + + @Test fun intersectWithPrefersNonZeroTimeout() { + val timeoutA = Timeout() + + val timeoutB = Timeout() + timeoutB.timeout(biggerTimeoutNanos, TimeUnit.NANOSECONDS) + + timeoutA.intersectWith(timeoutB) { + assertEquals(biggerTimeoutNanos, timeoutA.timeoutNanos()) + assertEquals(biggerTimeoutNanos, timeoutB.timeoutNanos()) + } + timeoutB.intersectWith(timeoutA) { + assertEquals(0L, timeoutA.timeoutNanos()) + assertEquals(biggerTimeoutNanos, timeoutB.timeoutNanos()) + } + assertEquals(0L, timeoutA.timeoutNanos()) + assertEquals(biggerTimeoutNanos, timeoutB.timeoutNanos()) + } + + @Test fun intersectWithPrefersSmallerDeadline() { + val timeoutA = Timeout() + timeoutA.deadlineNanoTime(smallerDeadlineNanos) + + val timeoutB = Timeout() + timeoutB.deadlineNanoTime(biggerDeadlineNanos) + + timeoutA.intersectWith(timeoutB) { + assertEquals(smallerDeadlineNanos, timeoutA.deadlineNanoTime()) + assertEquals(biggerDeadlineNanos, timeoutB.deadlineNanoTime()) + } + timeoutB.intersectWith(timeoutA) { + assertEquals(smallerDeadlineNanos, timeoutA.deadlineNanoTime()) + assertEquals(smallerDeadlineNanos, timeoutB.deadlineNanoTime()) + } + assertEquals(smallerDeadlineNanos, timeoutA.deadlineNanoTime()) + assertEquals(biggerDeadlineNanos, timeoutB.deadlineNanoTime()) + } + + @Test fun intersectWithPrefersNonZeroDeadline() { + val timeoutA = Timeout() + + val timeoutB = Timeout() + timeoutB.deadlineNanoTime(biggerDeadlineNanos) + + timeoutA.intersectWith(timeoutB) { + assertEquals(biggerDeadlineNanos, timeoutA.deadlineNanoTime()) + assertEquals(biggerDeadlineNanos, timeoutB.deadlineNanoTime()) + } + timeoutB.intersectWith(timeoutA) { + assertFalse(timeoutA.hasDeadline()) + assertEquals(biggerDeadlineNanos, timeoutB.deadlineNanoTime()) + } + assertFalse(timeoutA.hasDeadline()) + assertEquals(biggerDeadlineNanos, timeoutB.deadlineNanoTime()) + } + + companion object { + val smallerTimeoutNanos = TimeUnit.MILLISECONDS.toNanos(500L) + val biggerTimeoutNanos = TimeUnit.MILLISECONDS.toNanos(1500L) + + val smallerDeadlineNanos = TimeUnit.MILLISECONDS.toNanos(500L) + val biggerDeadlineNanos = TimeUnit.MILLISECONDS.toNanos(1500L) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/Utf8Test.kt b/okio/src/jvmTest/kotlin/okio/Utf8Test.kt new file mode 100644 index 00000000..3657f136 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/Utf8Test.kt @@ -0,0 +1,307 @@ +/* + * Copyright (C) 2014 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.EOFException +import kotlin.text.Charsets.UTF_8 +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.of +import okio.TestUtil.SEGMENT_SIZE +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Test + +class Utf8Test { + @Test + fun oneByteCharacters() { + assertEncoded("00", 0x00) // Smallest 1-byte character. + assertEncoded("20", ' '.code) + assertEncoded("7e", '~'.code) + assertEncoded("7f", 0x7f) // Largest 1-byte character. + } + + @Test + fun twoByteCharacters() { + assertEncoded("c280", 0x0080) // Smallest 2-byte character. + assertEncoded("c3bf", 0x00ff) + assertEncoded("c480", 0x0100) + assertEncoded("dfbf", 0x07ff) // Largest 2-byte character. + } + + @Test + fun threeByteCharacters() { + assertEncoded("e0a080", 0x0800) // Smallest 3-byte character. + assertEncoded("e0bfbf", 0x0fff) + assertEncoded("e18080", 0x1000) + assertEncoded("e1bfbf", 0x1fff) + assertEncoded("ed8080", 0xd000) + assertEncoded("ed9fbf", 0xd7ff) // Largest character lower than the min surrogate. + assertEncoded("ee8080", 0xe000) // Smallest character greater than the max surrogate. + assertEncoded("eebfbf", 0xefff) + assertEncoded("ef8080", 0xf000) + assertEncoded("efbfbf", 0xffff) // Largest 3-byte character. + } + + @Test + fun fourByteCharacters() { + assertEncoded("f0908080", 0x010000) // Smallest surrogate pair. + assertEncoded("f48fbfbf", 0x10ffff) // Largest code point expressible by UTF-16. + } + + @Test + fun danglingHighSurrogate() { + assertStringEncoded("3f", "\ud800") // "?" + } + + @Test + fun lowSurrogateWithoutHighSurrogate() { + assertStringEncoded("3f", "\udc00") // "?" + } + + @Test + fun highSurrogateFollowedByNonSurrogate() { + assertStringEncoded("3f61", "\ud800\u0061") // "?a": Following character is too low. + assertStringEncoded("3fee8080", "\ud800\ue000") // "?\ue000": Following character is too high. + } + + @Test + fun doubleLowSurrogate() { + assertStringEncoded("3f3f", "\udc00\udc00") // "??" + } + + @Test + fun doubleHighSurrogate() { + assertStringEncoded("3f3f", "\ud800\ud800") // "??" + } + + @Test + fun highSurrogateLowSurrogate() { + assertStringEncoded("3f3f", "\udc00\ud800") // "??" + } + + @Test + fun multipleSegmentString() { + val a = "a".repeat(SEGMENT_SIZE + SEGMENT_SIZE + 1) + val encoded = Buffer().writeUtf8(a) + val expected = Buffer().write(a.toByteArray(UTF_8)) + assertEquals(expected, encoded) + } + + @Test + fun stringSpansSegments() { + val buffer = Buffer() + val a = "a".repeat(SEGMENT_SIZE - 1) + val b = "bb" + val c = "c".repeat(SEGMENT_SIZE - 1) + buffer.writeUtf8(a) + buffer.writeUtf8(b) + buffer.writeUtf8(c) + assertEquals(a + b + c, buffer.readUtf8()) + } + + @Test + fun readEmptyBufferThrowsEofException() { + val buffer = Buffer() + try { + buffer.readUtf8CodePoint() + fail() + } catch (expected: EOFException) { + } + } + + @Test + fun readLeadingContinuationByteReturnsReplacementCharacter() { + val buffer = Buffer() + buffer.writeByte(0xbf) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertTrue(buffer.exhausted()) + } + + @Test + fun readMissingContinuationBytesThrowsEofException() { + val buffer = Buffer() + buffer.writeByte(0xdf) + try { + buffer.readUtf8CodePoint() + fail() + } catch (expected: EOFException) { + } + assertFalse(buffer.exhausted()) // Prefix byte wasn't consumed. + } + + @Test + fun readTooLargeCodepointReturnsReplacementCharacter() { + // 5-byte and 6-byte code points are not supported. + val buffer = Buffer() + buffer.write("f888808080".decodeHex()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertTrue(buffer.exhausted()) + } + + @Test + fun readNonContinuationBytesReturnsReplacementCharacter() { + // Use a non-continuation byte where a continuation byte is expected. + val buffer = Buffer() + buffer.write("df20".decodeHex()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertEquals(0x20, buffer.readUtf8CodePoint().toLong()) // Non-continuation character not consumed. + assertTrue(buffer.exhausted()) + } + + @Test + fun readCodePointBeyondUnicodeMaximum() { + // A 4-byte encoding with data above the U+10ffff Unicode maximum. + val buffer = Buffer() + buffer.write("f4908080".decodeHex()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertTrue(buffer.exhausted()) + } + + @Test + fun readSurrogateCodePoint() { + val buffer = Buffer() + buffer.write("eda080".decodeHex()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertTrue(buffer.exhausted()) + buffer.write("edbfbf".decodeHex()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertTrue(buffer.exhausted()) + } + + @Test + fun readOverlongCodePoint() { + // Use 2 bytes to encode data that only needs 1 byte. + val buffer = Buffer() + buffer.write("c080".decodeHex()) + assertEquals(REPLACEMENT_CODE_POINT.toLong(), buffer.readUtf8CodePoint().toLong()) + assertTrue(buffer.exhausted()) + } + + @Test + fun writeSurrogateCodePoint() { + assertStringEncoded("ed9fbf", "\ud7ff") // Below lowest surrogate is okay. + assertStringEncoded("3f", "\ud800") // Lowest surrogate gets '?'. + assertStringEncoded("3f", "\udfff") // Highest surrogate gets '?'. + assertStringEncoded("ee8080", "\ue000") // Above highest surrogate is okay. + } + + @Test + fun writeCodePointBeyondUnicodeMaximum() { + val buffer = Buffer() + try { + buffer.writeUtf8CodePoint(0x110000) + fail() + } catch (expected: IllegalArgumentException) { + assertEquals("Unexpected code point: 0x110000", expected.message) + } + } + + @Test + fun size() { + assertEquals(0, "".utf8Size()) + assertEquals(3, "abc".utf8Size()) + assertEquals(16, "təˈranəˌsôr".utf8Size()) + } + + @Test + fun sizeWithBounds() { + assertEquals(0, "".utf8Size(0, 0)) + assertEquals(0, "abc".utf8Size(0, 0)) + assertEquals(1, "abc".utf8Size(1, 2)) + assertEquals(2, "abc".utf8Size(0, 2)) + assertEquals(3, "abc".utf8Size(0, 3)) + assertEquals(16, "təˈranəˌsôr".utf8Size(0, 11)) + assertEquals(5, "təˈranəˌsôr".utf8Size(3, 7)) + } + + @Test + fun sizeBoundsCheck() { + try { + null!!.utf8Size(0, 0) + fail() + } catch (expected: NullPointerException) { + } + try { + "abc".utf8Size(-1, 2) + fail() + } catch (expected: IllegalArgumentException) { + } + try { + "abc".utf8Size(2, 1) + fail() + } catch (expected: IllegalArgumentException) { + } + try { + "abc".utf8Size(1, 4) + fail() + } catch (expected: IllegalArgumentException) { + } + } + + private fun assertEncoded(hex: String, vararg codePoints: Int) { + assertCodePointEncoded(hex, *codePoints) + assertCodePointDecoded(hex, *codePoints) + assertStringEncoded(hex, String(codePoints, 0, codePoints.size)) + } + + private fun assertCodePointEncoded(hex: String, vararg codePoints: Int) { + val buffer = Buffer() + for (codePoint in codePoints) { + buffer.writeUtf8CodePoint(codePoint) + } + assertEquals(buffer.readByteString(), hex.decodeHex()) + } + + private fun assertCodePointDecoded(hex: String, vararg codePoints: Int) { + val buffer = Buffer().write(hex.decodeHex()) + for (codePoint in codePoints) { + assertEquals(codePoint.toLong(), buffer.readUtf8CodePoint().toLong()) + } + assertTrue(buffer.exhausted()) + } + + private fun assertStringEncoded(hex: String, string: String) { + val expectedUtf8 = hex.decodeHex() + + // Confirm our expectations are consistent with the platform. + val platformUtf8 = of(*string.toByteArray(charset("UTF-8"))) + assertEquals(expectedUtf8, platformUtf8) + + // Confirm our implementation matches those expectations. + val actualUtf8 = Buffer().writeUtf8(string).readByteString() + assertEquals(expectedUtf8, actualUtf8) + + // Confirm we are consistent when writing one code point at a time. + val bufferUtf8 = Buffer() + var i = 0 + while (i < string.length) { + val c = string.codePointAt(i) + bufferUtf8.writeUtf8CodePoint(c) + i += Character.charCount(c) + } + assertEquals(expectedUtf8, bufferUtf8.readByteString()) + + // Confirm we are consistent when measuring lengths. + assertEquals(expectedUtf8.size.toLong(), string.utf8Size()) + assertEquals(expectedUtf8.size.toLong(), string.utf8Size(0, string.length)) + } +} diff --git a/okio/src/jvmTest/kotlin/okio/WaitUntilNotifiedTest.kt b/okio/src/jvmTest/kotlin/okio/WaitUntilNotifiedTest.kt new file mode 100644 index 00000000..44c6bbfd --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/WaitUntilNotifiedTest.kt @@ -0,0 +1,236 @@ +/* + * Copyright (C) 2016 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import java.io.InterruptedIOException +import java.util.concurrent.TimeUnit +import okio.TestUtil.assumeNotWindows +import okio.TestingExecutors.newScheduledExecutorService +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Assert.fail +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters + +@RunWith(Parameterized::class) +class WaitUntilNotifiedTest( + factory: TimeoutFactory, +) { + private val timeout = factory.newTimeout() + private val executorService = newScheduledExecutorService(0) + + @After + fun tearDown() { + executorService.shutdown() + } + + @Test + @Synchronized + fun notified() { + timeout.timeout(5000, TimeUnit.MILLISECONDS) + val start = now() + executorService.schedule( + { + synchronized(this@WaitUntilNotifiedTest) { + (this as Object).notify() + } + }, + 1000, + TimeUnit.MILLISECONDS, + ) + timeout.waitUntilNotified(this) + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun timeout() { + assumeNotWindows() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + val start = now() + try { + timeout.waitUntilNotified(this) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun deadline() { + assumeNotWindows() + timeout.deadline(1000, TimeUnit.MILLISECONDS) + val start = now() + try { + timeout.waitUntilNotified(this) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun deadlineBeforeTimeout() { + assumeNotWindows() + timeout.timeout(5000, TimeUnit.MILLISECONDS) + timeout.deadline(1000, TimeUnit.MILLISECONDS) + val start = now() + try { + timeout.waitUntilNotified(this) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun timeoutBeforeDeadline() { + assumeNotWindows() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.deadline(5000, TimeUnit.MILLISECONDS) + val start = now() + try { + timeout.waitUntilNotified(this) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun deadlineAlreadyReached() { + assumeNotWindows() + timeout.deadlineNanoTime(System.nanoTime()) + val start = now() + try { + timeout.waitUntilNotified(this) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(0.0, start) + } + + @Test + @Synchronized + fun threadInterrupted() { + assumeNotWindows() + val start = now() + Thread.currentThread().interrupt() + try { + timeout.waitUntilNotified(this) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("interrupted", expected.message) + assertTrue(Thread.interrupted()) + } + assertElapsed(0.0, start) + } + + @Test + @Synchronized + fun threadInterruptedOnThrowIfReached() { + assumeNotWindows() + Thread.currentThread().interrupt() + try { + timeout.throwIfReached() + fail() + } catch (expected: InterruptedIOException) { + assertEquals("interrupted", expected.message) + assertTrue(Thread.interrupted()) + } + } + + @Test + @Synchronized + fun cancelBeforeWaitDoesNothing() { + assumeNotWindows() + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancel() + val start = now() + try { + timeout.waitUntilNotified(this) + fail() + } catch (expected: InterruptedIOException) { + assertEquals("timeout", expected.message) + } + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun canceledTimeoutDoesNotThrowWhenNotNotifiedOnTime() { + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancelLater(500) + + val start = now() + timeout.waitUntilNotified(this) // Returns early but doesn't throw. + assertElapsed(1000.0, start) + } + + @Test + @Synchronized + fun multipleCancelsAreIdempotent() { + timeout.timeout(1000, TimeUnit.MILLISECONDS) + timeout.cancelLater(250) + timeout.cancelLater(500) + timeout.cancelLater(750) + + val start = now() + timeout.waitUntilNotified(this) // Returns early but doesn't throw. + assertElapsed(1000.0, start) + } + + /** Returns the nanotime in milliseconds as a double for measuring timeouts. */ + private fun now(): Double { + return System.nanoTime() / 1000000.0 + } + + /** + * Fails the test unless the time from start until now is duration, accepting differences in + * -50..+450 milliseconds. + */ + private fun assertElapsed(duration: Double, start: Double) { + assertEquals(duration, now() - start - 200.0, 250.0) + } + + private fun Timeout.cancelLater(delay: Long) { + executorService.schedule( + { + cancel() + }, + delay, + TimeUnit.MILLISECONDS, + ) + } + + companion object { + @Parameters(name = "{0}") + @JvmStatic + fun parameters(): List<Array<out Any?>> = TimeoutFactory.entries.map { arrayOf(it) } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/ZipBuilder.kt b/okio/src/jvmTest/kotlin/okio/ZipBuilder.kt new file mode 100644 index 00000000..1abcc031 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/ZipBuilder.kt @@ -0,0 +1,158 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import org.junit.Assume.assumeTrue + +/** + * Execute the `zip` command line program to create reference zip files for testing. + * + * Unfortunately the `zip` command line is limited in its ability to create exactly the zip files + * we want for testing. In particular, the only way it will create zip64 archives is if the input + * file is received from a UNIX pipe. (In such cases the file length is unknown in advance, and so + * the tool uses zip64 for the possibility of a very large file.) Files received from a pipe are + * always named `-`. + */ +class ZipBuilder( + private val directory: Path, +) { + private val fileSystem = FileSystem.SYSTEM + + private val entries = mutableListOf<Entry>() + private val options = mutableListOf<String>() + private var archiveComment: String = "" + + @JvmOverloads + fun addEntry( + path: String, + content: String? = null, + directory: Boolean = false, + comment: String = "", + modifiedAt: String? = null, + accessedAt: String? = null, + zip64: Boolean = false, + ) = apply { + entries += Entry(path, content, directory, comment, modifiedAt, accessedAt, zip64) + } + + fun addOption(option: String) = apply { options += option } + + fun archiveComment(archiveComment: String) = apply { this.archiveComment = archiveComment } + + fun build(): Path { + assumeTrue("ZipBuilder doesn't work on Windows", Path.DIRECTORY_SEPARATOR == "/") + + val archive = directory / "${randomToken(16)}.zip" + val anyZip64 = entries.any { it.zip64 } + + require(!anyZip64 || entries.size == 1) { + "ZipBuilder permits at most one zip64 entry" + } + require(!anyZip64 || archiveComment.isEmpty()) { + "Cannot combine archiveComment with zip64" + } + + val promptForComments = entries.any { it.comment.isNotEmpty() } + if (promptForComments) { + options += "--entry-comments" + } + if (archiveComment.isNotEmpty()) { + options += "--archive-comment" + } + + val command = mutableListOf<String>() + command += "zip" + command += options + command += archive.toString() + + for (entry in entries) { + if (!entry.zip64) { + val absolutePath = directory / entry.path + fileSystem.createDirectories(absolutePath.parent!!) + + if (entry.directory) { + fileSystem.createDirectories(absolutePath) + } else { + fileSystem.write(absolutePath) { + writeUtf8(entry.content!!) + } + } + + if (entry.modifiedAt != null) { + touch("-m", absolutePath, entry.modifiedAt) + } + if (entry.accessedAt != null) { + touch("-a", absolutePath, entry.accessedAt) + } + } + command += entry.path + } + + val process = ProcessBuilder() + .command(command) + .directory(directory.toFile()) + .redirectOutput(ProcessBuilder.Redirect.INHERIT) + .redirectError(ProcessBuilder.Redirect.INHERIT) + .start() + + process.outputStream.sink().buffer().use { sink -> + if (anyZip64) { + sink.writeUtf8(entries.single().content!!) + } + if (promptForComments) { + for (entry in entries) { + sink.writeUtf8(entry.comment) + sink.writeUtf8("\n") + sink.flush() + } + } + sink.writeUtf8(archiveComment) + } + + val result = process.waitFor() + require(result == 0) { "process failed: $command" } + + return archive + } + + private fun touch(option: String, absolutePath: Path, date: String) { + val exitCode = ProcessBuilder() + .command("touch", option, "-t", date, absolutePath.toString()) + .apply { environment()["TZ"] = "UTC" } + .start() + .waitFor() + require(exitCode == 0) + } + + private class Entry( + val path: String, + val content: String?, + val directory: Boolean, + val comment: String, + val modifiedAt: String?, + val accessedAt: String?, + val zip64: Boolean, + ) { + init { + require(directory != (content != null)) { "must be a directory or have content" } + if (zip64) { + require(path == "-") { "zip64 file name must be '-'" } + require(comment == "") { "zip64 must not have comments" } + require(modifiedAt == null) { "zip64 must not have modifiedAt" } + } + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/ZipFileSystemJavaTest.kt b/okio/src/jvmTest/kotlin/okio/ZipFileSystemJavaTest.kt new file mode 100644 index 00000000..6d636145 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/ZipFileSystemJavaTest.kt @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import okio.Path.Companion.toPath +import org.assertj.core.api.Assertions +import org.junit.Before +import org.junit.Test + +class ZipFileSystemJavaTest { + private val fileSystem = FileSystem.SYSTEM + private val base = FileSystem.SYSTEM_TEMPORARY_DIRECTORY.div(randomToken(16)) + + @Before + fun setUp() { + fileSystem.createDirectory(base) + } + + @Test + fun zipFileSystemApi() { + val zipPath = ZipBuilder(base) + .addEntry("hello.txt", "Hello World") + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + zipFileSystem.source("hello.txt".toPath(false)).buffer().use { source -> + val content = source.readUtf8() + Assertions.assertThat(content).isEqualTo("Hello World") + } + } +} diff --git a/okio/src/jvmTest/kotlin/okio/ZipFileSystemTest.kt b/okio/src/jvmTest/kotlin/okio/ZipFileSystemTest.kt new file mode 100644 index 00000000..fd7aa46b --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/ZipFileSystemTest.kt @@ -0,0 +1,502 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio + +import kotlin.test.assertFailsWith +import kotlinx.datetime.Instant +import okio.ByteString.Companion.decodeHex +import okio.ByteString.Companion.encodeUtf8 +import okio.Path.Companion.toPath +import org.assertj.core.api.Assertions.assertThat +import org.junit.Before +import org.junit.Test + +class ZipFileSystemTest { + private val fileSystem = FileSystem.SYSTEM + private var base = FileSystem.SYSTEM_TEMPORARY_DIRECTORY / randomToken(16) + + @Before + fun setUp() { + fileSystem.createDirectory(base) + } + + @Test + fun emptyZip() { + // ZipBuilder cannot write empty zips. + val zipPath = base / "empty.zip" + fileSystem.write(zipPath) { + write("504b0506000000000000000000000000000000000000".decodeHex()) + } + + val zipFileSystem = fileSystem.openZip(zipPath) + assertThat(zipFileSystem.list("/".toPath())).isEmpty() + } + + @Test + fun emptyZipWithPrependedData() { + // ZipBuilder cannot write empty zips. + val zipPath = base / "empty.zip" + fileSystem.write(zipPath) { + writeUtf8("Hello I'm junk data prepended to the ZIP!") + write("504b0506000000000000000000000000000000000000".decodeHex()) + } + + val zipFileSystem = fileSystem.openZip(zipPath) + assertThat(zipFileSystem.list("/".toPath())).isEmpty() + } + + @Test + fun zipWithFiles() { + val zipPath = ZipBuilder(base) + .addEntry("hello.txt", "Hello World") + .addEntry("directory/subdirectory/child.txt", "Another file!") + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + assertThat(zipFileSystem.read("hello.txt".toPath()) { readUtf8() }) + .isEqualTo("Hello World") + + assertThat(zipFileSystem.read("directory/subdirectory/child.txt".toPath()) { readUtf8() }) + .isEqualTo("Another file!") + + assertThat(zipFileSystem.list("/".toPath())) + .hasSameElementsAs(listOf("/hello.txt".toPath(), "/directory".toPath())) + assertThat(zipFileSystem.list("/directory".toPath())) + .containsExactly("/directory/subdirectory".toPath()) + assertThat(zipFileSystem.list("/directory/subdirectory".toPath())) + .containsExactly("/directory/subdirectory/child.txt".toPath()) + } + + /** + * Note that the zip tool does not compress files that don't benefit from it. Examples above like + * 'Hello World' are stored, not deflated. + */ + @Test + fun zipWithDeflate() { + val content = "Android\n".repeat(1000) + val zipPath = ZipBuilder(base) + .addEntry("a.txt", content) + .addOption("--compression-method") + .addOption("deflate") + .build() + assertThat(fileSystem.metadata(zipPath).size).isLessThan(content.length.toLong()) + val zipFileSystem = fileSystem.openZip(zipPath) + + assertThat(zipFileSystem.read("a.txt".toPath()) { readUtf8() }) + .isEqualTo(content) + } + + @Test + fun zipWithStore() { + val content = "Android\n".repeat(1000) + val zipPath = ZipBuilder(base) + .addEntry("a.txt", content) + .addOption("--compression-method") + .addOption("store") + .build() + assertThat(fileSystem.metadata(zipPath).size).isGreaterThan(content.length.toLong()) + val zipFileSystem = fileSystem.openZip(zipPath) + + assertThat(zipFileSystem.read("a.txt".toPath()) { readUtf8() }) + .isEqualTo(content) + } + + /** + * Confirm we can read zip files that have file comments, even if these comments are not exposed + * in the public API. + */ + @Test + fun zipWithFileComments() { + val zipPath = ZipBuilder(base) + .addEntry("a.txt", "Android", comment = "A is for Android") + .addEntry("b.txt", "Banana", comment = "B or not to Be") + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + assertThat(zipFileSystem.read("a.txt".toPath()) { readUtf8() }) + .isEqualTo("Android") + + assertThat(zipFileSystem.read("b.txt".toPath()) { readUtf8() }) + .isEqualTo("Banana") + } + + @Test + fun zipWithFileModifiedDate() { + val zipPath = ZipBuilder(base) + .addEntry( + path = "a.txt", + content = "Android", + modifiedAt = "200102030405.06", + accessedAt = "200102030405.07", + ) + .addEntry( + path = "b.txt", + content = "Banana", + modifiedAt = "200908070605.04", + accessedAt = "200908070605.03", + ) + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + zipFileSystem.metadata("a.txt".toPath()) + .apply { + assertThat(isRegularFile).isTrue() + assertThat(isDirectory).isFalse() + assertThat(size).isEqualTo(7L) + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isEqualTo("2001-02-03T04:05:06Z".toEpochMillis()) + assertThat(lastAccessedAtMillis).isEqualTo("2001-02-03T04:05:07Z".toEpochMillis()) + } + + zipFileSystem.metadata("b.txt".toPath()) + .apply { + assertThat(isRegularFile).isTrue() + assertThat(isDirectory).isFalse() + assertThat(size).isEqualTo(6L) + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isEqualTo("2009-08-07T06:05:04Z".toEpochMillis()) + assertThat(lastAccessedAtMillis).isEqualTo("2009-08-07T06:05:03Z".toEpochMillis()) + } + } + + /** Confirm we suffer UNIX limitations on our date format. */ + @Test + fun zipWithFileOutOfBoundsModifiedDate() { + val zipPath = ZipBuilder(base) + .addEntry( + path = "a.txt", + content = "Android", + modifiedAt = "196912310000.00", + accessedAt = "196912300000.00", + ) + .addEntry( + path = "b.txt", + content = "Banana", + modifiedAt = "203801190314.07", // Last UNIX date representable in 31 bits. + accessedAt = "203801190314.08", // Overflows! + ) + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + println(Instant.fromEpochMilliseconds(-2147483648000L)) + + zipFileSystem.metadata("a.txt".toPath()) + .apply { + assertThat(isRegularFile).isTrue() + assertThat(isDirectory).isFalse() + assertThat(size).isEqualTo(7L) + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isEqualTo("1969-12-31T00:00:00Z".toEpochMillis()) + assertThat(lastAccessedAtMillis).isEqualTo("1969-12-30T00:00:00Z".toEpochMillis()) + } + + // Greater than the upper bound wraps around. + zipFileSystem.metadata("b.txt".toPath()) + .apply { + assertThat(isRegularFile).isTrue() + assertThat(isDirectory).isFalse() + assertThat(size).isEqualTo(6L) + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isEqualTo("2038-01-19T03:14:07Z".toEpochMillis()) + assertThat(lastAccessedAtMillis).isEqualTo("1901-12-13T20:45:52Z".toEpochMillis()) + } + } + + /** + * Directories are optional in the zip file. But if we want metadata on them they must be stored. + * Note that this test adds the directories last; otherwise adding child files to them will cause + * their modified at times to change. + */ + @Test + fun zipWithDirectoryModifiedDate() { + val zipPath = ZipBuilder(base) + .addEntry("a/a.txt", "Android") + .addEntry( + path = "a", + directory = true, + modifiedAt = "200102030405.06", + accessedAt = "200102030405.07", + ) + .addEntry("b/b.txt", "Android") + .addEntry( + path = "b", + directory = true, + modifiedAt = "200908070605.04", + accessedAt = "200908070605.03", + ) + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + zipFileSystem.metadata("a".toPath()) + .apply { + assertThat(isRegularFile).isFalse() + assertThat(isDirectory).isTrue() + assertThat(size).isNull() + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isEqualTo("2001-02-03T04:05:06Z".toEpochMillis()) + assertThat(lastAccessedAtMillis).isEqualTo("2001-02-03T04:05:07Z".toEpochMillis()) + } + assertThat(zipFileSystem.list("a".toPath())).containsExactly("/a/a.txt".toPath()) + + zipFileSystem.metadata("b".toPath()) + .apply { + assertThat(isRegularFile).isFalse() + assertThat(isDirectory).isTrue() + assertThat(size).isNull() + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isEqualTo("2009-08-07T06:05:04Z".toEpochMillis()) + assertThat(lastAccessedAtMillis).isEqualTo("2009-08-07T06:05:03Z".toEpochMillis()) + } + assertThat(zipFileSystem.list("b".toPath())).containsExactly("/b/b.txt".toPath()) + } + + @Test + fun zipWithModifiedDate() { + val zipPath = ZipBuilder(base) + .addEntry( + "a/a.txt", + modifiedAt = "197001010001.00", + accessedAt = "197001010002.00", + content = "Android", + ) + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + zipFileSystem.metadata("a/a.txt".toPath()) + .apply { + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isEqualTo("1970-01-01T00:01:00Z".toEpochMillis()) + assertThat(lastAccessedAtMillis).isEqualTo("1970-01-01T00:02:00Z".toEpochMillis()) + } + } + + /** Build a very small zip file with just a single empty directory. */ + @Test + fun zipWithEmptyDirectory() { + val zipPath = ZipBuilder(base) + .addEntry( + path = "a", + directory = true, + modifiedAt = "200102030405.06", + accessedAt = "200102030405.07", + ) + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + zipFileSystem.metadata("a".toPath()) + .apply { + assertThat(isRegularFile).isFalse() + assertThat(isDirectory).isTrue() + assertThat(size).isNull() + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isEqualTo("2001-02-03T04:05:06Z".toEpochMillis()) + assertThat(lastAccessedAtMillis).isEqualTo("2001-02-03T04:05:07Z".toEpochMillis()) + } + assertThat(zipFileSystem.list("a".toPath())).isEmpty() + } + + /** + * The `--no-dir-entries` option causes the zip file to omit the directories from the encoded + * file. Our implementation synthesizes these missing directories automatically. + */ + @Test + fun zipWithSyntheticDirectory() { + val zipPath = ZipBuilder(base) + .addEntry("a/a.txt", "Android") + .addEntry("a", directory = true) + .addEntry("b/b.txt", "Android") + .addEntry("b", directory = true) + .addOption("--no-dir-entries") + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + zipFileSystem.metadata("a".toPath()) + .apply { + assertThat(isRegularFile).isFalse() + assertThat(isDirectory).isTrue() + assertThat(size).isNull() + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isNull() + assertThat(lastAccessedAtMillis).isNull() + } + assertThat(zipFileSystem.list("a".toPath())).containsExactly("/a/a.txt".toPath()) + + zipFileSystem.metadata("b".toPath()) + .apply { + assertThat(isRegularFile).isFalse() + assertThat(isDirectory).isTrue() + assertThat(size).isNull() + assertThat(createdAtMillis).isNull() + assertThat(lastModifiedAtMillis).isNull() + assertThat(lastAccessedAtMillis).isNull() + } + assertThat(zipFileSystem.list("b".toPath())).containsExactly("/b/b.txt".toPath()) + } + + /** + * Force a file to be encoded with zip64 metadata. We use a pipe to force the zip command to + * create a zip64 archive; otherwise we'd need to add a very large file to get this format. + */ + @Test + fun zip64() { + val zipPath = ZipBuilder(base) + .addEntry("-", "Android", zip64 = true) + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + assertThat(zipFileSystem.read("-".toPath()) { readUtf8() }) + .isEqualTo("Android") + } + + /** + * Confirm we can read zip files with a full-archive comment, even if this comment is not surfaced + * in our API. + */ + @Test + fun zipWithArchiveComment() { + val zipPath = ZipBuilder(base) + .addEntry("a.txt", "Android") + .archiveComment("this comment applies to the entire archive") + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + assertThat(zipFileSystem.read("a.txt".toPath()) { readUtf8() }) + .isEqualTo("Android") + } + + @Test + fun cannotReadZipWithSpanning() { + // Spanned archives must be at least 64 KiB. + val largeFile = randomToken(length = 128 * 1024) + val zipPath = ZipBuilder(base) + .addEntry("large_file.txt", largeFile) + .addOption("--split-size") + .addOption("64k") + .build() + assertFailsWith<IOException> { + fileSystem.openZip(zipPath) + } + } + + @Test + fun cannotReadZipWithEncryption() { + val zipPath = ZipBuilder(base) + .addEntry("a.txt", "Android") + .addOption("--password") + .addOption("secret") + .build() + assertFailsWith<IOException> { + fileSystem.openZip(zipPath) + } + } + + @Test + fun zipTooShort() { + val zipPath = ZipBuilder(base) + .addEntry("a.txt", "Android") + .build() + + val prefix = fileSystem.read(zipPath) { readByteString(20) } + fileSystem.write(zipPath) { write(prefix) } + + assertFailsWith<IOException> { + fileSystem.openZip(zipPath) + } + } + + /** + * The zip format permits multiple files with the same names. For example, + * `kotlin-gradle-plugin-1.5.20.jar` contains two copies of + * `META-INF/kotlin-gradle-statistics.kotlin_module`. + * + * We used to crash on duplicates, but they are common in practice so now we prefer the last + * entry. This behavior is consistent with both [java.util.zip.ZipFile] and + * [java.nio.file.FileSystem]. + */ + @Test + fun filesOverlap() { + val zipPath = ZipBuilder(base) + .addEntry("hello.txt", "This is the first hello.txt") + .addEntry("xxxxx.xxx", "This is the second hello.txt") + .build() + val original = fileSystem.read(zipPath) { readByteString() } + val rewritten = original.replaceAll("xxxxx.xxx".encodeUtf8(), "hello.txt".encodeUtf8()) + fileSystem.write(zipPath) { write(rewritten) } + + val zipFileSystem = fileSystem.openZip(zipPath) + assertThat(zipFileSystem.read("hello.txt".toPath()) { readUtf8() }) + .isEqualTo("This is the second hello.txt") + assertThat(zipFileSystem.list("/".toPath())) + .containsExactly("/hello.txt".toPath()) + } + + @Test + fun canonicalizationValid() { + val zipPath = ZipBuilder(base) + .addEntry("hello.txt", "Hello World") + .addEntry("directory/child.txt", "Another file!") + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + assertThat(zipFileSystem.canonicalize("/".toPath())).isEqualTo("/".toPath()) + assertThat(zipFileSystem.canonicalize(".".toPath())).isEqualTo("/".toPath()) + assertThat(zipFileSystem.canonicalize("not/a/path/../../..".toPath())).isEqualTo("/".toPath()) + assertThat(zipFileSystem.canonicalize("hello.txt".toPath())).isEqualTo("/hello.txt".toPath()) + assertThat(zipFileSystem.canonicalize("stuff/../hello.txt".toPath())).isEqualTo("/hello.txt".toPath()) + assertThat(zipFileSystem.canonicalize("directory".toPath())).isEqualTo("/directory".toPath()) + assertThat(zipFileSystem.canonicalize("directory/whevs/..".toPath())).isEqualTo("/directory".toPath()) + assertThat(zipFileSystem.canonicalize("directory/child.txt".toPath())).isEqualTo("/directory/child.txt".toPath()) + assertThat(zipFileSystem.canonicalize("directory/whevs/../child.txt".toPath())).isEqualTo("/directory/child.txt".toPath()) + } + + @Test + fun canonicalizationInvalidThrows() { + val zipPath = ZipBuilder(base) + .addEntry("hello.txt", "Hello World") + .addEntry("directory/child.txt", "Another file!") + .build() + val zipFileSystem = fileSystem.openZip(zipPath) + + assertFailsWith<FileNotFoundException> { + zipFileSystem.canonicalize("not/a/path".toPath()) + } + } +} + +private fun ByteString.replaceAll(a: ByteString, b: ByteString): ByteString { + val buffer = Buffer() + buffer.write(this) + buffer.replace(a, b) + return buffer.readByteString() +} + +private fun Buffer.replace(a: ByteString, b: ByteString) { + val result = Buffer() + while (!exhausted()) { + val index = indexOf(a) + if (index == -1L) { + result.writeAll(this) + } else { + result.write(this, index) + result.write(b) + skip(a.size.toLong()) + } + } + writeAll(result) +} + +/** Decodes this ISO8601 time string. */ +fun String.toEpochMillis() = Instant.parse(this).toEpochMilliseconds() diff --git a/okio/src/jvmTest/kotlin/okio/internal/HmacTest.kt b/okio/src/jvmTest/kotlin/okio/internal/HmacTest.kt new file mode 100644 index 00000000..01666d17 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/internal/HmacTest.kt @@ -0,0 +1,111 @@ +/* + * Copyright (C) 2020 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio.internal + +import javax.crypto.Mac +import javax.crypto.spec.SecretKeySpec +import kotlin.random.Random +import okio.ByteString +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +/** + * Check the [Hmac] implementation against the reference [Mac] JVM implementation. + */ +@RunWith(Parameterized::class) +class HmacTest(val parameters: Parameters) { + + companion object { + @get:Parameterized.Parameters(name = "{0}") + @get:JvmStatic + val parameters: List<Parameters> + get() { + val algorithms = enumValues<Parameters.Algorithm>() + val keySizes = listOf(8, 32, 48, 64, 128, 256) + val dataSizes = listOf(0, 32, 64, 128, 256, 512) + return algorithms.flatMap { algorithm -> + keySizes.flatMap { keySize -> + dataSizes.map { dataSize -> + Parameters( + algorithm, + keySize, + dataSize, + ) + } + } + } + } + } + + private val keySize + get() = parameters.keySize + private val dataSize + get() = parameters.dataSize + private val algorithm + get() = parameters.algorithmName + + private val random = Random(682741861446) + + private val key = random.nextBytes(keySize) + private val bytes = random.nextBytes(dataSize) + private val mac = parameters.createMac(key) + + private val expected = hmac(algorithm, key, bytes) + + @Test + fun hmac() { + mac.update(bytes) + val hmacValue = mac.digest() + + Assert.assertArrayEquals(expected, hmacValue) + } + + @Test + fun hmacBytes() { + for (byte in bytes) { + mac.update(byteArrayOf(byte)) + } + val hmacValue = mac.digest() + + Assert.assertArrayEquals(expected, hmacValue) + } + + data class Parameters( + val algorithm: Algorithm, + val keySize: Int, + val dataSize: Int, + ) { + val algorithmName + get() = algorithm.algorithmName + + internal fun createMac(key: ByteArray) = + algorithm.HmacFactory(ByteString(key)) + + enum class Algorithm( + val algorithmName: String, + internal val HmacFactory: (key: ByteString) -> Hmac, + ) { + SHA_1("HmacSha1", Hmac.Companion::sha1), + SHA_256("HmacSha256", Hmac.Companion::sha256), + SHA_512("HmacSha512", Hmac.Companion::sha512), + } + } +} + +private fun hmac(algorithm: String, key: ByteArray, bytes: ByteArray) = + Mac.getInstance(algorithm).apply { init(SecretKeySpec(key, algorithm)) }.doFinal(bytes) diff --git a/okio/src/jvmTest/kotlin/okio/internal/ResourceFileSystemTest.kt b/okio/src/jvmTest/kotlin/okio/internal/ResourceFileSystemTest.kt new file mode 100644 index 00000000..13c63158 --- /dev/null +++ b/okio/src/jvmTest/kotlin/okio/internal/ResourceFileSystemTest.kt @@ -0,0 +1,505 @@ +/* + * Copyright (C) 2021 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okio.internal + +import java.net.URL +import java.net.URLClassLoader +import java.util.Enumeration +import kotlin.reflect.KClass +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.fail +import okio.BufferedSource +import okio.ByteString +import okio.FileNotFoundException +import okio.FileSystem +import okio.ForwardingFileSystem +import okio.IOException +import okio.Path +import okio.Path.Companion.toPath +import okio.ZipBuilder +import okio.randomToken +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class ResourceFileSystemTest { + private val fileSystem = FileSystem.RESOURCES as ResourceFileSystem + private var base = FileSystem.SYSTEM_TEMPORARY_DIRECTORY / randomToken(16) + + @Test + fun testResourceA() { + val path = "okio/resourcefilesystem/a.txt".toPath() + + val metadata = fileSystem.metadataOrNull(path)!! + + assertThat(metadata.size).isEqualTo(1L) + assertThat(metadata.isRegularFile).isTrue() + assertThat(metadata.isDirectory).isFalse() + + val content = fileSystem.read(path) { readUtf8() } + assertThat(fileSystem.metadata(path).isRegularFile).isTrue() + + assertThat(content).isEqualTo("a") + } + + @Test + fun testResourceB() { + val path = "okio/resourcefilesystem/b/b.txt".toPath() + + val metadata = fileSystem.metadataOrNull(path)!! + + assertThat(metadata.size).isEqualTo(3L) + assertThat(metadata.isRegularFile).isTrue() + assertThat(metadata.isDirectory).isFalse() + + val content = fileSystem.read(path) { readUtf8() } + assertThat(fileSystem.metadata(path).isRegularFile).isTrue() + + assertThat(content).isEqualTo("b/b") + } + + @Test + fun testSingleArchive() { + val zipPath = ZipBuilder(base) + .addEntry("hello.txt", "Hello World") + .addEntry("directory/subdirectory/child.txt", "Another file!") + .addEntry("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + .build() + val resourceFileSystem = ResourceFileSystem( + classLoader = URLClassLoader(arrayOf(zipPath.toFile().toURI().toURL()), null), + indexEagerly = false, + ) + + assertThat(resourceFileSystem.read("hello.txt".toPath()) { readUtf8() }) + .isEqualTo("Hello World") + assertThat( + resourceFileSystem.metadata("hello.txt".toPath()).isRegularFile, + ).isTrue() + + assertThat(resourceFileSystem.read("directory/subdirectory/child.txt".toPath()) { readUtf8() }) + .isEqualTo("Another file!") + assertThat( + resourceFileSystem.metadata("directory/subdirectory/child.txt".toPath()).isRegularFile, + ).isTrue() + + assertThat(resourceFileSystem.list("/".toPath())) + .hasSameElementsAs(listOf("/META-INF".toPath(), "/hello.txt".toPath(), "/directory".toPath())) + assertThat(resourceFileSystem.list("/directory".toPath())) + .containsExactly("/directory/subdirectory".toPath()) + assertThat(resourceFileSystem.list("/directory/subdirectory".toPath())) + .containsExactly("/directory/subdirectory/child.txt".toPath()) + + val metadata = resourceFileSystem.metadata(".".toPath()) + assertThat(metadata.isDirectory).isTrue() + } + + @Test + fun testDirectoryAndJarOverlap() { + val filesAPath = base / "filesA" + FileSystem.SYSTEM.createDirectories(filesAPath / "colors") + FileSystem.SYSTEM.write(filesAPath / "colors" / "red.txt") { writeUtf8("Apples are red") } + FileSystem.SYSTEM.write(filesAPath / "colors" / "green.txt") { writeUtf8("Grass is green") } + val zipBPath = ZipBuilder(base) + .addEntry("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + .addEntry("colors/blue.txt", "The sky is blue") + .addEntry("colors/green.txt", "Limes are green") + .build() + + val resourceFileSystem = ResourceFileSystem( + classLoader = URLClassLoader( + arrayOf( + filesAPath.toFile().toURI().toURL(), + zipBPath.toFile().toURI().toURL(), + ), + null, + ), + indexEagerly = false, + ) + + assertThat(resourceFileSystem.read("/colors/red.txt".toPath()) { readUtf8() }) + .isEqualTo("Apples are red") + assertThat(resourceFileSystem.read("/colors/green.txt".toPath()) { readUtf8() }) + .isEqualTo("Grass is green") + assertThat(resourceFileSystem.read("/colors/blue.txt".toPath()) { readUtf8() }) + .isEqualTo("The sky is blue") + + assertThat(resourceFileSystem.metadata("/colors/red.txt".toPath()).isRegularFile).isTrue() + assertThat(resourceFileSystem.metadata("/colors/green.txt".toPath()).isRegularFile).isTrue() + assertThat(resourceFileSystem.metadata("/colors/blue.txt".toPath()).isRegularFile).isTrue() + + assertThat(resourceFileSystem.list("/".toPath())) + .hasSameElementsAs(listOf("/META-INF".toPath(), "/colors".toPath())) + assertThat(resourceFileSystem.list("/colors".toPath())).hasSameElementsAs( + listOf( + "/colors/red.txt".toPath(), + "/colors/green.txt".toPath(), + "/colors/blue.txt".toPath(), + ), + ) + + assertThat(resourceFileSystem.metadata("/".toPath()).isDirectory).isTrue() + assertThat(resourceFileSystem.metadata("/colors".toPath()).isDirectory).isTrue() + } + + @Test + fun testDirectoryAndDirectoryOverlap() { + val filesAPath = base / "filesA" + FileSystem.SYSTEM.createDirectories(filesAPath / "colors") + FileSystem.SYSTEM.write(filesAPath / "colors" / "red.txt") { writeUtf8("Apples are red") } + FileSystem.SYSTEM.write(filesAPath / "colors" / "green.txt") { writeUtf8("Grass is green") } + val filesBPath = base / "filesB" + FileSystem.SYSTEM.createDirectories(filesBPath / "colors") + FileSystem.SYSTEM.write(filesBPath / "colors" / "blue.txt") { writeUtf8("The sky is blue") } + FileSystem.SYSTEM.write(filesBPath / "colors" / "green.txt") { writeUtf8("Limes are green") } + + val resourceFileSystem = ResourceFileSystem( + classLoader = URLClassLoader( + arrayOf( + filesAPath.toFile().toURI().toURL(), + filesBPath.toFile().toURI().toURL(), + ), + null, + ), + indexEagerly = false, + ) + + assertThat(resourceFileSystem.read("/colors/red.txt".toPath()) { readUtf8() }) + .isEqualTo("Apples are red") + assertThat(resourceFileSystem.read("/colors/green.txt".toPath()) { readUtf8() }) + .isEqualTo("Grass is green") + assertThat(resourceFileSystem.read("/colors/blue.txt".toPath()) { readUtf8() }) + .isEqualTo("The sky is blue") + + assertThat(resourceFileSystem.metadata("/colors/red.txt".toPath()).isRegularFile).isTrue() + assertThat(resourceFileSystem.metadata("/colors/green.txt".toPath()).isRegularFile).isTrue() + assertThat(resourceFileSystem.metadata("/colors/blue.txt".toPath()).isRegularFile).isTrue() + + assertThat(resourceFileSystem.list("/".toPath())) + .hasSameElementsAs(listOf("/colors".toPath())) + assertThat(resourceFileSystem.list("/colors".toPath())).hasSameElementsAs( + listOf( + "/colors/red.txt".toPath(), + "/colors/green.txt".toPath(), + "/colors/blue.txt".toPath(), + ), + ) + + assertThat(resourceFileSystem.metadata("/".toPath()).isDirectory).isTrue() + assertThat(resourceFileSystem.metadata("/colors".toPath()).isDirectory).isTrue() + } + + @Test + fun testJarAndJarOverlap() { + val zipAPath = ZipBuilder(base) + .addEntry("colors/red.txt", "Apples are red") + .addEntry("colors/green.txt", "Grass is green") + .addEntry("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + .build() + val zipBPath = ZipBuilder(base) + .addEntry("colors/blue.txt", "The sky is blue") + .addEntry("colors/green.txt", "Limes are green") + .addEntry("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + .build() + val resourceFileSystem = ResourceFileSystem( + classLoader = URLClassLoader( + arrayOf( + zipAPath.toFile().toURI().toURL(), + zipBPath.toFile().toURI().toURL(), + ), + null, + ), + indexEagerly = false, + ) + + assertThat(resourceFileSystem.read("/colors/red.txt".toPath()) { readUtf8() }) + .isEqualTo("Apples are red") + assertThat(resourceFileSystem.read("/colors/green.txt".toPath()) { readUtf8() }) + .isEqualTo("Grass is green") + assertThat(resourceFileSystem.read("/colors/blue.txt".toPath()) { readUtf8() }) + .isEqualTo("The sky is blue") + + assertThat(resourceFileSystem.metadata("/colors/red.txt".toPath()).isRegularFile).isTrue() + assertThat(resourceFileSystem.metadata("/colors/green.txt".toPath()).isRegularFile).isTrue() + assertThat(resourceFileSystem.metadata("/colors/blue.txt".toPath()).isRegularFile).isTrue() + + assertThat(resourceFileSystem.list("/".toPath())) + .hasSameElementsAs(listOf("/META-INF".toPath(), "/colors".toPath())) + assertThat(resourceFileSystem.list("/colors".toPath())).hasSameElementsAs( + listOf( + "/colors/red.txt".toPath(), + "/colors/green.txt".toPath(), + "/colors/blue.txt".toPath(), + ), + ) + + assertThat(resourceFileSystem.metadata("/".toPath()).isDirectory).isTrue() + assertThat(resourceFileSystem.metadata("/colors".toPath()).isDirectory).isTrue() + } + + @Test + fun testResourceMissing() { + val path = "okio/resourcefilesystem/b/c.txt".toPath() + + assertThat(fileSystem.metadataOrNull(path)).isNull() + + try { + fileSystem.read(path) { readUtf8() } + fail() + } catch (ioe: IOException) { + assertThat(ioe.message).isEqualTo("file not found: okio/resourcefilesystem/b/c.txt") + } + } + + @Test + fun testProjectIsListable() { + val path = "okio/resourcefilesystem/b/".toPath() + + val metadata = fileSystem.metadataOrNull(path)!! + + assertThat(metadata.isDirectory).isTrue() + assertThat(metadata.createdAtMillis).isGreaterThan(1L) + + assertThat(fileSystem.list(path).map { it.name }).containsExactly("b.txt") + } + + @Test + fun testResourceFromJar() { + val path = "LICENSE-junit.txt".toPath() + + val metadata = fileSystem.metadataOrNull(path)!! + + assertThat(metadata.size).isGreaterThan(10000L) + assertThat(metadata.isRegularFile).isTrue() + assertThat(metadata.isDirectory).isFalse() + + val content = fileSystem.read(path) { readUtf8Line() } + + assertThat(content).isEqualTo("JUnit") + } + + @Test + fun testClassFilesOmittedFromJar() { + assertThat(fileSystem.list("/org/junit/rules".toPath())).isEmpty() + assertThat(fileSystem.metadataOrNull("/org/junit/Test.class".toPath())).isNull() + } + + @Test + fun testClassFilesOmittedFromDirectory() { + val filesPath = base / "files" + val packagePath = filesPath / "com" / "example" / "project" + FileSystem.SYSTEM.createDirectories(packagePath) + FileSystem.SYSTEM.write(packagePath / "Hello.class") { writeUtf8("cafebabe") } + + val resourceFileSystem = ResourceFileSystem( + classLoader = URLClassLoader( + arrayOf(filesPath.toFile().toURI().toURL()), + null, + ), + indexEagerly = false, + ) + + assertThat(resourceFileSystem.list("/com/example/project".toPath())).isEmpty() + assertThat(resourceFileSystem.metadataOrNull("/com/example/project/Hello.class".toPath())) + .isNull() + assertFailsWith<FileNotFoundException> { + resourceFileSystem.source("/com/example/project/Hello.class".toPath()) + } + } + + @Test + fun testDirectoryFromJar() { + val path = "org/junit/".toPath() + + val metadata = fileSystem.metadataOrNull(path) + assertThat(metadata?.isDirectory).isTrue() + + val files = fileSystem.list(path).map { it.name } + assertThat(files).contains("matchers", "rules") + assertThat(files.filter { it.endsWith(".class") }).isEmpty() + } + + @Test + fun packagePath() { + val path = ByteString::class.java.`package`.toPath() + + assertThat((path / "a.txt").toString()) + .isEqualTo("okio${Path.DIRECTORY_SEPARATOR}a.txt") + } + + @Test + fun classResource() { + val path = ByteString::class.packagePath!! + + assertThat((path / "a.txt").toString()) + .isEqualTo("okio${Path.DIRECTORY_SEPARATOR}a.txt") + } + + /** + * Confirm that class loaders aren't accessed until the file system is used. This should save + * resources so that zip files aren't decoded if they're unused. + */ + @Test + fun testIndexLazily() { + val classLoader = object : ClassLoader() { + override fun findResources(name: String?): Enumeration<URL> { + throw Exception("finding a resource") + } + } + + val resourceFileSystem = ResourceFileSystem( + classLoader = classLoader, + indexEagerly = false, + ) + + assertThat( + assertFailsWith<Exception> { + resourceFileSystem.list("/".toPath()) + }, + ).hasMessage("finding a resource") + } + + @Test + fun testIndexEagerly() { + val classLoader = object : ClassLoader() { + override fun findResources(name: String?): Enumeration<URL> { + throw Exception("finding a resource") + } + } + + assertThat( + assertFailsWith<Exception> { + ResourceFileSystem( + classLoader = classLoader, + indexEagerly = true, + ) + }, + ).hasMessage("finding a resource") + } + + /** Confirm we can read individual files without triggering indexing. */ + @Test + fun testSourceDoesntTriggerIndexing() { + val processedPaths = mutableSetOf<Path>() + val recordingFileSystem = object : ForwardingFileSystem(FileSystem.SYSTEM) { + override fun onPathParameter(path: Path, functionName: String, parameterName: String): Path { + processedPaths += path + return super.onPathParameter(path, functionName, parameterName) + } + } + + val zipAPath = ZipBuilder(base) + .addEntry("colors/red.txt", "Apples are red") + .addEntry("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + .build() + val zipBPath = ZipBuilder(base) + .addEntry("colors/blue.txt", "The sky is blue") + .addEntry("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + .build() + val resourceFileSystem = ResourceFileSystem( + classLoader = URLClassLoader( + arrayOf( + zipAPath.toFile().toURI().toURL(), + zipBPath.toFile().toURI().toURL(), + ), + null, + ), + indexEagerly = false, + systemFileSystem = recordingFileSystem, + ) + + // Reading paths with source() or read() doesn't index zips. + assertThat( + resourceFileSystem.read("/colors/red.txt".toPath()) { readUtf8() }, + ).isEqualTo("Apples are red") + assertThat( + resourceFileSystem.read("/colors/blue.txt".toPath()) { readUtf8() }, + ).isEqualTo("The sky is blue") + assertThat(processedPaths).isEmpty() + + // Calling list() does though. + assertThat(resourceFileSystem.list("/colors".toPath())).containsExactlyInAnyOrder( + "/colors/red.txt".toPath(), + "/colors/blue.txt".toPath(), + ) + assertThat(processedPaths).containsExactlyInAnyOrder( + zipAPath, + zipBPath, + ) + } + + /** + * Our resource file system uses [URLClassLoader] internally, which means we need to go back and + * forth between [File], [URL], and [URI] models for component paths. This is a big hazard for + * escaping special characters and it's likely that some file paths won't survive the round trip! + */ + @Test + fun fileNameWithSpaceInPath() { + val zipPath = ZipBuilder(base / "space in directory name") + .addEntry("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + .addEntry("hello.txt", "Hello World") + .build() + val resourceFileSystem = ResourceFileSystem( + classLoader = URLClassLoader(arrayOf(zipPath.toFile().toURI().toURL()), null), + indexEagerly = false, + ) + + assertThat(resourceFileSystem.read("hello.txt".toPath()) { readUtf8() }) + .isEqualTo("Hello World") + assertThat(resourceFileSystem.metadata("hello.txt".toPath())).isNotNull() + } + + @Test + fun missingResourceSilentlyIgnored() { + val zipAPath = ZipBuilder(base) + .addEntry("META-INF/MANIFEST.MF", "Manifest-Version: 1.0\n") + .addEntry("hello.txt", "Hello World") + .build() + val resourceFileSystem = ResourceFileSystem( + classLoader = URLClassLoader( + arrayOf( + zipAPath.toFile().toURI().toURL(), + (base / "missing.zip").toFile().toURI().toURL(), + ), + null, + ), + indexEagerly = false, + ) + + assertThat(resourceFileSystem.read("hello.txt".toPath()) { readUtf8() }) + .isEqualTo("Hello World") + } + + @Test + fun listSpecialCharacterNamedFiles() { + val path = "okio/resourcefilesystem/non-ascii".toPath() + + assertThat(fileSystem.listRecursively(path).toList()).containsExactly( + "/okio/resourcefilesystem/non-ascii/ギリシア神話".toPath(), + "/okio/resourcefilesystem/non-ascii/ギリシア神話/Ἰλιάς".toPath(), + ) + + val content = fileSystem.read( + "/okio/resourcefilesystem/non-ascii/ギリシア神話/Ἰλιάς".toPath(), + BufferedSource::readUtf8, + ) + assertEquals("Chante, ô déesse, le courroux du Péléide Achille,\n", content) + } + + private fun Package.toPath(): Path = name.replace(".", "/").toPath() + + private val KClass<*>.packagePath: Path? + get() = qualifiedName?.replace(".", "/")?.toPath()?.parent +} |