aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAbseil Team <absl-team@google.com>2023-10-18 12:27:30 -0700
committerCopybara-Service <copybara-worker@google.com>2023-10-18 12:28:09 -0700
commit0ff1e24e9486900a895af805e58f4e468ec5edf7 (patch)
treeb196061da151d8933a69a7dd5c18241c9986e631
parent9499935b4f8f9db17114cdb21b12b3275efe21cc (diff)
downloadabsl-py-0ff1e24e9486900a895af805e58f4e468ec5edf7.tar.gz
Add an assertDataclassEqual method that provides better errors when it fails.
PiperOrigin-RevId: 574555720
-rw-r--r--CHANGELOG.md3
-rw-r--r--absl/testing/absltest.py61
-rw-r--r--absl/testing/tests/absltest_test.py83
3 files changed, 147 insertions, 0 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index f265810..776dc3f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).
with a construct to modify values. The new interface parallels
`absl.flags.FlagValues.__setattr__` but checks that the provided value
conforms to the flag's expected type.
+* (testing) Added a new method `absltest.TestCase.assertDataclassEqual` that
+ tests equality of `dataclass.dataclass` objects with better error messages
+ when the assert fails.
### Fixed
diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py
index 61b096d..a3b4e34 100644
--- a/absl/testing/absltest.py
+++ b/absl/testing/absltest.py
@@ -20,6 +20,7 @@ tests.
from collections import abc
import contextlib
+import dataclasses
import difflib
import enum
import errno
@@ -1730,6 +1731,66 @@ class TestCase(unittest.TestCase):
raise self.failureException('\n'.join(message))
+ def assertDataclassEqual(self, first, second, msg=None):
+ """Asserts two dataclasses are equal with more informative errors.
+
+ Arguments must both be dataclasses. This compares equality of individual
+ fields and takes care to not compare fields that are marked as
+ non-comparable. It gives per field differences, which are easier to parse
+ than the comparison of the string representations from assertEqual.
+
+ In cases where the dataclass has a custom __eq__, and it is defined in a
+ way that is inconsistent with equality of comparable fields, we raise an
+ exception without further trying to figure out how they are different.
+
+ Args:
+ first: A dataclass, the first value.
+ second: A dataclass, the second value.
+ msg: An optional str, the associated message.
+
+ Raises:
+ AssertionError: if the dataclasses are not equal.
+ """
+
+ if not dataclasses.is_dataclass(first) or isinstance(first, type):
+ raise self.failureException('First argument is not a dataclass instance.')
+ if not dataclasses.is_dataclass(second) or isinstance(second, type):
+ raise self.failureException(
+ 'Second argument is not a dataclass instance.'
+ )
+
+ if first == second:
+ return
+
+ if type(first) is not type(second):
+ self.fail(
+ 'Found different dataclass types: %s != %s'
+ % (type(first), type(second)),
+ msg,
+ )
+
+ # Make sure to skip fields that are marked compare=False.
+ different = [
+ (f.name, getattr(first, f.name), getattr(second, f.name))
+ for f in dataclasses.fields(first)
+ if f.compare and getattr(first, f.name) != getattr(second, f.name)
+ ]
+
+ safe_repr = unittest.util.safe_repr # pytype: disable=module-attr
+ message = ['%s != %s' % (safe_repr(first), safe_repr(second))]
+ if different:
+ message.append('Fields that differ:')
+ message.extend(
+ '%s: %s != %s' % (k, safe_repr(first_v), safe_repr(second_v))
+ for k, first_v, second_v in different
+ )
+ else:
+ message.append(
+ 'Cannot detect difference by examining the fields of the dataclass.'
+ )
+
+ raise self.fail('\n'.join(message), msg)
+
def assertUrlEqual(self, a, b, msg=None):
"""Asserts that urls are equal, ignoring ordering of query params."""
parsed_a = parse.urlparse(a)
diff --git a/absl/testing/tests/absltest_test.py b/absl/testing/tests/absltest_test.py
index f23d856..baeccaf 100644
--- a/absl/testing/tests/absltest_test.py
+++ b/absl/testing/tests/absltest_test.py
@@ -16,6 +16,7 @@
import collections
import contextlib
+import dataclasses
import io
import os
import pathlib
@@ -24,6 +25,7 @@ import stat
import string
import subprocess
import tempfile
+import textwrap
from typing import Optional
import unittest
@@ -1974,6 +1976,87 @@ class InitNotNecessaryForAssertsTest(absltest.TestCase):
Subclass().assertEqual({}, {})
+@dataclasses.dataclass
+class _ExampleDataclass:
+ comparable: str
+ not_comparable: str = dataclasses.field(compare=False)
+ comparable2: str = 'comparable2'
+
+
+@dataclasses.dataclass
+class _ExampleCustomEqualDataclass:
+ value: str
+
+ def __eq__(self, other):
+ return False
+
+
+class TestAssertDataclassEqual(absltest.TestCase):
+
+ def test_assert_dataclass_equal_checks_a_for_dataclass(self):
+ b = _ExampleDataclass('a', 'b')
+
+ message = 'First argument is not a dataclass instance.'
+ with self.assertRaisesWithLiteralMatch(AssertionError, message):
+ self.assertDataclassEqual('a', b)
+
+ def test_assert_dataclass_equal_checks_b_for_dataclass(self):
+ a = _ExampleDataclass('a', 'b')
+
+ message = 'Second argument is not a dataclass instance.'
+ with self.assertRaisesWithLiteralMatch(AssertionError, message):
+ self.assertDataclassEqual(a, 'b')
+
+ def test_assert_dataclass_equal_different_dataclasses(self):
+ a = _ExampleDataclass('a', 'b')
+ b = _ExampleCustomEqualDataclass('c')
+
+ message = """Found different dataclass types: <class '__main__._ExampleDataclass'> != <class '__main__._ExampleCustomEqualDataclass'>"""
+ with self.assertRaisesWithLiteralMatch(AssertionError, message):
+ self.assertDataclassEqual(a, b)
+
+ def test_assert_dataclass_equal(self):
+ a = _ExampleDataclass(comparable='a', not_comparable='b')
+ b = _ExampleDataclass(comparable='a', not_comparable='c')
+
+ self.assertDataclassEqual(a, a)
+ self.assertDataclassEqual(a, b)
+ self.assertDataclassEqual(b, a)
+
+ def test_assert_dataclass_fails_non_equal_classes_assert_dict_passes(self):
+ a = _ExampleCustomEqualDataclass(value='a')
+ b = _ExampleCustomEqualDataclass(value='a')
+
+ message = textwrap.dedent("""\
+ _ExampleCustomEqualDataclass(value='a') != _ExampleCustomEqualDataclass(value='a')
+ Cannot detect difference by examining the fields of the dataclass.""")
+ with self.assertRaisesWithLiteralMatch(AssertionError, message):
+ self.assertDataclassEqual(a, b)
+
+ def test_assert_dataclass_fails_assert_dict_fails_one_field(self):
+ a = _ExampleDataclass(comparable='a', not_comparable='b')
+ b = _ExampleDataclass(comparable='c', not_comparable='d')
+
+ message = textwrap.dedent("""\
+ _ExampleDataclass(comparable='a', not_comparable='b', comparable2='comparable2') != _ExampleDataclass(comparable='c', not_comparable='d', comparable2='comparable2')
+ Fields that differ:
+ comparable: 'a' != 'c'""")
+ with self.assertRaisesWithLiteralMatch(AssertionError, message):
+ self.assertDataclassEqual(a, b)
+
+ def test_assert_dataclass_fails_assert_dict_fails_multiple_fields(self):
+ a = _ExampleDataclass(comparable='a', not_comparable='b', comparable2='c')
+ b = _ExampleDataclass(comparable='c', not_comparable='d', comparable2='e')
+
+ message = textwrap.dedent("""\
+ _ExampleDataclass(comparable='a', not_comparable='b', comparable2='c') != _ExampleDataclass(comparable='c', not_comparable='d', comparable2='e')
+ Fields that differ:
+ comparable: 'a' != 'c'
+ comparable2: 'c' != 'e'""")
+ with self.assertRaisesWithLiteralMatch(AssertionError, message):
+ self.assertDataclassEqual(a, b)
+
+
class GetCommandStringTest(parameterized.TestCase):
@parameterized.parameters(