From b275a2d013f502de4aaedbc5ea849d10fbca5255 Mon Sep 17 00:00:00 2001 From: Martin Raison Date: Mon, 1 Feb 2021 15:48:49 +0100 Subject: Contextual serialization for derived classes (#1277) Fixes #1276 The issue #1276 is happening because, in the implementation of ContextualSerializer, the serializer is looked up using the concrete class instead of the serializable class --- .../kotlinx/serialization/ContextualSerializer.kt | 11 +++-- .../features/DerivedContextualSerializerTest.kt | 49 ++++++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 formats/json/commonTest/src/kotlinx/serialization/features/DerivedContextualSerializerTest.kt diff --git a/core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt b/core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt index 3919fd12..e43e6250 100644 --- a/core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt +++ b/core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt @@ -45,6 +45,9 @@ public class ContextualSerializer( private val typeParametersSerializers: Array> ) : KSerializer { + private fun serializer(serializersModule: SerializersModule): KSerializer = + serializersModule.getContextual(serializableClass) ?: fallbackSerializer ?: serializableClass.serializerNotRegistered() + // Used from auto-generated code public constructor(serializableClass: KClass) : this(serializableClass, null, EMPTY_SERIALIZER_ARRAY) @@ -52,14 +55,10 @@ public class ContextualSerializer( buildSerialDescriptor("kotlinx.serialization.ContextualSerializer", SerialKind.CONTEXTUAL).withContext(serializableClass) public override fun serialize(encoder: Encoder, value: T) { - val clz = value::class - val serializer = encoder.serializersModule.getContextual(clz) ?: fallbackSerializer ?: serializableClass.serializerNotRegistered() - @Suppress("UNCHECKED_CAST") - encoder.encodeSerializableValue(serializer as SerializationStrategy, value) + encoder.encodeSerializableValue(serializer(encoder.serializersModule), value) } public override fun deserialize(decoder: Decoder): T { - val serializer = decoder.serializersModule.getContextual(serializableClass) ?: fallbackSerializer ?: serializableClass.serializerNotRegistered() - return decoder.decodeSerializableValue(serializer) + return decoder.decodeSerializableValue(serializer(decoder.serializersModule)) } } diff --git a/formats/json/commonTest/src/kotlinx/serialization/features/DerivedContextualSerializerTest.kt b/formats/json/commonTest/src/kotlinx/serialization/features/DerivedContextualSerializerTest.kt new file mode 100644 index 00000000..7da16fb8 --- /dev/null +++ b/formats/json/commonTest/src/kotlinx/serialization/features/DerivedContextualSerializerTest.kt @@ -0,0 +1,49 @@ +package kotlinx.serialization.features + +import kotlin.test.* +import kotlinx.serialization.* +import kotlinx.serialization.descriptors.* +import kotlinx.serialization.encoding.* +import kotlinx.serialization.json.* +import kotlinx.serialization.modules.* + +class DerivedContextualSerializerTest { + + @Serializable + abstract class Message + + @Serializable + class SimpleMessage(val body: String) : Message() + + @Serializable + class Holder(@Contextual val message: Message) + + object MessageAsStringSerializer : KSerializer { + override val descriptor: SerialDescriptor = + PrimitiveSerialDescriptor("kotlinx.serialization.MessageAsStringSerializer", PrimitiveKind.STRING) + + override fun serialize(encoder: Encoder, value: Message) { + // dummy serializer that assumes Message is always SimpleMessage + check(value is SimpleMessage) + encoder.encodeString(value.body) + } + + override fun deserialize(decoder: Decoder): Message { + return SimpleMessage(decoder.decodeString()) + } + } + + @Test + fun testDerivedContextualSerializer() { + val module = SerializersModule { + contextual(MessageAsStringSerializer) + } + val format = Json { serializersModule = module } + val data = Holder(SimpleMessage("hello")) + val serialized = format.encodeToString(data) + assertEquals("""{"message":"hello"}""", serialized) + val deserialized = format.decodeFromString(serialized) + assertTrue(deserialized.message is SimpleMessage) + assertEquals("hello", deserialized.message.body) + } +} -- cgit v1.2.3