diff options
author | Vsevolod Tolstopyatov <qwwdfsad@gmail.com> | 2020-03-20 12:24:15 +0300 |
---|---|---|
committer | Vsevolod Tolstopyatov <qwwdfsad@gmail.com> | 2020-03-26 13:49:21 +0300 |
commit | 16ae7fae0f93f1e4bc98330d21b696342f3e862c (patch) | |
tree | 66a1a18f60c0e62abb6b18368536dcad4e01a7f9 /formats/protobuf | |
parent | 23c1b541ed55ec34029b5d205818cc868a0271df (diff) | |
download | kotlinx.serialization-16ae7fae0f93f1e4bc98330d21b696342f3e862c.tar.gz |
Optimize ProtoBuf encoding:
* Leverage hotspot intrinsics for conversion to LE
* Optimize LE loops
* Optimize varint decoding, add fast-path for single byte and two byte values as the most common
* Single copy String decoding
* Make extract parameters loops EA and inliner independent
Diffstat (limited to 'formats/protobuf')
6 files changed, 108 insertions, 91 deletions
diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoBuf.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoBuf.kt index c8b1a660..af6f88d3 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoBuf.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoBuf.kt @@ -12,7 +12,6 @@ import kotlinx.serialization.internal.* import kotlinx.serialization.modules.* import kotlinx.serialization.protobuf.ProtoBuf.Varint.decodeSignedVarintInt import kotlinx.serialization.protobuf.ProtoBuf.Varint.decodeSignedVarintLong -import kotlinx.serialization.protobuf.ProtoBuf.Varint.decodeVarint import kotlinx.serialization.protobuf.ProtoBuf.Varint.encodeVarint import kotlin.jvm.* @@ -160,7 +159,7 @@ public class ProtoBuf( ProtoNumberType.DEFAULT ) - override fun SerialDescriptor.getTag(index: Int) = this.getProtoDesc(index) + override fun SerialDescriptor.getTag(index: Int) = extractParameters(index) @Suppress("UNCHECKED_CAST", "NAME_SHADOWING") override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) = when { @@ -235,12 +234,12 @@ public class ProtoBuf( fun writeDouble(value: Double, tag: Int) { out.encode32((tag shl 3) or i64) - value.toLittleEndian().writeTo(out) + value.reverseBytes().writeTo(out) } fun writeFloat(value: Float, tag: Int) { out.encode32((tag shl 3) or i32) - value.toLittleEndian().writeTo(out) + value.reverseBytes().writeTo(out) } private fun OutputStream.encode32( @@ -248,7 +247,7 @@ public class ProtoBuf( format: ProtoNumberType = ProtoNumberType.DEFAULT ) { when (format) { - ProtoNumberType.FIXED -> number.toLittleEndian().writeTo(this) + ProtoNumberType.FIXED -> number.reverseBytes().writeTo(this) ProtoNumberType.DEFAULT -> encodeVarint(number.toLong()) ProtoNumberType.SIGNED -> encodeVarint(((number shl 1) xor (number shr 31))) } @@ -257,7 +256,7 @@ public class ProtoBuf( private fun OutputStream.encode64(number: Long, format: ProtoNumberType = ProtoNumberType.DEFAULT) { when (format) { - ProtoNumberType.FIXED -> number.toLittleEndian().writeTo(this) + ProtoNumberType.FIXED -> number.reverseBytes().writeTo(this) ProtoNumberType.DEFAULT -> encodeVarint(number) ProtoNumberType.SIGNED -> encodeVarint((number shl 1) xor (number shr 63)) } @@ -265,7 +264,7 @@ public class ProtoBuf( } - private open inner class ProtobufReader(val decoder: ProtobufDecoder) : TaggedDecoder<ProtoDesc>() { + private open inner class ProtobufReader(@JvmField val decoder: ProtobufDecoder) : TaggedDecoder<ProtoDesc>() { override val context: SerialModule get() = this@ProtoBuf.context @@ -338,7 +337,10 @@ public class ProtoBuf( override fun decodeTaggedByte(tag: ProtoDesc): Byte = decoder.nextInt(tag.numberType).toByte() override fun decodeTaggedShort(tag: ProtoDesc): Short = decoder.nextInt(tag.numberType).toShort() - override fun decodeTaggedInt(tag: ProtoDesc): Int = decoder.nextInt(tag.numberType) + + override fun decodeTaggedInt(tag: ProtoDesc): Int { + return decoder.nextInt(tag.numberType) + } override fun decodeTaggedLong(tag: ProtoDesc): Long = decoder.nextLong(tag.numberType) override fun decodeTaggedFloat(tag: ProtoDesc): Float = decoder.nextFloat() override fun decodeTaggedDouble(tag: ProtoDesc): Double = decoder.nextDouble() @@ -365,7 +367,7 @@ public class ProtoBuf( return setOfEntries.associateBy({ it.key }, { it.value }) as T } - override fun SerialDescriptor.getTag(index: Int) = this.getProtoDesc(index) + override fun SerialDescriptor.getTag(index: Int) = extractParameters(index) override fun decodeElementIndex(descriptor: SerialDescriptor): Int { while (true) { @@ -385,10 +387,10 @@ public class ProtoBuf( decoder: ProtobufDecoder, private val targetTag: ProtoDesc ) : ProtobufReader(decoder) { - private var ind = -1 + private var index = -1 override fun decodeElementIndex(descriptor: SerialDescriptor) = - if (decoder.currentId == targetTag.protoId) ++ind else READ_DONE + if (decoder.currentId == targetTag.protoId) ++index else READ_DONE override fun SerialDescriptor.getTag(index: Int): ProtoDesc = targetTag } @@ -399,7 +401,7 @@ public class ProtoBuf( else ProtoDesc(2, (parentTag?.numberType ?: ProtoNumberType.DEFAULT)) } - internal class ProtobufDecoder(private val inp: ByteArrayInputStream) { + internal class ProtobufDecoder(private val input: ByteArrayInputStream) { @JvmField public var currentId = -1 @JvmField @@ -432,16 +434,16 @@ public class ProtoBuf( @Suppress("NOTHING_TO_INLINE") private inline fun assertWireType(expected: Int) { - if (currentType != expected) throw ProtobufDecodingException("Expected wire type $expected, but found ${currentType}") + if (currentType != expected) throw ProtobufDecodingException("Expected wire type $expected, but found $currentType") } fun nextObject(): ByteArray { assertWireType(SIZE_DELIMITED) - val len = decode32() - check(len >= 0) - val ans = inp.readExactNBytes(len) + val length = decode32() + check(length >= 0) + val result = input.readExactNBytes(length) readTag() - return ans + return result } private fun InputStream.readExactNBytes(bytes: Int): ByteArray { @@ -458,66 +460,69 @@ public class ProtoBuf( fun nextInt(format: ProtoNumberType): Int { val wireType = if (format == ProtoNumberType.FIXED) i32 else VARINT assertWireType(wireType) - val ans = decode32(format) + val result = decode32(format) readTag() - return ans + return result } fun nextLong(format: ProtoNumberType): Long { val wireType = if (format == ProtoNumberType.FIXED) i64 else VARINT assertWireType(wireType) - val ans = decode64(format) + val result = decode64(format) readTag() - return ans + return result } fun nextFloat(): Float { assertWireType(i32) - val ans = readIntLittleEndian() + val result = Float.fromBits(readIntLittleEndian()) readTag() - return Float.fromBits(ans) + return result } private fun readIntLittleEndian(): Int { var result = 0 for (i in 0..3) { - val byte = inp.read() - result = (result shl 8) or byte + val byte = input.read() and 0x000000FF + result = result or (byte shl (i * 8)) } - return result.toLittleEndian() + return result } private fun readLongLittleEndian(): Long { var result = 0L for (i in 0..7) { - val byte = inp.read() - result = (result shl 8) or byte.toLong() + val byte = (input.read() and 0x000000FF).toLong() + result = result or (byte shl (i * 8)) } - return result.toLittleEndian() + return result } fun nextDouble(): Double { assertWireType(i64) - val ans = readLongLittleEndian() + val result = Double.fromBits(readLongLittleEndian()) readTag() - return Double.fromBits(ans) + return result } fun nextString(): String { - val bytes = this.nextObject() - return bytes.decodeToString() + assertWireType(SIZE_DELIMITED) + val length = decode32() + check(length >= 0) + val result = input.readString(length) + readTag() + return result } - private fun decode32(format: ProtoNumberType = ProtoNumberType.DEFAULT, eofAllowed: Boolean = false): Int = - when (format) { - ProtoNumberType.DEFAULT -> decodeVarint(inp, 64, eofAllowed).toInt() - ProtoNumberType.SIGNED -> decodeSignedVarintInt(inp) - ProtoNumberType.FIXED -> readIntLittleEndian() - } + private fun decode32(format: ProtoNumberType = ProtoNumberType.DEFAULT, eofAllowed: Boolean = false): Int = when (format) { + ProtoNumberType.DEFAULT -> input.readVarint64(eofAllowed).toInt() + ProtoNumberType.SIGNED -> decodeSignedVarintInt(input) + ProtoNumberType.FIXED -> readIntLittleEndian() + } private fun decode64(format: ProtoNumberType = ProtoNumberType.DEFAULT): Long = when (format) { - ProtoNumberType.DEFAULT -> decodeVarint(inp, 64) - ProtoNumberType.SIGNED -> decodeSignedVarintLong(inp) + ProtoNumberType.DEFAULT -> input.readVarint64(false) + ProtoNumberType.SIGNED -> decodeSignedVarintLong(input) ProtoNumberType.FIXED -> readLongLittleEndian() } } @@ -545,29 +550,8 @@ public class ProtoBuf( write((value and 0x7F).toInt()) } - internal fun decodeVarint(inp: InputStream, bitLimit: Int = 32, eofOnStartAllowed: Boolean = false): Long { - var result = 0L - var shift = 0 - var b: Int - do { - if (shift >= bitLimit) { - // Out of range - throw ProtobufDecodingException("Varint too long: exceeded $bitLimit bits") - } - // Get 7 bits from next byte - b = inp.read() - if (b == -1) { - if (eofOnStartAllowed && shift == 0) return -1 - else error("Unexpected EOF") - } - result = result or (b.toLong() and 0x7FL shl shift) - shift += 7 - } while (b and 0x80 != 0) - return result - } - - internal fun decodeSignedVarintInt(inp: InputStream): Int { - val raw = decodeVarint(inp, 32).toInt() + internal fun decodeSignedVarintInt(input: InputStream): Int { + val raw = input.readVarint32() val temp = raw shl 31 shr 31 xor raw shr 1 // This extra step lets us deal with the largest signed values by treating // negative results from read unsigned methods as like unsigned values. @@ -575,8 +559,8 @@ public class ProtoBuf( return temp xor (raw and (1 shl 31)) } - internal fun decodeSignedVarintLong(inp: InputStream): Long { - val raw = decodeVarint(inp, 64) + internal fun decodeSignedVarintLong(input: InputStream): Long { + val raw = input.readVarint64(false) val temp = raw shl 63 shr 63 xor raw shr 1 // This extra step lets us deal with the largest signed values by treating // negative results from read unsigned methods as like unsigned values @@ -594,10 +578,6 @@ public class ProtoBuf( return ProtobufDecoder(ByteArrayInputStream(bytes)) } - private fun SerialDescriptor.getProtoDesc(index: Int): ProtoDesc { - return extractParameters(this, index) - } - internal const val VARINT = 0 internal const val i64 = 1 internal const val SIZE_DELIMITED = 2 @@ -618,19 +598,9 @@ public class ProtoBuf( } } -private fun Short.toLittleEndian(): Short = - (((this.toInt() and 0xff) shl 8) or ((this.toInt() and 0xffff) ushr 8)).toShort() - -private fun Int.toLittleEndian(): Int = - ((this and 0xffff).toShort().toLittleEndian().toInt() shl 16) or ((this ushr 16).toShort().toLittleEndian().toInt() and 0xffff) - -private fun Long.toLittleEndian(): Long = - ((this and 0xffffffff).toInt().toLittleEndian().toLong() shl 32) or ((this ushr 32).toInt() - .toLittleEndian().toLong() and 0xffffffff) - -private fun Float.toLittleEndian(): Int = toRawBits().toLittleEndian() +private fun Float.reverseBytes(): Int = toRawBits().reverseBytes() -private fun Double.toLittleEndian(): Long = toRawBits().toLittleEndian() +private fun Double.reverseBytes(): Long = toRawBits().reverseBytes() private fun Int.writeTo(out: OutputStream) { for (i in 3 downTo 0) { diff --git a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoTypes.kt b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoTypes.kt index de2c950d..08f784f2 100644 --- a/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoTypes.kt +++ b/formats/protobuf/commonMain/src/kotlinx/serialization/protobuf/ProtoTypes.kt @@ -24,7 +24,6 @@ public enum class ProtoNumberType(internal val signature: Long) { SIGNED(2L shl 32), FIXED(3L shl 32); - internal fun equalTo(descriptor: ProtoDesc): Boolean { return descriptor and MASK == signature } @@ -53,11 +52,12 @@ internal val ProtoDesc.numberType: ProtoNumberType get() { return ProtoNumberType.FIXED } -internal fun extractParameters(descriptor: SerialDescriptor, index: Int): ProtoDesc { - val annotations = descriptor.getElementAnnotations(index) +internal fun SerialDescriptor.extractParameters(index: Int): ProtoDesc { + val annotations = getElementAnnotations(index) var protoId: Int = index + 1 var format: ProtoNumberType = ProtoNumberType.DEFAULT - for (annotation in annotations) { + for (i in annotations.indices) { // Allocation-friendly loop + val annotation = annotations[i] if (annotation is ProtoId) { protoId = annotation.id } else if (annotation is ProtoType) { @@ -68,12 +68,17 @@ internal fun extractParameters(descriptor: SerialDescriptor, index: Int): ProtoD } internal fun extractProtoId(descriptor: SerialDescriptor, index: Int, zeroBasedDefault: Boolean = false): Int { - val protoId = descriptor.findAnnotation<ProtoId>(index) - return protoId?.id ?: (if (zeroBasedDefault) index else index + 1) + val annotations = descriptor.getElementAnnotations(index) + for (i in annotations.indices) { // Allocation-friendly loop + val annotation = annotations[i] + if (annotation is ProtoId) { + return annotation.id + } + } + return if (zeroBasedDefault) index else index + 1 } public class ProtobufDecodingException(message: String) : SerializationException(message) -internal inline fun <reified A: Annotation> SerialDescriptor.findAnnotation(elementIndex: Int): A? { - return getElementAnnotations(elementIndex).find { it is A } as A? -} +internal expect fun Int.reverseBytes(): Int +internal expect fun Long.reverseBytes(): Long diff --git a/formats/protobuf/jsMain/src/kotlinx/serialization/protobuf/Bytes.kt b/formats/protobuf/jsMain/src/kotlinx/serialization/protobuf/Bytes.kt new file mode 100644 index 00000000..288a4a0b --- /dev/null +++ b/formats/protobuf/jsMain/src/kotlinx/serialization/protobuf/Bytes.kt @@ -0,0 +1,14 @@ +/* + * Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.protobuf + +private fun Short.reverseBytes(): Short = (((this.toInt() and 0xff) shl 8) or ((this.toInt() and 0xffff) ushr 8)).toShort() + +internal actual fun Int.reverseBytes(): Int = + ((this and 0xffff).toShort().reverseBytes().toInt() shl 16) or ((this ushr 16).toShort().reverseBytes().toInt() and 0xffff) + +internal actual fun Long.reverseBytes(): Long = + ((this and 0xffffffff).toInt().reverseBytes().toLong() shl 32) or ((this ushr 32).toInt() + .reverseBytes().toLong() and 0xffffffff) diff --git a/formats/protobuf/jvmMain/src/kotlinx/serialization/protobuf/Bytes.kt b/formats/protobuf/jvmMain/src/kotlinx/serialization/protobuf/Bytes.kt new file mode 100644 index 00000000..bfb6fdd1 --- /dev/null +++ b/formats/protobuf/jvmMain/src/kotlinx/serialization/protobuf/Bytes.kt @@ -0,0 +1,9 @@ +/* + * Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.protobuf + +internal actual fun Int.reverseBytes(): Int = Integer.reverseBytes(this) + +internal actual fun Long.reverseBytes(): Long = java.lang.Long.reverseBytes(this) diff --git a/formats/protobuf/nativeMain/src/kotlinx/serialization/protobuf/Bytes.kt b/formats/protobuf/nativeMain/src/kotlinx/serialization/protobuf/Bytes.kt new file mode 100644 index 00000000..288a4a0b --- /dev/null +++ b/formats/protobuf/nativeMain/src/kotlinx/serialization/protobuf/Bytes.kt @@ -0,0 +1,14 @@ +/* + * Copyright 2017-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.serialization.protobuf + +private fun Short.reverseBytes(): Short = (((this.toInt() and 0xff) shl 8) or ((this.toInt() and 0xffff) ushr 8)).toShort() + +internal actual fun Int.reverseBytes(): Int = + ((this and 0xffff).toShort().reverseBytes().toInt() shl 16) or ((this ushr 16).toShort().reverseBytes().toInt() and 0xffff) + +internal actual fun Long.reverseBytes(): Long = + ((this and 0xffffffff).toInt().reverseBytes().toLong() shl 32) or ((this ushr 32).toInt() + .reverseBytes().toLong() and 0xffffffff) diff --git a/formats/protobuf/nativeTest/src/kotlinx/serialization/NativeTest.kt b/formats/protobuf/nativeTest/src/kotlinx/serialization/NativeTest.kt index eea03158..5880d946 100644 --- a/formats/protobuf/nativeTest/src/kotlinx/serialization/NativeTest.kt +++ b/formats/protobuf/nativeTest/src/kotlinx/serialization/NativeTest.kt @@ -27,4 +27,9 @@ class CommonTest { val id1 = country.descriptor.findAnnotation<ProtoId>(0)?.id ?: 0 assertEquals(10, id1) } + + private inline fun <reified A: Annotation> SerialDescriptor.findAnnotation(elementIndex: Int): A? { + return getElementAnnotations(elementIndex).find { it is A } as A? + } + } |