aboutsummaryrefslogtreecommitdiff
path: root/typing_extensions/src/test_typing_extensions.py
diff options
context:
space:
mode:
Diffstat (limited to 'typing_extensions/src/test_typing_extensions.py')
-rw-r--r--typing_extensions/src/test_typing_extensions.py19
1 files changed, 17 insertions, 2 deletions
diff --git a/typing_extensions/src/test_typing_extensions.py b/typing_extensions/src/test_typing_extensions.py
index 53a8343..b8fe5e3 100644
--- a/typing_extensions/src/test_typing_extensions.py
+++ b/typing_extensions/src/test_typing_extensions.py
@@ -12,7 +12,7 @@ import types
from unittest import TestCase, main, skipUnless, skipIf
from test import ann_module, ann_module2, ann_module3
import typing
-from typing import TypeVar, Optional, Union, Any
+from typing import TypeVar, Optional, Union, Any, AnyStr
from typing import T, KT, VT # Not in __all__.
from typing import Tuple, List, Dict, Iterable, Iterator, Callable
from typing import Generic, NamedTuple
@@ -23,7 +23,7 @@ from typing_extensions import TypeAlias, ParamSpec, Concatenate, ParamSpecArgs,
from typing_extensions import Awaitable, AsyncIterator, AsyncContextManager, Required, NotRequired
from typing_extensions import Protocol, runtime, runtime_checkable, Annotated, overload, final, is_typeddict
from typing_extensions import TypeVarTuple, Unpack, dataclass_transform, reveal_type, Never, assert_never, LiteralString
-from typing_extensions import get_type_hints, get_origin, get_args
+from typing_extensions import assert_type, get_type_hints, get_origin, get_args
# Flags used to mark tests that only apply after a specific
# version of the typing module.
@@ -425,6 +425,21 @@ class OverloadTests(BaseTestCase):
blah()
+class AssertTypeTests(BaseTestCase):
+
+ def test_basics(self):
+ arg = 42
+ self.assertIs(assert_type(arg, int), arg)
+ self.assertIs(assert_type(arg, Union[str, float]), arg)
+ self.assertIs(assert_type(arg, AnyStr), arg)
+ self.assertIs(assert_type(arg, None), arg)
+
+ def test_errors(self):
+ # Bogus calls are not expected to fail.
+ arg = 42
+ self.assertIs(assert_type(arg, 42), arg)
+ self.assertIs(assert_type(arg, 'hello'), arg)
+
T_a = TypeVar('T_a')