summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Raison <martinraison@users.noreply.github.com>2021-02-01 15:48:49 +0100
committerGitHub <noreply@github.com>2021-02-01 17:48:49 +0300
commitb275a2d013f502de4aaedbc5ea849d10fbca5255 (patch)
treeb9fa81a5589c5b62c36f18e25626d1fc26f8967a
parentac5450475f5e6e86f7de0db821bfc6014befb126 (diff)
downloadkotlinx.serialization-b275a2d013f502de4aaedbc5ea849d10fbca5255.tar.gz
Contextual serialization for derived classes (#1277)
Fixes #1276 The issue #1276 is happening because, in the implementation of ContextualSerializer, the serializer is looked up using the concrete class instead of the serializable class
-rw-r--r--core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt11
-rw-r--r--formats/json/commonTest/src/kotlinx/serialization/features/DerivedContextualSerializerTest.kt49
2 files changed, 54 insertions, 6 deletions
diff --git a/core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt b/core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt
index 3919fd12..e43e6250 100644
--- a/core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt
+++ b/core/commonMain/src/kotlinx/serialization/ContextualSerializer.kt
@@ -45,6 +45,9 @@ public class ContextualSerializer<T : Any>(
private val typeParametersSerializers: Array<KSerializer<*>>
) : KSerializer<T> {
+ private fun serializer(serializersModule: SerializersModule): KSerializer<T> =
+ serializersModule.getContextual(serializableClass) ?: fallbackSerializer ?: serializableClass.serializerNotRegistered()
+
// Used from auto-generated code
public constructor(serializableClass: KClass<T>) : this(serializableClass, null, EMPTY_SERIALIZER_ARRAY)
@@ -52,14 +55,10 @@ public class ContextualSerializer<T : Any>(
buildSerialDescriptor("kotlinx.serialization.ContextualSerializer", SerialKind.CONTEXTUAL).withContext(serializableClass)
public override fun serialize(encoder: Encoder, value: T) {
- val clz = value::class
- val serializer = encoder.serializersModule.getContextual(clz) ?: fallbackSerializer ?: serializableClass.serializerNotRegistered()
- @Suppress("UNCHECKED_CAST")
- encoder.encodeSerializableValue(serializer as SerializationStrategy<T>, value)
+ encoder.encodeSerializableValue(serializer(encoder.serializersModule), value)
}
public override fun deserialize(decoder: Decoder): T {
- val serializer = decoder.serializersModule.getContextual(serializableClass) ?: fallbackSerializer ?: serializableClass.serializerNotRegistered()
- return decoder.decodeSerializableValue(serializer)
+ return decoder.decodeSerializableValue(serializer(decoder.serializersModule))
}
}
diff --git a/formats/json/commonTest/src/kotlinx/serialization/features/DerivedContextualSerializerTest.kt b/formats/json/commonTest/src/kotlinx/serialization/features/DerivedContextualSerializerTest.kt
new file mode 100644
index 00000000..7da16fb8
--- /dev/null
+++ b/formats/json/commonTest/src/kotlinx/serialization/features/DerivedContextualSerializerTest.kt
@@ -0,0 +1,49 @@
+package kotlinx.serialization.features
+
+import kotlin.test.*
+import kotlinx.serialization.*
+import kotlinx.serialization.descriptors.*
+import kotlinx.serialization.encoding.*
+import kotlinx.serialization.json.*
+import kotlinx.serialization.modules.*
+
+class DerivedContextualSerializerTest {
+
+ @Serializable
+ abstract class Message
+
+ @Serializable
+ class SimpleMessage(val body: String) : Message()
+
+ @Serializable
+ class Holder(@Contextual val message: Message)
+
+ object MessageAsStringSerializer : KSerializer<Message> {
+ override val descriptor: SerialDescriptor =
+ PrimitiveSerialDescriptor("kotlinx.serialization.MessageAsStringSerializer", PrimitiveKind.STRING)
+
+ override fun serialize(encoder: Encoder, value: Message) {
+ // dummy serializer that assumes Message is always SimpleMessage
+ check(value is SimpleMessage)
+ encoder.encodeString(value.body)
+ }
+
+ override fun deserialize(decoder: Decoder): Message {
+ return SimpleMessage(decoder.decodeString())
+ }
+ }
+
+ @Test
+ fun testDerivedContextualSerializer() {
+ val module = SerializersModule {
+ contextual(MessageAsStringSerializer)
+ }
+ val format = Json { serializersModule = module }
+ val data = Holder(SimpleMessage("hello"))
+ val serialized = format.encodeToString(data)
+ assertEquals("""{"message":"hello"}""", serialized)
+ val deserialized = format.decodeFromString<Holder>(serialized)
+ assertTrue(deserialized.message is SimpleMessage)
+ assertEquals("hello", deserialized.message.body)
+ }
+}