aboutsummaryrefslogtreecommitdiff
path: root/tests/test_callbacks.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_callbacks.py')
-rw-r--r--tests/test_callbacks.py84
1 files changed, 77 insertions, 7 deletions
diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py
index 039b877c..4a652f53 100644
--- a/tests/test_callbacks.py
+++ b/tests/test_callbacks.py
@@ -1,7 +1,11 @@
-# -*- coding: utf-8 -*-
+import time
+from threading import Thread
+
import pytest
+
+import env # noqa: F401
from pybind11_tests import callbacks as m
-from threading import Thread
+from pybind11_tests import detailed_error_messages_enabled
def test_callbacks():
@@ -14,7 +18,7 @@ def test_callbacks():
return "func2", a, b, c, d
def func3(a):
- return "func3({})".format(a)
+ return f"func3({a})"
assert m.test_callback1(func1) == "func1"
assert m.test_callback2(func2) == ("func2", "Hello", "x", True, 5)
@@ -67,21 +71,35 @@ def test_keyword_args_and_generalized_unpacking():
with pytest.raises(RuntimeError) as excinfo:
m.test_arg_conversion_error1(f)
- assert "Unable to convert call argument" in str(excinfo.value)
+ assert str(excinfo.value) == "Unable to convert call argument " + (
+ "'1' of type 'UnregisteredType' to Python object"
+ if detailed_error_messages_enabled
+ else "'1' to Python object (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)"
+ )
with pytest.raises(RuntimeError) as excinfo:
m.test_arg_conversion_error2(f)
- assert "Unable to convert call argument" in str(excinfo.value)
+ assert str(excinfo.value) == "Unable to convert call argument " + (
+ "'expected_name' of type 'UnregisteredType' to Python object"
+ if detailed_error_messages_enabled
+ else "'expected_name' to Python object "
+ "(#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)"
+ )
def test_lambda_closure_cleanup():
- m.test_cleanup()
+ m.test_lambda_closure_cleanup()
cstats = m.payload_cstats()
assert cstats.alive() == 0
assert cstats.copy_constructions == 1
assert cstats.move_constructions >= 1
+def test_cpp_callable_cleanup():
+ alive_counts = m.test_cpp_callable_cleanup()
+ assert alive_counts == [0, 1, 2, 1, 2, 1, 0]
+
+
def test_cpp_function_roundtrip():
"""Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer"""
@@ -92,6 +110,10 @@ def test_cpp_function_roundtrip():
m.test_dummy_function(m.roundtrip(m.dummy_function))
== "matches dummy_function: eval(1) = 2"
)
+ assert (
+ m.test_dummy_function(m.dummy_function_overloaded)
+ == "matches dummy_function: eval(1) = 2"
+ )
assert m.roundtrip(None, expect_none=True) is None
assert (
m.test_dummy_function(lambda x: x + 2)
@@ -119,6 +141,16 @@ def test_movable_object():
assert m.callback_with_movable(lambda _: None) is True
+@pytest.mark.skipif(
+ "env.PYPY",
+ reason="PyPy segfaults on here. See discussion on #1413.",
+)
+def test_python_builtins():
+ """Test if python builtins like sum() can be used as callbacks"""
+ assert m.test_sum_builtin(sum, [1, 2, 3]) == 6
+ assert m.test_sum_builtin(sum, []) == 0
+
+
def test_async_callbacks():
# serves as state for async callback
class Item:
@@ -139,10 +171,48 @@ def test_async_callbacks():
from time import sleep
sleep(0.5)
- assert sum(res) == sum([x + 3 for x in work])
+ assert sum(res) == sum(x + 3 for x in work)
def test_async_async_callbacks():
t = Thread(target=test_async_callbacks)
t.start()
t.join()
+
+
+def test_callback_num_times():
+ # Super-simple micro-benchmarking related to PR #2919.
+ # Example runtimes (Intel Xeon 2.2GHz, fully optimized):
+ # num_millions 1, repeats 2: 0.1 secs
+ # num_millions 20, repeats 10: 11.5 secs
+ one_million = 1000000
+ num_millions = 1 # Try 20 for actual micro-benchmarking.
+ repeats = 2 # Try 10.
+ rates = []
+ for rep in range(repeats):
+ t0 = time.time()
+ m.callback_num_times(lambda: None, num_millions * one_million)
+ td = time.time() - t0
+ rate = num_millions / td if td else 0
+ rates.append(rate)
+ if not rep:
+ print()
+ print(
+ f"callback_num_times: {num_millions:d} million / {td:.3f} seconds = {rate:.3f} million / second"
+ )
+ if len(rates) > 1:
+ print("Min Mean Max")
+ print(f"{min(rates):6.3f} {sum(rates) / len(rates):6.3f} {max(rates):6.3f}")
+
+
+def test_custom_func():
+ assert m.custom_function(4) == 36
+ assert m.roundtrip(m.custom_function)(4) == 36
+
+
+@pytest.mark.skipif(
+ m.custom_function2 is None, reason="Current PYBIND11_INTERNALS_VERSION too low"
+)
+def test_custom_func2():
+ assert m.custom_function2(3) == 27
+ assert m.roundtrip(m.custom_function2)(3) == 27