From 0ff1e24e9486900a895af805e58f4e468ec5edf7 Mon Sep 17 00:00:00 2001 From: Abseil Team Date: Wed, 18 Oct 2023 12:27:30 -0700 Subject: Add an assertDataclassEqual method that provides better errors when it fails. PiperOrigin-RevId: 574555720 --- CHANGELOG.md | 3 ++ absl/testing/absltest.py | 61 +++++++++++++++++++++++++++ absl/testing/tests/absltest_test.py | 83 +++++++++++++++++++++++++++++++++++++ 3 files changed, 147 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f265810..776dc3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). with a construct to modify values. The new interface parallels `absl.flags.FlagValues.__setattr__` but checks that the provided value conforms to the flag's expected type. +* (testing) Added a new method `absltest.TestCase.assertDataclassEqual` that + tests equality of `dataclass.dataclass` objects with better error messages + when the assert fails. ### Fixed diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py index 61b096d..a3b4e34 100644 --- a/absl/testing/absltest.py +++ b/absl/testing/absltest.py @@ -20,6 +20,7 @@ tests. from collections import abc import contextlib +import dataclasses import difflib import enum import errno @@ -1730,6 +1731,66 @@ class TestCase(unittest.TestCase): raise self.failureException('\n'.join(message)) + def assertDataclassEqual(self, first, second, msg=None): + """Asserts two dataclasses are equal with more informative errors. + + Arguments must both be dataclasses. This compares equality of individual + fields and takes care to not compare fields that are marked as + non-comparable. It gives per field differences, which are easier to parse + than the comparison of the string representations from assertEqual. + + In cases where the dataclass has a custom __eq__, and it is defined in a + way that is inconsistent with equality of comparable fields, we raise an + exception without further trying to figure out how they are different. + + Args: + first: A dataclass, the first value. + second: A dataclass, the second value. + msg: An optional str, the associated message. + + Raises: + AssertionError: if the dataclasses are not equal. + """ + + if not dataclasses.is_dataclass(first) or isinstance(first, type): + raise self.failureException('First argument is not a dataclass instance.') + if not dataclasses.is_dataclass(second) or isinstance(second, type): + raise self.failureException( + 'Second argument is not a dataclass instance.' + ) + + if first == second: + return + + if type(first) is not type(second): + self.fail( + 'Found different dataclass types: %s != %s' + % (type(first), type(second)), + msg, + ) + + # Make sure to skip fields that are marked compare=False. + different = [ + (f.name, getattr(first, f.name), getattr(second, f.name)) + for f in dataclasses.fields(first) + if f.compare and getattr(first, f.name) != getattr(second, f.name) + ] + + safe_repr = unittest.util.safe_repr # pytype: disable=module-attr + message = ['%s != %s' % (safe_repr(first), safe_repr(second))] + if different: + message.append('Fields that differ:') + message.extend( + '%s: %s != %s' % (k, safe_repr(first_v), safe_repr(second_v)) + for k, first_v, second_v in different + ) + else: + message.append( + 'Cannot detect difference by examining the fields of the dataclass.' + ) + + raise self.fail('\n'.join(message), msg) + def assertUrlEqual(self, a, b, msg=None): """Asserts that urls are equal, ignoring ordering of query params.""" parsed_a = parse.urlparse(a) diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py index f23d856..baeccaf 100644 --- a/absl/testing/tests/absltest_test.py +++ b/absl/testing/tests/absltest_test.py @@ -16,6 +16,7 @@ import collections import contextlib +import dataclasses import io import os import pathlib @@ -24,6 +25,7 @@ import stat import string import subprocess import tempfile +import textwrap from typing import Optional import unittest @@ -1974,6 +1976,87 @@ class InitNotNecessaryForAssertsTest(absltest.TestCase): Subclass().assertEqual({}, {}) +@dataclasses.dataclass +class _ExampleDataclass: + comparable: str + not_comparable: str = dataclasses.field(compare=False) + comparable2: str = 'comparable2' + + +@dataclasses.dataclass +class _ExampleCustomEqualDataclass: + value: str + + def __eq__(self, other): + return False + + +class TestAssertDataclassEqual(absltest.TestCase): + + def test_assert_dataclass_equal_checks_a_for_dataclass(self): + b = _ExampleDataclass('a', 'b') + + message = 'First argument is not a dataclass instance.' + with self.assertRaisesWithLiteralMatch(AssertionError, message): + self.assertDataclassEqual('a', b) + + def test_assert_dataclass_equal_checks_b_for_dataclass(self): + a = _ExampleDataclass('a', 'b') + + message = 'Second argument is not a dataclass instance.' + with self.assertRaisesWithLiteralMatch(AssertionError, message): + self.assertDataclassEqual(a, 'b') + + def test_assert_dataclass_equal_different_dataclasses(self): + a = _ExampleDataclass('a', 'b') + b = _ExampleCustomEqualDataclass('c') + + message = """Found different dataclass types: != """ + with self.assertRaisesWithLiteralMatch(AssertionError, message): + self.assertDataclassEqual(a, b) + + def test_assert_dataclass_equal(self): + a = _ExampleDataclass(comparable='a', not_comparable='b') + b = _ExampleDataclass(comparable='a', not_comparable='c') + + self.assertDataclassEqual(a, a) + self.assertDataclassEqual(a, b) + self.assertDataclassEqual(b, a) + + def test_assert_dataclass_fails_non_equal_classes_assert_dict_passes(self): + a = _ExampleCustomEqualDataclass(value='a') + b = _ExampleCustomEqualDataclass(value='a') + + message = textwrap.dedent("""\ + _ExampleCustomEqualDataclass(value='a') != _ExampleCustomEqualDataclass(value='a') + Cannot detect difference by examining the fields of the dataclass.""") + with self.assertRaisesWithLiteralMatch(AssertionError, message): + self.assertDataclassEqual(a, b) + + def test_assert_dataclass_fails_assert_dict_fails_one_field(self): + a = _ExampleDataclass(comparable='a', not_comparable='b') + b = _ExampleDataclass(comparable='c', not_comparable='d') + + message = textwrap.dedent("""\ + _ExampleDataclass(comparable='a', not_comparable='b', comparable2='comparable2') != _ExampleDataclass(comparable='c', not_comparable='d', comparable2='comparable2') + Fields that differ: + comparable: 'a' != 'c'""") + with self.assertRaisesWithLiteralMatch(AssertionError, message): + self.assertDataclassEqual(a, b) + + def test_assert_dataclass_fails_assert_dict_fails_multiple_fields(self): + a = _ExampleDataclass(comparable='a', not_comparable='b', comparable2='c') + b = _ExampleDataclass(comparable='c', not_comparable='d', comparable2='e') + + message = textwrap.dedent("""\ + _ExampleDataclass(comparable='a', not_comparable='b', comparable2='c') != _ExampleDataclass(comparable='c', not_comparable='d', comparable2='e') + Fields that differ: + comparable: 'a' != 'c' + comparable2: 'c' != 'e'""") + with self.assertRaisesWithLiteralMatch(AssertionError, message): + self.assertDataclassEqual(a, b) + + class GetCommandStringTest(parameterized.TestCase): @parameterized.parameters( -- cgit v1.2.3