diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2021-11-25 20:05:29 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2021-11-25 20:05:29 +0000 |
commit | 2f0421fc344db1fc129b5ece09f6c9b662648c2e (patch) | |
tree | 7c474e5234417d4e078c4a115ce27112bf34c301 | |
parent | 6c9a00a78b690ca3f86b75d07c5270e21facad62 (diff) | |
parent | b8563af5d8e399995ad8a5d01bf40ee260885b32 (diff) | |
download | tokio-android12-mainline-mediaprovider-release.tar.gz |
Snap for 7947225 from b8563af5d8e399995ad8a5d01bf40ee260885b32 to mainline-mediaprovider-releaseandroid-mainline-12.0.0_r90android-mainline-12.0.0_r76android-mainline-12.0.0_r121android-mainline-12.0.0_r106aml_mpr_311911090android12-mainline-mediaprovider-release
Change-Id: I1319ffb9800fb9ba5c4f6e30ce14f5fe2ae7153b
269 files changed, 20073 insertions, 4465 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json index b75e29b..92e2ebe 100644 --- a/.cargo_vcs_info.json +++ b/.cargo_vcs_info.json @@ -1,5 +1,5 @@ { "git": { - "sha1": "a5ee2f0d3d78daa01e2c6c12d22b82474dc5c32a" + "sha1": "623c09c52c2c38a8d75e94c166593547e8477707" } } @@ -1,4 +1,4 @@ -// This file is generated by cargo2android.py --device --run --features io-util,macros,rt-multi-thread,sync,net,fs,time --tests --patch=patches/Android.bp.patch. +// This file is generated by cargo2android.py --config cargo2android.json. // Do not modify this file as changes will be overridden on upgrade. package { @@ -39,6 +39,7 @@ rust_library { "sync", "time", "tokio-macros", + "winapi", ], cfgs: ["tokio_track_caller"], rustlibs: [ @@ -58,7 +59,7 @@ rust_library { } rust_defaults { - name: "tokio_defaults", + name: "tokio_defaults_tokio", crate_name: "tokio", test_suites: ["general-tests"], auto_gen_config: true, @@ -100,8 +101,23 @@ rust_defaults { } rust_test_host { + name: "tokio_host_test_tests__require_full", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/_require_full.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests__require_full", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/_require_full.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_buffered", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/buffered.rs"], test_options: { unit_test: true, @@ -110,13 +126,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_buffered", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/buffered.rs"], } rust_test_host { name: "tokio_host_test_tests_io_async_read", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_async_read.rs"], test_options: { unit_test: true, @@ -125,13 +141,43 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_async_read", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_async_read.rs"], } rust_test_host { + name: "tokio_host_test_tests_io_chain", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_chain.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_chain", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_chain.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_copy", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_copy.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_copy", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_copy.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_io_copy_bidirectional", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_copy_bidirectional.rs"], test_options: { unit_test: true, @@ -140,13 +186,43 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_copy_bidirectional", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_copy_bidirectional.rs"], } rust_test_host { + name: "tokio_host_test_tests_io_driver", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_driver.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_driver", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_driver.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_driver_drop", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_driver_drop.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_driver_drop", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_driver_drop.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_io_lines", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_lines.rs"], test_options: { unit_test: true, @@ -155,13 +231,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_lines", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_lines.rs"], } rust_test_host { name: "tokio_host_test_tests_io_mem_stream", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_mem_stream.rs"], test_options: { unit_test: true, @@ -170,13 +246,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_mem_stream", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_mem_stream.rs"], } rust_test_host { name: "tokio_host_test_tests_io_read", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_read.rs"], test_options: { unit_test: true, @@ -185,13 +261,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_read", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_read.rs"], } rust_test_host { name: "tokio_host_test_tests_io_read_buf", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_read_buf.rs"], test_options: { unit_test: true, @@ -200,13 +276,43 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_read_buf", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_read_buf.rs"], } rust_test_host { + name: "tokio_host_test_tests_io_read_exact", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_exact.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_exact", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_exact.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read_line", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_line.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_line", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_line.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_io_read_to_end", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_read_to_end.rs"], test_options: { unit_test: true, @@ -215,13 +321,58 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_read_to_end", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_read_to_end.rs"], } rust_test_host { + name: "tokio_host_test_tests_io_read_to_string", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_to_string.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_to_string", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_to_string.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read_until", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_until.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_until", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_until.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_split.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_split.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_io_take", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_take.rs"], test_options: { unit_test: true, @@ -230,13 +381,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_take", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_take.rs"], } rust_test_host { name: "tokio_host_test_tests_io_write", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_write.rs"], test_options: { unit_test: true, @@ -245,13 +396,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_write", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_write.rs"], } rust_test_host { name: "tokio_host_test_tests_io_write_all", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_write_all.rs"], test_options: { unit_test: true, @@ -260,13 +411,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_write_all", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_write_all.rs"], } rust_test_host { name: "tokio_host_test_tests_io_write_buf", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_write_buf.rs"], test_options: { unit_test: true, @@ -275,13 +426,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_write_buf", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_write_buf.rs"], } rust_test_host { name: "tokio_host_test_tests_io_write_int", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_write_int.rs"], test_options: { unit_test: true, @@ -290,13 +441,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_io_write_int", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/io_write_int.rs"], } rust_test_host { name: "tokio_host_test_tests_macros_join", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/macros_join.rs"], test_options: { unit_test: true, @@ -305,13 +456,103 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_macros_join", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/macros_join.rs"], } rust_test_host { + name: "tokio_host_test_tests_macros_pin", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_pin.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_pin", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_pin.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_macros_select", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_select.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_select", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_select.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_macros_test", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_test.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_test", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_test.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_macros_try_join", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_try_join.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_try_join", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_try_join.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_net_bind_resource", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/net_bind_resource.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_net_bind_resource", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/net_bind_resource.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_net_lookup_host", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/net_lookup_host.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_net_lookup_host", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/net_lookup_host.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_no_rt", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/no_rt.rs"], test_options: { unit_test: true, @@ -320,13 +561,28 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_no_rt", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/no_rt.rs"], } rust_test_host { + name: "tokio_host_test_tests_process_kill_on_drop", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/process_kill_on_drop.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_process_kill_on_drop", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/process_kill_on_drop.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_rt_basic", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/rt_basic.rs"], test_options: { unit_test: true, @@ -335,13 +591,29 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_rt_basic", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/rt_basic.rs"], } rust_test_host { + name: "tokio_host_test_tests_rt_common", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/rt_common.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_rt_common", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/rt_common.rs"], +} + + +rust_test_host { name: "tokio_host_test_tests_rt_threaded", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/rt_threaded.rs"], test_options: { unit_test: true, @@ -350,13 +622,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_rt_threaded", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/rt_threaded.rs"], } rust_test_host { name: "tokio_host_test_tests_sync_barrier", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_barrier.rs"], test_options: { unit_test: true, @@ -365,13 +637,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_sync_barrier", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_barrier.rs"], } rust_test_host { name: "tokio_host_test_tests_sync_broadcast", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_broadcast.rs"], test_options: { unit_test: true, @@ -380,13 +652,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_sync_broadcast", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_broadcast.rs"], } rust_test_host { name: "tokio_host_test_tests_sync_errors", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_errors.rs"], test_options: { unit_test: true, @@ -395,13 +667,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_sync_errors", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_errors.rs"], } rust_test_host { name: "tokio_host_test_tests_sync_mpsc", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_mpsc.rs"], test_options: { unit_test: true, @@ -410,13 +682,28 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_sync_mpsc", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_mpsc.rs"], } rust_test_host { + name: "tokio_host_test_tests_sync_mutex", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_mutex.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_mutex", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_mutex.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_sync_mutex_owned", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_mutex_owned.rs"], test_options: { unit_test: true, @@ -425,13 +712,43 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_sync_mutex_owned", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_mutex_owned.rs"], } rust_test_host { + name: "tokio_host_test_tests_sync_notify", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_notify.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_notify", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_notify.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_oneshot", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_oneshot.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_oneshot", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_oneshot.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_sync_rwlock", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_rwlock.rs"], test_options: { unit_test: true, @@ -440,13 +757,43 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_sync_rwlock", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_rwlock.rs"], } rust_test_host { + name: "tokio_host_test_tests_sync_semaphore", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_semaphore.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_semaphore", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_semaphore.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_semaphore_owned", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_semaphore_owned.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_semaphore_owned", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_semaphore_owned.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_sync_watch", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_watch.rs"], test_options: { unit_test: true, @@ -455,13 +802,43 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_sync_watch", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/sync_watch.rs"], } rust_test_host { + name: "tokio_host_test_tests_task_abort", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_abort.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_task_abort", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_abort.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_task_blocking", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_blocking.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_task_blocking", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_blocking.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_task_local", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/task_local.rs"], test_options: { unit_test: true, @@ -470,13 +847,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_task_local", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/task_local.rs"], } rust_test_host { name: "tokio_host_test_tests_task_local_set", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/task_local_set.rs"], test_options: { unit_test: true, @@ -485,13 +862,13 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_task_local_set", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/task_local_set.rs"], } rust_test_host { name: "tokio_host_test_tests_tcp_accept", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/tcp_accept.rs"], test_options: { unit_test: true, @@ -500,13 +877,28 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_tcp_accept", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/tcp_accept.rs"], } rust_test_host { + name: "tokio_host_test_tests_tcp_connect", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_connect.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_connect", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_connect.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_tcp_echo", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/tcp_echo.rs"], test_options: { unit_test: true, @@ -515,13 +907,28 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_tcp_echo", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/tcp_echo.rs"], } rust_test_host { + name: "tokio_host_test_tests_tcp_into_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_into_split.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_into_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_into_split.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_tcp_into_std", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/tcp_into_std.rs"], test_options: { unit_test: true, @@ -530,13 +937,28 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_tcp_into_std", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/tcp_into_std.rs"], } rust_test_host { + name: "tokio_host_test_tests_tcp_peek", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_peek.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_peek", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_peek.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_tcp_shutdown", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/tcp_shutdown.rs"], test_options: { unit_test: true, @@ -545,13 +967,43 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_tcp_shutdown", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/tcp_shutdown.rs"], } rust_test_host { + name: "tokio_host_test_tests_tcp_socket", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_socket.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_socket", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_socket.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_split.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_split.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_time_rt", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/time_rt.rs"], test_options: { unit_test: true, @@ -560,13 +1012,43 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_time_rt", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/time_rt.rs"], } rust_test_host { + name: "tokio_host_test_tests_udp", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/udp.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_udp", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/udp.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_uds_cred", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/uds_cred.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_uds_cred", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/uds_cred.rs"], +} + +rust_test_host { name: "tokio_host_test_tests_uds_split", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/uds_split.rs"], test_options: { unit_test: true, @@ -575,6 +1057,6 @@ rust_test_host { rust_test { name: "tokio_device_test_tests_uds_split", - defaults: ["tokio_defaults"], + defaults: ["tokio_defaults_tokio"], srcs: ["tests/uds_split.rs"], } diff --git a/CHANGELOG.md b/CHANGELOG.md index 0808920..afa8bf0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,457 @@ +# 1.14.0 (November 15, 2021) + +### Fixed + +- macros: fix compiler errors when using `mut` patterns in `select!` ([#4211]) +- sync: fix a data race between `oneshot::Sender::send` and awaiting a + `oneshot::Receiver` when the oneshot has been closed ([#4226]) +- sync: make `AtomicWaker` panic safe ([#3689]) +- runtime: fix basic scheduler dropping tasks outside a runtime context + ([#4213]) + +### Added + +- stats: add `RuntimeStats::busy_duration_total` ([#4179], [#4223]) + +### Changed + +- io: updated `copy` buffer size to match `std::io::copy` ([#4209]) + +### Documented + +- io: rename buffer to file in doc-test ([#4230]) +- sync: fix Notify example ([#4212]) + +[#4211]: https://github.com/tokio-rs/tokio/pull/4211 +[#4226]: https://github.com/tokio-rs/tokio/pull/4226 +[#3689]: https://github.com/tokio-rs/tokio/pull/3689 +[#4213]: https://github.com/tokio-rs/tokio/pull/4213 +[#4179]: https://github.com/tokio-rs/tokio/pull/4179 +[#4223]: https://github.com/tokio-rs/tokio/pull/4223 +[#4209]: https://github.com/tokio-rs/tokio/pull/4209 +[#4230]: https://github.com/tokio-rs/tokio/pull/4230 +[#4212]: https://github.com/tokio-rs/tokio/pull/4212 + +# 1.13.1 (November 15, 2021) + +### Fixed + +- sync: fix a data race between `oneshot::Sender::send` and awaiting a + `oneshot::Receiver` when the oneshot has been closed ([#4226]) + +[#4226]: https://github.com/tokio-rs/tokio/pull/4226 + +# 1.13.0 (October 29, 2021) + +### Fixed + +- sync: fix `Notify` to clone the waker before locking its waiter list ([#4129]) +- tokio: add riscv32 to non atomic64 architectures ([#4185]) + +### Added + +- net: add `poll_{recv,send}_ready` methods to `udp` and `uds_datagram` ([#4131]) +- net: add `try_*`, `readable`, `writable`, `ready`, and `peer_addr` methods to split halves ([#4120]) +- sync: add `blocking_lock` to `Mutex` ([#4130]) +- sync: add `watch::Sender::send_replace` ([#3962], [#4195]) +- sync: expand `Debug` for `Mutex<T>` impl to unsized `T` ([#4134]) +- tracing: instrument time::Sleep ([#4072]) +- tracing: use structured location fields for spawned tasks ([#4128]) + +### Changed + +- io: add assert in `copy_bidirectional` that `poll_write` is sensible ([#4125]) +- macros: use qualified syntax when polling in `select!` ([#4192]) +- runtime: handle `block_on` wakeups better ([#4157]) +- task: allocate callback on heap immediately in debug mode ([#4203]) +- tokio: assert platform-minimum requirements at build time ([#3797]) + +### Documented + +- docs: conversion of doc comments to indicative mood ([#4174]) +- docs: add returning on the first error example for `try_join!` ([#4133]) +- docs: fixing broken links in `tokio/src/lib.rs` ([#4132]) +- signal: add example with background listener ([#4171]) +- sync: add more oneshot examples ([#4153]) +- time: document `Interval::tick` cancel safety ([#4152]) + +[#3797]: https://github.com/tokio-rs/tokio/pull/3797 +[#3962]: https://github.com/tokio-rs/tokio/pull/3962 +[#4072]: https://github.com/tokio-rs/tokio/pull/4072 +[#4120]: https://github.com/tokio-rs/tokio/pull/4120 +[#4125]: https://github.com/tokio-rs/tokio/pull/4125 +[#4128]: https://github.com/tokio-rs/tokio/pull/4128 +[#4129]: https://github.com/tokio-rs/tokio/pull/4129 +[#4130]: https://github.com/tokio-rs/tokio/pull/4130 +[#4131]: https://github.com/tokio-rs/tokio/pull/4131 +[#4132]: https://github.com/tokio-rs/tokio/pull/4132 +[#4133]: https://github.com/tokio-rs/tokio/pull/4133 +[#4134]: https://github.com/tokio-rs/tokio/pull/4134 +[#4152]: https://github.com/tokio-rs/tokio/pull/4152 +[#4153]: https://github.com/tokio-rs/tokio/pull/4153 +[#4157]: https://github.com/tokio-rs/tokio/pull/4157 +[#4171]: https://github.com/tokio-rs/tokio/pull/4171 +[#4174]: https://github.com/tokio-rs/tokio/pull/4174 +[#4185]: https://github.com/tokio-rs/tokio/pull/4185 +[#4192]: https://github.com/tokio-rs/tokio/pull/4192 +[#4195]: https://github.com/tokio-rs/tokio/pull/4195 +[#4203]: https://github.com/tokio-rs/tokio/pull/4203 + +# 1.12.0 (September 21, 2021) + +### Fixed + +- mpsc: ensure `try_reserve` error is consistent with `try_send` ([#4119]) +- mpsc: use `spin_loop_hint` instead of `yield_now` ([#4115]) +- sync: make `SendError` field public ([#4097]) + +### Added + +- io: add POSIX AIO on FreeBSD ([#4054]) +- io: add convenience method `AsyncSeekExt::rewind` ([#4107]) +- runtime: add tracing span for `block_on` futures ([#4094]) +- runtime: callback when a worker parks and unparks ([#4070]) +- sync: implement `try_recv` for mpsc channels ([#4113]) + +### Documented + +- docs: clarify CPU-bound tasks on Tokio ([#4105]) +- mpsc: document spurious failures on `poll_recv` ([#4117]) +- mpsc: document that `PollSender` impls `Sink` ([#4110]) +- task: document non-guarantees of `yield_now` ([#4091]) +- time: document paused time details better ([#4061], [#4103]) + +[#4027]: https://github.com/tokio-rs/tokio/pull/4027 +[#4054]: https://github.com/tokio-rs/tokio/pull/4054 +[#4061]: https://github.com/tokio-rs/tokio/pull/4061 +[#4070]: https://github.com/tokio-rs/tokio/pull/4070 +[#4091]: https://github.com/tokio-rs/tokio/pull/4091 +[#4094]: https://github.com/tokio-rs/tokio/pull/4094 +[#4097]: https://github.com/tokio-rs/tokio/pull/4097 +[#4103]: https://github.com/tokio-rs/tokio/pull/4103 +[#4105]: https://github.com/tokio-rs/tokio/pull/4105 +[#4107]: https://github.com/tokio-rs/tokio/pull/4107 +[#4110]: https://github.com/tokio-rs/tokio/pull/4110 +[#4113]: https://github.com/tokio-rs/tokio/pull/4113 +[#4115]: https://github.com/tokio-rs/tokio/pull/4115 +[#4117]: https://github.com/tokio-rs/tokio/pull/4117 +[#4119]: https://github.com/tokio-rs/tokio/pull/4119 + +# 1.11.0 (August 31, 2021) + +### Fixed + + - time: don't panic when Instant is not monotonic ([#4044]) + - io: fix panic in `fill_buf` by not calling `poll_fill_buf` twice ([#4084]) + +### Added + + - watch: add `watch::Sender::subscribe` ([#3800]) + - process: add `from_std` to `ChildStd*` ([#4045]) + - stats: initial work on runtime stats ([#4043]) + +### Changed + + - tracing: change span naming to new console convention ([#4042]) + - io: speed-up waking by using uninitialized array ([#4055], [#4071], [#4075]) + +### Documented + + - time: make Sleep examples easier to find ([#4040]) + +[#3800]: https://github.com/tokio-rs/tokio/pull/3800 +[#4040]: https://github.com/tokio-rs/tokio/pull/4040 +[#4042]: https://github.com/tokio-rs/tokio/pull/4042 +[#4043]: https://github.com/tokio-rs/tokio/pull/4043 +[#4044]: https://github.com/tokio-rs/tokio/pull/4044 +[#4045]: https://github.com/tokio-rs/tokio/pull/4045 +[#4055]: https://github.com/tokio-rs/tokio/pull/4055 +[#4071]: https://github.com/tokio-rs/tokio/pull/4071 +[#4075]: https://github.com/tokio-rs/tokio/pull/4075 +[#4084]: https://github.com/tokio-rs/tokio/pull/4084 + +# 1.10.1 (August 24, 2021) + +### Fixed + + - runtime: fix leak in UnownedTask ([#4063]) + +[#4063]: https://github.com/tokio-rs/tokio/pull/4063 + +# 1.10.0 (August 12, 2021) + +### Added + + - io: add `(read|write)_f(32|64)[_le]` methods ([#4022]) + - io: add `fill_buf` and `consume` to `AsyncBufReadExt` ([#3991]) + - process: add `Child::raw_handle()` on windows ([#3998]) + +### Fixed + + - doc: fix non-doc builds with `--cfg docsrs` ([#4020]) + - io: flush eagerly in `io::copy` ([#4001]) + - runtime: a debug assert was sometimes triggered during shutdown ([#4005]) + - sync: use `spin_loop_hint` instead of `yield_now` in mpsc ([#4037]) + - tokio: the test-util feature depends on rt, sync, and time ([#4036]) + +### Changes + + - runtime: reorganize parts of the runtime ([#3979], [#4005]) + - signal: make windows docs for signal module show up on unix builds ([#3770]) + - task: quickly send task to heap on debug mode ([#4009]) + +### Documented + + - io: document cancellation safety of `AsyncBufReadExt` ([#3997]) + - sync: document when `watch::send` fails ([#4021]) + +[#3770]: https://github.com/tokio-rs/tokio/pull/3770 +[#3979]: https://github.com/tokio-rs/tokio/pull/3979 +[#3991]: https://github.com/tokio-rs/tokio/pull/3991 +[#3997]: https://github.com/tokio-rs/tokio/pull/3997 +[#3998]: https://github.com/tokio-rs/tokio/pull/3998 +[#4001]: https://github.com/tokio-rs/tokio/pull/4001 +[#4005]: https://github.com/tokio-rs/tokio/pull/4005 +[#4009]: https://github.com/tokio-rs/tokio/pull/4009 +[#4020]: https://github.com/tokio-rs/tokio/pull/4020 +[#4021]: https://github.com/tokio-rs/tokio/pull/4021 +[#4022]: https://github.com/tokio-rs/tokio/pull/4022 +[#4036]: https://github.com/tokio-rs/tokio/pull/4036 +[#4037]: https://github.com/tokio-rs/tokio/pull/4037 + +# 1.9.0 (July 22, 2021) + +### Added + + - net: allow customized I/O operations for `TcpStream` ([#3888]) + - sync: add getter for the mutex from a guard ([#3928]) + - task: expose nameable future for `TaskLocal::scope` ([#3273]) + +### Fixed + + - Fix leak if output of future panics on drop ([#3967]) + - Fix leak in `LocalSet` ([#3978]) + +### Changes + + - runtime: reorganize parts of the runtime ([#3909], [#3939], [#3950], [#3955], [#3980]) + - sync: clean up `OnceCell` ([#3945]) + - task: remove mutex in `JoinError` ([#3959]) + +[#3273]: https://github.com/tokio-rs/tokio/pull/3273 +[#3888]: https://github.com/tokio-rs/tokio/pull/3888 +[#3909]: https://github.com/tokio-rs/tokio/pull/3909 +[#3928]: https://github.com/tokio-rs/tokio/pull/3928 +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 +[#3939]: https://github.com/tokio-rs/tokio/pull/3939 +[#3945]: https://github.com/tokio-rs/tokio/pull/3945 +[#3950]: https://github.com/tokio-rs/tokio/pull/3950 +[#3955]: https://github.com/tokio-rs/tokio/pull/3955 +[#3959]: https://github.com/tokio-rs/tokio/pull/3959 +[#3967]: https://github.com/tokio-rs/tokio/pull/3967 +[#3978]: https://github.com/tokio-rs/tokio/pull/3978 +[#3980]: https://github.com/tokio-rs/tokio/pull/3980 + +# 1.8.3 (July 26, 2021) + +This release backports two fixes from 1.9.0 + +### Fixed + + - Fix leak if output of future panics on drop ([#3967]) + - Fix leak in `LocalSet` ([#3978]) + +[#3967]: https://github.com/tokio-rs/tokio/pull/3967 +[#3978]: https://github.com/tokio-rs/tokio/pull/3978 + +# 1.8.2 (July 19, 2021) + +Fixes a missed edge case from 1.8.1. + +### Fixed + +- runtime: drop canceled future on next poll (#3965) + +# 1.8.1 (July 6, 2021) + +Forward ports 1.5.1 fixes. + +### Fixed + +- runtime: remotely abort tasks on `JoinHandle::abort` ([#3934]) + +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 + +# 1.8.0 (July 2, 2021) + +### Added + +- io: add `get_{ref,mut}` methods to `AsyncFdReadyGuard` and `AsyncFdReadyMutGuard` ([#3807]) +- io: efficient implementation of vectored writes for `BufWriter` ([#3163]) +- net: add ready/try methods to `NamedPipe{Client,Server}` ([#3866], [#3899]) +- sync: add `watch::Receiver::borrow_and_update` ([#3813]) +- sync: implement `From<T>` for `OnceCell<T>` ([#3877]) +- time: allow users to specify Interval behaviour when delayed ([#3721]) + +### Added (unstable) + +- rt: add `tokio::task::Builder` ([#3881]) + +### Fixed + +- net: handle HUP event with `UnixStream` ([#3898]) + +### Documented + +- doc: document cancellation safety ([#3900]) +- time: add wait alias to sleep ([#3897]) +- time: document auto-advancing behaviour of runtime ([#3763]) + +[#3163]: https://github.com/tokio-rs/tokio/pull/3163 +[#3721]: https://github.com/tokio-rs/tokio/pull/3721 +[#3763]: https://github.com/tokio-rs/tokio/pull/3763 +[#3807]: https://github.com/tokio-rs/tokio/pull/3807 +[#3813]: https://github.com/tokio-rs/tokio/pull/3813 +[#3866]: https://github.com/tokio-rs/tokio/pull/3866 +[#3877]: https://github.com/tokio-rs/tokio/pull/3877 +[#3881]: https://github.com/tokio-rs/tokio/pull/3881 +[#3897]: https://github.com/tokio-rs/tokio/pull/3897 +[#3898]: https://github.com/tokio-rs/tokio/pull/3898 +[#3899]: https://github.com/tokio-rs/tokio/pull/3899 +[#3900]: https://github.com/tokio-rs/tokio/pull/3900 + +# 1.7.2 (July 6, 2021) + +Forward ports 1.5.1 fixes. + +### Fixed + +- runtime: remotely abort tasks on `JoinHandle::abort` ([#3934]) + +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 + +# 1.7.1 (June 18, 2021) + +### Fixed + +- runtime: fix early task shutdown during runtime shutdown ([#3870]) + +[#3870]: https://github.com/tokio-rs/tokio/pull/3870 + +# 1.7.0 (June 15, 2021) + +### Added + +- net: add named pipes on windows ([#3760]) +- net: add `TcpSocket` from `std::net::TcpStream` conversion ([#3838]) +- sync: add `receiver_count` to `watch::Sender` ([#3729]) +- sync: export `sync::notify::Notified` future publicly ([#3840]) +- tracing: instrument task wakers ([#3836]) + +### Fixed + +- macros: suppress `clippy::default_numeric_fallback` lint in generated code ([#3831]) +- runtime: immediately drop new tasks when runtime is shut down ([#3752]) +- sync: deprecate unused `mpsc::RecvError` type ([#3833]) + +### Documented + +- io: clarify EOF condition for `AsyncReadExt::read_buf` ([#3850]) +- io: clarify limits on return values of `AsyncWrite::poll_write` ([#3820]) +- sync: add examples to Semaphore ([#3808]) + +[#3729]: https://github.com/tokio-rs/tokio/pull/3729 +[#3752]: https://github.com/tokio-rs/tokio/pull/3752 +[#3760]: https://github.com/tokio-rs/tokio/pull/3760 +[#3808]: https://github.com/tokio-rs/tokio/pull/3808 +[#3820]: https://github.com/tokio-rs/tokio/pull/3820 +[#3831]: https://github.com/tokio-rs/tokio/pull/3831 +[#3833]: https://github.com/tokio-rs/tokio/pull/3833 +[#3836]: https://github.com/tokio-rs/tokio/pull/3836 +[#3838]: https://github.com/tokio-rs/tokio/pull/3838 +[#3840]: https://github.com/tokio-rs/tokio/pull/3840 +[#3850]: https://github.com/tokio-rs/tokio/pull/3850 + +# 1.6.3 (July 6, 2021) + +Forward ports 1.5.1 fixes. + +### Fixed + +- runtime: remotely abort tasks on `JoinHandle::abort` ([#3934]) + +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 + +# 1.6.2 (June 14, 2021) + +### Fixes + +- test: sub-ms `time:advance` regression introduced in 1.6 ([#3852]) + +[#3852]: https://github.com/tokio-rs/tokio/pull/3852 + +# 1.6.1 (May 28, 2021) + +This release reverts [#3518] because it doesn't work on some kernels due to +a kernel bug. ([#3803]) + +[#3518]: https://github.com/tokio-rs/tokio/issues/3518 +[#3803]: https://github.com/tokio-rs/tokio/issues/3803 + +# 1.6.0 (May 14, 2021) + +### Added + +- fs: try doing a non-blocking read before punting to the threadpool ([#3518]) +- io: add `write_all_buf` to `AsyncWriteExt` ([#3737]) +- io: implement `AsyncSeek` for `BufReader`, `BufWriter`, and `BufStream` ([#3491]) +- net: support non-blocking vectored I/O ([#3761]) +- sync: add `mpsc::Sender::{reserve_owned, try_reserve_owned}` ([#3704]) +- sync: add a `MutexGuard::map` method that returns a `MappedMutexGuard` ([#2472]) +- time: add getter for Interval's period ([#3705]) + +### Fixed + +- io: wake pending writers on `DuplexStream` close ([#3756]) +- process: avoid redundant effort to reap orphan processes ([#3743]) +- signal: use `std::os::raw::c_int` instead of `libc::c_int` on public API ([#3774]) +- sync: preserve permit state in `notify_waiters` ([#3660]) +- task: update `JoinHandle` panic message ([#3727]) +- time: prevent `time::advance` from going too far ([#3712]) + +### Documented + +- net: hide `net::unix::datagram` module from docs ([#3775]) +- process: updated example ([#3748]) +- sync: `Barrier` doc should use task, not thread ([#3780]) +- task: update documentation on `block_in_place` ([#3753]) + +[#2472]: https://github.com/tokio-rs/tokio/pull/2472 +[#3491]: https://github.com/tokio-rs/tokio/pull/3491 +[#3518]: https://github.com/tokio-rs/tokio/pull/3518 +[#3660]: https://github.com/tokio-rs/tokio/pull/3660 +[#3704]: https://github.com/tokio-rs/tokio/pull/3704 +[#3705]: https://github.com/tokio-rs/tokio/pull/3705 +[#3712]: https://github.com/tokio-rs/tokio/pull/3712 +[#3727]: https://github.com/tokio-rs/tokio/pull/3727 +[#3737]: https://github.com/tokio-rs/tokio/pull/3737 +[#3743]: https://github.com/tokio-rs/tokio/pull/3743 +[#3748]: https://github.com/tokio-rs/tokio/pull/3748 +[#3753]: https://github.com/tokio-rs/tokio/pull/3753 +[#3756]: https://github.com/tokio-rs/tokio/pull/3756 +[#3761]: https://github.com/tokio-rs/tokio/pull/3761 +[#3774]: https://github.com/tokio-rs/tokio/pull/3774 +[#3775]: https://github.com/tokio-rs/tokio/pull/3775 +[#3780]: https://github.com/tokio-rs/tokio/pull/3780 + +# 1.5.1 (July 6, 2021) + +### Fixed + +- runtime: remotely abort tasks on `JoinHandle::abort` ([#3934]) + +[#3934]: https://github.com/tokio-rs/tokio/pull/3934 + # 1.5.0 (April 12, 2021) ### Added @@ -19,6 +473,7 @@ - rt: fix panic in `JoinHandle::abort()` when called from other threads ([#3672]) - sync: don't panic in `oneshot::try_recv` ([#3674]) - sync: fix notifications getting dropped on receiver drop ([#3652]) +- sync: fix `Semaphore` permit overflow calculation ([#3644]) ### Documented @@ -3,21 +3,20 @@ # When uploading crates to the registry Cargo will automatically # "normalize" Cargo.toml files for maximal compatibility # with all versions of Cargo and also rewrite `path` dependencies -# to registry (e.g., crates.io) dependencies +# to registry (e.g., crates.io) dependencies. # -# If you believe there's an error in this file please file an -# issue against the rust-lang/cargo repository. If you're -# editing this file be aware that the upstream Cargo.toml -# will likely look very different (and much more reasonable) +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. [package] edition = "2018" name = "tokio" -version = "1.5.0" +version = "1.14.0" authors = ["Tokio Contributors <team@tokio.rs>"] description = "An event-driven, non-blocking I/O platform for writing asynchronous I/O\nbacked applications.\n" homepage = "https://tokio.rs" -documentation = "https://docs.rs/tokio/1.5.0/tokio/" +documentation = "https://docs.rs/tokio/1.14.0/tokio/" readme = "README.md" keywords = ["io", "async", "non-blocking", "futures"] categories = ["asynchronous", "network-programming"] @@ -57,7 +56,7 @@ optional = true version = "0.2.0" [dependencies.tokio-macros] -version = "1.1.0" +version = "1.6.0" optional = true [dev-dependencies.async-stream] version = "0.3" @@ -66,12 +65,18 @@ version = "0.3" version = "0.3.0" features = ["async-await"] +[dev-dependencies.mockall] +version = "0.10.2" + [dev-dependencies.proptest] version = "1" [dev-dependencies.rand] version = "0.8.0" +[dev-dependencies.socket2] +version = "0.4" + [dev-dependencies.tempfile] version = "3.1.0" @@ -90,19 +95,23 @@ full = ["fs", "io-util", "io-std", "macros", "net", "parking_lot", "process", "r io-std = [] io-util = ["memchr", "bytes"] macros = ["tokio-macros"] -net = ["libc", "mio/os-poll", "mio/os-util", "mio/tcp", "mio/udp", "mio/uds"] +net = ["libc", "mio/os-poll", "mio/os-util", "mio/tcp", "mio/udp", "mio/uds", "winapi/namedpipeapi"] process = ["bytes", "once_cell", "libc", "mio/os-poll", "mio/os-util", "mio/uds", "signal-hook-registry", "winapi/threadpoollegacyapiset"] rt = [] rt-multi-thread = ["num_cpus", "rt"] signal = ["once_cell", "libc", "mio/os-poll", "mio/uds", "mio/os-util", "signal-hook-registry", "winapi/consoleapi"] +stats = [] sync = [] -test-util = [] +test-util = ["rt", "sync", "time"] time = [] [target."cfg(loom)".dev-dependencies.loom] version = "0.5" features = ["futures", "checkpoint"] +[target."cfg(target_os = \"freebsd\")".dev-dependencies.mio-aio] +version = "0.6.0" +features = ["tokio"] [target."cfg(tokio_unstable)".dependencies.tracing] -version = "0.1.21" +version = "0.1.25" features = ["std"] optional = true default-features = false @@ -117,8 +126,10 @@ optional = true version = "0.2.42" [target."cfg(unix)".dev-dependencies.nix] -version = "0.19.0" +version = "0.22.0" [target."cfg(windows)".dependencies.winapi] version = "0.3.8" optional = true default-features = false +[target."cfg(windows)".dev-dependencies.ntapi] +version = "0.3.6" diff --git a/Cargo.toml.orig b/Cargo.toml.orig index 5e53c3f..348ec46 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -7,12 +7,12 @@ name = "tokio" # - README.md # - Update CHANGELOG.md. # - Create "v1.0.x" git tag. -version = "1.5.0" +version = "1.14.0" edition = "2018" authors = ["Tokio Contributors <team@tokio.rs>"] license = "MIT" readme = "README.md" -documentation = "https://docs.rs/tokio/1.5.0/tokio/" +documentation = "https://docs.rs/tokio/1.14.0/tokio/" repository = "https://github.com/tokio-rs/tokio" homepage = "https://tokio.rs" description = """ @@ -47,6 +47,7 @@ io-util = ["memchr", "bytes"] # stdin, stdout, stderr io-std = [] macros = ["tokio-macros"] +stats = [] net = [ "libc", "mio/os-poll", @@ -54,6 +55,7 @@ net = [ "mio/tcp", "mio/udp", "mio/uds", + "winapi/namedpipeapi", ] process = [ "bytes", @@ -81,11 +83,11 @@ signal = [ "winapi/consoleapi", ] sync = [] -test-util = [] +test-util = ["rt", "sync", "time"] time = [] [dependencies] -tokio-macros = { version = "1.1.0", path = "../tokio-macros", optional = true } +tokio-macros = { version = "1.6.0", path = "../tokio-macros", optional = true } pin-project-lite = "0.2.0" @@ -100,7 +102,7 @@ parking_lot = { version = "0.11.0", optional = true } # Currently unstable. The API exposed by these features may be broken at any time. # Requires `--cfg tokio_unstable` to enable. [target.'cfg(tokio_unstable)'.dependencies] -tracing = { version = "0.1.21", default-features = false, features = ["std"], optional = true } # Not in full +tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true } # Not in full [target.'cfg(unix)'.dependencies] libc = { version = "0.2.42", optional = true } @@ -108,21 +110,29 @@ signal-hook-registry = { version = "1.1.1", optional = true } [target.'cfg(unix)'.dev-dependencies] libc = { version = "0.2.42" } -nix = { version = "0.19.0" } +nix = { version = "0.22.0" } [target.'cfg(windows)'.dependencies.winapi] version = "0.3.8" default-features = false optional = true +[target.'cfg(windows)'.dev-dependencies.ntapi] +version = "0.3.6" + [dev-dependencies] tokio-test = { version = "0.4.0", path = "../tokio-test" } tokio-stream = { version = "0.1", path = "../tokio-stream" } futures = { version = "0.3.0", features = ["async-await"] } +mockall = "0.10.2" proptest = "1" rand = "0.8.0" tempfile = "3.1.0" async-stream = "0.3" +socket2 = "0.4" + +[target.'cfg(target_os = "freebsd")'.dev-dependencies] +mio-aio = { version = "0.6.0", features = ["tokio"] } [target.'cfg(loom)'.dev-dependencies] loom = { version = "0.5", features = ["futures", "checkpoint"] } @@ -7,13 +7,13 @@ third_party { } url { type: ARCHIVE - value: "https://static.crates.io/crates/tokio/tokio-1.5.0.crate" + value: "https://static.crates.io/crates/tokio/tokio-1.14.0.crate" } - version: "1.5.0" + version: "1.14.0" license_type: NOTICE last_upgrade_date { year: 2021 - month: 4 - day: 21 + month: 11 + day: 16 } } diff --git a/README.android b/README.android new file mode 100644 index 0000000..a0d1ce4 --- /dev/null +++ b/README.android @@ -0,0 +1 @@ +The Android.bp file contains the "winapi" feature even though it is not used on Android. This gets added automatically by cargo2android.py and removing it would be non-trivial. It should have no effect on the crate, as the code is included only on Windows, so we keep it rather than maintaining a patch to remove it. @@ -50,7 +50,15 @@ an asynchronous application. ## Example -A basic TCP echo server with Tokio: +A basic TCP echo server with Tokio. + +Make sure you activated the full features of the tokio crate on Cargo.toml: + +```toml +[dependencies] +tokio = { version = "1.14.0", features = ["full"] } +``` +Then, on your main.rs: ```rust,no_run use tokio::net::TcpListener; @@ -58,7 +66,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> { - let mut listener = TcpListener::bind("127.0.0.1:8080").await?; + let listener = TcpListener::bind("127.0.0.1:8080").await?; loop { let (mut socket, _) = listener.accept().await?; @@ -132,7 +140,7 @@ several other libraries, including: * [`tower`]: A library of modular and reusable components for building robust networking clients and servers. -* [`tracing`]: A framework for application-level tracing and async-aware diagnostics. +* [`tracing`] (formerly `tokio-trace`): A framework for application-level tracing and async-aware diagnostics. * [`rdbc`]: A Rust database connectivity library for MySQL, Postgres and SQLite. @@ -155,9 +163,35 @@ several other libraries, including: ## Supported Rust Versions -Tokio is built against the latest stable release. The minimum supported version is 1.45. -The current Tokio version is not guaranteed to build on Rust versions earlier than the -minimum supported version. +Tokio is built against the latest stable release. The minimum supported version +is 1.45. The current Tokio version is not guaranteed to build on Rust versions +earlier than the minimum supported version. + +## Release schedule + +Tokio doesn't follow a fixed release schedule, but we typically make one to two +new minor releases each month. We make patch releases for bugfixes as necessary. + +## Bug patching policy + +For the purposes of making patch releases with bugfixes, we have designated +certain minor releases as LTS (long term support) releases. Whenever a bug +warrants a patch release with a fix for the bug, it will be backported and +released as a new patch release for each LTS minor version. Our current LTS +releases are: + + * `1.8.x` - LTS release until February 2022. + +Each LTS release will continue to receive backported fixes for at least half a +year. If you wish to use a fixed minor release in your project, we recommend +that you use an LTS release. + +To use a fixed minor version, you can specify the version with a tilde. For +example, to specify that you wish to use the newest `1.8.x` patch release, you +can use the following dependency specification: +```text +tokio = { version = "~1.8", features = [...] } +``` ## License diff --git a/TEST_MAPPING b/TEST_MAPPING index 686fbda..6c266ef 100644 --- a/TEST_MAPPING +++ b/TEST_MAPPING @@ -20,15 +20,33 @@ "name": "tokio-test_device_test_tests_macros" }, { + "name": "tokio_device_test_tests__require_full" + }, + { "name": "tokio_device_test_tests_buffered" }, { + "name": "tokio_device_test_tests_io_async_fd" + }, + { "name": "tokio_device_test_tests_io_async_read" }, { + "name": "tokio_device_test_tests_io_chain" + }, + { + "name": "tokio_device_test_tests_io_copy" + }, + { "name": "tokio_device_test_tests_io_copy_bidirectional" }, { + "name": "tokio_device_test_tests_io_driver" + }, + { + "name": "tokio_device_test_tests_io_driver_drop" + }, + { "name": "tokio_device_test_tests_io_lines" }, { @@ -41,9 +59,24 @@ "name": "tokio_device_test_tests_io_read_buf" }, { + "name": "tokio_device_test_tests_io_read_exact" + }, + { + "name": "tokio_device_test_tests_io_read_line" + }, + { "name": "tokio_device_test_tests_io_read_to_end" }, { + "name": "tokio_device_test_tests_io_read_to_string" + }, + { + "name": "tokio_device_test_tests_io_read_until" + }, + { + "name": "tokio_device_test_tests_io_split" + }, + { "name": "tokio_device_test_tests_io_take" }, { @@ -62,12 +95,36 @@ "name": "tokio_device_test_tests_macros_join" }, { + "name": "tokio_device_test_tests_macros_pin" + }, + { + "name": "tokio_device_test_tests_macros_select" + }, + { + "name": "tokio_device_test_tests_macros_test" + }, + { + "name": "tokio_device_test_tests_macros_try_join" + }, + { + "name": "tokio_device_test_tests_net_bind_resource" + }, + { + "name": "tokio_device_test_tests_net_lookup_host" + }, + { "name": "tokio_device_test_tests_no_rt" }, { + "name": "tokio_device_test_tests_process_kill_on_drop" + }, + { "name": "tokio_device_test_tests_rt_basic" }, { + "name": "tokio_device_test_tests_rt_common" + }, + { "name": "tokio_device_test_tests_rt_threaded" }, { @@ -83,15 +140,36 @@ "name": "tokio_device_test_tests_sync_mpsc" }, { + "name": "tokio_device_test_tests_sync_mutex" + }, + { "name": "tokio_device_test_tests_sync_mutex_owned" }, { + "name": "tokio_device_test_tests_sync_notify" + }, + { + "name": "tokio_device_test_tests_sync_oneshot" + }, + { "name": "tokio_device_test_tests_sync_rwlock" }, { + "name": "tokio_device_test_tests_sync_semaphore" + }, + { + "name": "tokio_device_test_tests_sync_semaphore_owned" + }, + { "name": "tokio_device_test_tests_sync_watch" }, { + "name": "tokio_device_test_tests_task_abort" + }, + { + "name": "tokio_device_test_tests_task_blocking" + }, + { "name": "tokio_device_test_tests_task_local" }, { @@ -101,18 +179,39 @@ "name": "tokio_device_test_tests_tcp_accept" }, { + "name": "tokio_device_test_tests_tcp_connect" + }, + { "name": "tokio_device_test_tests_tcp_echo" }, { + "name": "tokio_device_test_tests_tcp_into_split" + }, + { "name": "tokio_device_test_tests_tcp_into_std" }, { + "name": "tokio_device_test_tests_tcp_peek" + }, + { "name": "tokio_device_test_tests_tcp_shutdown" }, { + "name": "tokio_device_test_tests_tcp_socket" + }, + { + "name": "tokio_device_test_tests_tcp_split" + }, + { "name": "tokio_device_test_tests_time_rt" }, { + "name": "tokio_device_test_tests_udp" + }, + { + "name": "tokio_device_test_tests_uds_cred" + }, + { "name": "tokio_device_test_tests_uds_split" } ] diff --git a/cargo2android.json b/cargo2android.json new file mode 100644 index 0000000..c715d16 --- /dev/null +++ b/cargo2android.json @@ -0,0 +1,11 @@ +{ + "add-toplevel-block": "cargo2android_tests.bp", + "apex-available": [ + "//apex_available:platform", + "com.android.resolv" + ], + "min_sdk_version": "29", + "features": "io-util,macros,rt-multi-thread,sync,net,fs,time", + "device": true, + "run": true +}
\ No newline at end of file diff --git a/cargo2android_tests.bp b/cargo2android_tests.bp new file mode 100644 index 0000000..763471b --- /dev/null +++ b/cargo2android_tests.bp @@ -0,0 +1,1017 @@ +rust_defaults { + name: "tokio_defaults_tokio", + crate_name: "tokio", + test_suites: ["general-tests"], + auto_gen_config: true, + edition: "2018", + features: [ + "bytes", + "fs", + "full", + "io-util", + "libc", + "macros", + "memchr", + "mio", + "net", + "num_cpus", + "rt", + "rt-multi-thread", + "sync", + "time", + "tokio-macros", + ], + cfgs: ["tokio_track_caller"], + rustlibs: [ + "libasync_stream", + "libbytes", + "libfutures", + "liblibc", + "libmemchr", + "libmio", + "libnix", + "libnum_cpus", + "libpin_project_lite", + "librand", + "libtokio", + "libtokio_stream", + "libtokio_test", + ], + proc_macros: ["libtokio_macros"], +} + +rust_test_host { + name: "tokio_host_test_tests__require_full", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/_require_full.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests__require_full", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/_require_full.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_buffered", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/buffered.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_buffered", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/buffered.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_async_fd", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_async_fd.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_async_fd", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_async_fd.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_async_read", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_async_read.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_async_read", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_async_read.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_chain", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_chain.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_chain", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_chain.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_copy", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_copy.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_copy", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_copy.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_copy_bidirectional", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_copy_bidirectional.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_copy_bidirectional", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_copy_bidirectional.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_driver", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_driver.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_driver", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_driver.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_driver_drop", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_driver_drop.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_driver_drop", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_driver_drop.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_lines", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_lines.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_lines", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_lines.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_mem_stream", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_mem_stream.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_mem_stream", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_mem_stream.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read_buf", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_buf.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_buf", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_buf.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read_exact", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_exact.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_exact", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_exact.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read_line", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_line.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_line", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_line.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read_to_end", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_to_end.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_to_end", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_to_end.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read_to_string", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_to_string.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_to_string", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_to_string.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_read_until", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_until.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_read_until", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_read_until.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_split.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_split.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_take", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_take.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_take", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_take.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_write", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_write.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_write", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_write.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_write_all", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_write_all.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_write_all", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_write_all.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_write_buf", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_write_buf.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_write_buf", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_write_buf.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_io_write_int", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_write_int.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_io_write_int", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/io_write_int.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_macros_join", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_join.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_join", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_join.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_macros_pin", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_pin.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_pin", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_pin.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_macros_select", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_select.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_select", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_select.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_macros_test", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_test.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_test", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_test.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_macros_try_join", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_try_join.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_macros_try_join", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/macros_try_join.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_net_bind_resource", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/net_bind_resource.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_net_bind_resource", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/net_bind_resource.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_net_lookup_host", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/net_lookup_host.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_net_lookup_host", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/net_lookup_host.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_no_rt", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/no_rt.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_no_rt", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/no_rt.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_process_kill_on_drop", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/process_kill_on_drop.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_process_kill_on_drop", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/process_kill_on_drop.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_rt_basic", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/rt_basic.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_rt_basic", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/rt_basic.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_rt_common", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/rt_common.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_rt_common", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/rt_common.rs"], +} + + +rust_test_host { + name: "tokio_host_test_tests_rt_threaded", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/rt_threaded.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_rt_threaded", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/rt_threaded.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_barrier", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_barrier.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_barrier", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_barrier.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_broadcast", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_broadcast.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_broadcast", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_broadcast.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_errors", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_errors.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_errors", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_errors.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_mpsc", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_mpsc.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_mpsc", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_mpsc.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_mutex", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_mutex.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_mutex", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_mutex.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_mutex_owned", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_mutex_owned.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_mutex_owned", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_mutex_owned.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_notify", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_notify.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_notify", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_notify.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_oneshot", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_oneshot.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_oneshot", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_oneshot.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_rwlock", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_rwlock.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_rwlock", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_rwlock.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_semaphore", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_semaphore.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_semaphore", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_semaphore.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_semaphore_owned", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_semaphore_owned.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_semaphore_owned", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_semaphore_owned.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_sync_watch", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_watch.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_sync_watch", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/sync_watch.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_task_abort", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_abort.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_task_abort", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_abort.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_task_blocking", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_blocking.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_task_blocking", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_blocking.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_task_local", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_local.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_task_local", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_local.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_task_local_set", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_local_set.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_task_local_set", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/task_local_set.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_accept", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_accept.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_accept", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_accept.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_connect", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_connect.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_connect", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_connect.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_echo", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_echo.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_echo", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_echo.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_into_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_into_split.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_into_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_into_split.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_into_std", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_into_std.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_into_std", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_into_std.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_peek", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_peek.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_peek", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_peek.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_shutdown", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_shutdown.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_shutdown", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_shutdown.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_socket", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_socket.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_socket", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_socket.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_tcp_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_split.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_tcp_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/tcp_split.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_time_rt", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/time_rt.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_time_rt", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/time_rt.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_udp", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/udp.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_udp", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/udp.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_uds_cred", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/uds_cred.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_uds_cred", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/uds_cred.rs"], +} + +rust_test_host { + name: "tokio_host_test_tests_uds_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/uds_split.rs"], + test_options: { + unit_test: true, + }, +} + +rust_test { + name: "tokio_device_test_tests_uds_split", + defaults: ["tokio_defaults_tokio"], + srcs: ["tests/uds_split.rs"], +}
\ No newline at end of file diff --git a/docs/reactor-refactor.md b/docs/reactor-refactor.md index a0b5447..1c9ace1 100644 --- a/docs/reactor-refactor.md +++ b/docs/reactor-refactor.md @@ -228,7 +228,7 @@ It is only possible to implement `AsyncRead` and `AsyncWrite` for resource types themselves and not for `&Resource`. Implementing the traits for `&Resource` would permit concurrent operations to the resource. Because only a single waker is stored per direction, any concurrent usage would result in deadlocks. An -alterate implementation would call for a `Vec<Waker>` but this would result in +alternate implementation would call for a `Vec<Waker>` but this would result in memory leaks. ## Enabling reads and writes for `&TcpStream` @@ -268,9 +268,9 @@ select! { } ``` -It is also possible to sotre a `TcpStream` in an `Arc`. +It is also possible to store a `TcpStream` in an `Arc`. ```rust let arc_stream = Arc::new(my_tcp_stream); let n = arc_stream.by_ref().read(buf).await?; -```
\ No newline at end of file +``` diff --git a/patches/Android.bp.patch b/patches/Android.bp.patch deleted file mode 100644 index 759d95f..0000000 --- a/patches/Android.bp.patch +++ /dev/null @@ -1,310 +0,0 @@ -diff --git a/Android.bp b/Android.bp -index 6b8ca5b..222916b 100644 ---- a/Android.bp -+++ b/Android.bp -@@ -50,6 +50,11 @@ rust_library { - "libpin_project_lite", - ], - proc_macros: ["libtokio_macros"], -+ apex_available: [ -+ "//apex_available:platform", -+ "com.android.resolv", -+ ], -+ min_sdk_version: "29", - } - - rust_defaults { -@@ -61,6 +66,7 @@ rust_defaults { - features: [ - "bytes", - "fs", -+ "full", - "io-util", - "libc", - "macros", -@@ -108,36 +114,6 @@ rust_test { - srcs: ["tests/buffered.rs"], - } - --rust_test_host { -- name: "tokio_host_test_tests_fs_file", -- defaults: ["tokio_defaults"], -- srcs: ["tests/fs_file.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_fs_file", -- defaults: ["tokio_defaults"], -- srcs: ["tests/fs_file.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_fs_link", -- defaults: ["tokio_defaults"], -- srcs: ["tests/fs_link.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_fs_link", -- defaults: ["tokio_defaults"], -- srcs: ["tests/fs_link.rs"], --} -- - rust_test_host { - name: "tokio_host_test_tests_io_async_read", - defaults: ["tokio_defaults"], -@@ -348,51 +324,6 @@ rust_test { - srcs: ["tests/no_rt.rs"], - } - --rust_test_host { -- name: "tokio_host_test_tests_process_issue_2174", -- defaults: ["tokio_defaults"], -- srcs: ["tests/process_issue_2174.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_process_issue_2174", -- defaults: ["tokio_defaults"], -- srcs: ["tests/process_issue_2174.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_process_issue_42", -- defaults: ["tokio_defaults"], -- srcs: ["tests/process_issue_42.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_process_issue_42", -- defaults: ["tokio_defaults"], -- srcs: ["tests/process_issue_42.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_process_smoke", -- defaults: ["tokio_defaults"], -- srcs: ["tests/process_smoke.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_process_smoke", -- defaults: ["tokio_defaults"], -- srcs: ["tests/process_smoke.rs"], --} -- - rust_test_host { - name: "tokio_host_test_tests_rt_basic", - defaults: ["tokio_defaults"], -@@ -423,111 +354,6 @@ rust_test { - srcs: ["tests/rt_threaded.rs"], - } - --rust_test_host { -- name: "tokio_host_test_tests_signal_ctrl_c", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_ctrl_c.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_signal_ctrl_c", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_ctrl_c.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_signal_drop_rt", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_drop_rt.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_signal_drop_rt", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_drop_rt.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_signal_drop_signal", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_drop_signal.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_signal_drop_signal", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_drop_signal.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_signal_multi_rt", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_multi_rt.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_signal_multi_rt", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_multi_rt.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_signal_no_rt", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_no_rt.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_signal_no_rt", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_no_rt.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_signal_notify_both", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_notify_both.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_signal_notify_both", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_notify_both.rs"], --} -- --rust_test_host { -- name: "tokio_host_test_tests_signal_twice", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_twice.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_signal_twice", -- defaults: ["tokio_defaults"], -- srcs: ["tests/signal_twice.rs"], --} -- - rust_test_host { - name: "tokio_host_test_tests_sync_barrier", - defaults: ["tokio_defaults"], -@@ -603,21 +429,6 @@ rust_test { - srcs: ["tests/sync_mutex_owned.rs"], - } - --rust_test_host { -- name: "tokio_host_test_tests_sync_once_cell", -- defaults: ["tokio_defaults"], -- srcs: ["tests/sync_once_cell.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_sync_once_cell", -- defaults: ["tokio_defaults"], -- srcs: ["tests/sync_once_cell.rs"], --} -- - rust_test_host { - name: "tokio_host_test_tests_sync_rwlock", - defaults: ["tokio_defaults"], -@@ -738,21 +549,6 @@ rust_test { - srcs: ["tests/tcp_shutdown.rs"], - } - --rust_test_host { -- name: "tokio_host_test_tests_time_interval", -- defaults: ["tokio_defaults"], -- srcs: ["tests/time_interval.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_time_interval", -- defaults: ["tokio_defaults"], -- srcs: ["tests/time_interval.rs"], --} -- - rust_test_host { - name: "tokio_host_test_tests_time_rt", - defaults: ["tokio_defaults"], -@@ -768,21 +564,6 @@ rust_test { - srcs: ["tests/time_rt.rs"], - } - --rust_test_host { -- name: "tokio_host_test_tests_time_timeout", -- defaults: ["tokio_defaults"], -- srcs: ["tests/time_timeout.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_time_timeout", -- defaults: ["tokio_defaults"], -- srcs: ["tests/time_timeout.rs"], --} -- - rust_test_host { - name: "tokio_host_test_tests_uds_split", - defaults: ["tokio_defaults"], -@@ -797,18 +578,3 @@ rust_test { - defaults: ["tokio_defaults"], - srcs: ["tests/uds_split.rs"], - } -- --rust_test_host { -- name: "tokio_host_test_tests_uds_stream", -- defaults: ["tokio_defaults"], -- srcs: ["tests/uds_stream.rs"], -- test_options: { -- unit_test: true, -- }, --} -- --rust_test { -- name: "tokio_device_test_tests_uds_stream", -- defaults: ["tokio_defaults"], -- srcs: ["tests/uds_stream.rs"], --} diff --git a/patches/io_mem_stream.patch b/patches/io_mem_stream.patch new file mode 100644 index 0000000..c21ce18 --- /dev/null +++ b/patches/io_mem_stream.patch @@ -0,0 +1,12 @@ +diff --git a/tests/io_mem_stream.rs b/tests/io_mem_stream.rs +index 01baa53..520391a 100644 +--- a/tests/io_mem_stream.rs ++++ b/tests/io_mem_stream.rs +@@ -63,6 +63,7 @@ async fn disconnect() { + } + + #[tokio::test] ++#[cfg(not(target_os = "android"))] + async fn disconnect_reader() { + let (a, mut b) = duplex(2); + diff --git a/patches/rt_common.patch b/patches/rt_common.patch new file mode 100644 index 0000000..1444cfe --- /dev/null +++ b/patches/rt_common.patch @@ -0,0 +1,12 @@ +diff --git a/tests/rt_common.rs b/tests/rt_common.rs +index cb1d0f6..e5fc7a9 100644 +--- a/tests/rt_common.rs ++++ b/tests/rt_common.rs +@@ -647,6 +647,7 @@ rt_test! { + } + + #[test] ++ #[cfg(not(target_os = "android"))] + fn panic_in_task() { + let rt = rt(); + let (tx, rx) = oneshot::channel(); diff --git a/patches/task_abort.patch b/patches/task_abort.patch new file mode 100644 index 0000000..df05ccb --- /dev/null +++ b/patches/task_abort.patch @@ -0,0 +1,20 @@ +diff --git a/tests/task_abort.rs b/tests/task_abort.rs +index cdaa405..ec0eed7 100644 +--- a/tests/task_abort.rs ++++ b/tests/task_abort.rs +@@ -180,6 +180,7 @@ fn test_abort_wakes_task_3964() { + /// Checks that aborting a task whose destructor panics does not allow the + /// panic to escape the task. + #[test] ++#[cfg(not(target_os = "android"))] + fn test_abort_task_that_panics_on_drop_contained() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + +@@ -204,6 +205,7 @@ fn test_abort_task_that_panics_on_drop_contained() { + + /// Checks that aborting a task whose destructor panics has the expected result. + #[test] ++#[cfg(not(target_os = "android"))] + fn test_abort_task_that_panics_on_drop_returned() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + diff --git a/patches/task_blocking.patch b/patches/task_blocking.patch new file mode 100644 index 0000000..7f4f7d4 --- /dev/null +++ b/patches/task_blocking.patch @@ -0,0 +1,12 @@ +diff --git a/tests/task_blocking.rs b/tests/task_blocking.rs +index 82bef8a..d9514d2 100644 +--- a/tests/task_blocking.rs ++++ b/tests/task_blocking.rs +@@ -114,6 +114,7 @@ fn can_enter_basic_rt_from_within_block_in_place() { + } + + #[test] ++#[cfg(not(target_os = "android"))] + fn useful_panic_message_when_dropping_rt_in_rt() { + use std::panic::{catch_unwind, AssertUnwindSafe}; + diff --git a/src/coop.rs b/src/coop.rs index 16d93fb..256e962 100644 --- a/src/coop.rs +++ b/src/coop.rs @@ -69,14 +69,14 @@ cfg_rt_multi_thread! { } } -/// Run the given closure with a cooperative task budget. When the function +/// Runs the given closure with a cooperative task budget. When the function /// returns, the budget is reset to the value prior to calling the function. #[inline(always)] pub(crate) fn budget<R>(f: impl FnOnce() -> R) -> R { with_budget(Budget::initial(), f) } -/// Run the given closure with an unconstrained task budget. When the function returns, the budget +/// Runs the given closure with an unconstrained task budget. When the function returns, the budget /// is reset to the value prior to calling the function. #[inline(always)] pub(crate) fn with_unconstrained<R>(f: impl FnOnce() -> R) -> R { @@ -108,7 +108,7 @@ fn with_budget<R>(budget: Budget, f: impl FnOnce() -> R) -> R { } cfg_rt_multi_thread! { - /// Set the current task's budget + /// Sets the current task's budget. pub(crate) fn set(budget: Budget) { CURRENT.with(|cell| cell.set(budget)) } @@ -120,7 +120,7 @@ cfg_rt_multi_thread! { } cfg_rt! { - /// Forcibly remove the budgeting constraints early. + /// Forcibly removes the budgeting constraints early. /// /// Returns the remaining budget pub(crate) fn stop() -> Budget { @@ -186,7 +186,7 @@ cfg_coop! { } impl Budget { - /// Decrement the budget. Returns `true` if successful. Decrementing fails + /// Decrements the budget. Returns `true` if successful. Decrementing fails /// when there is not enough remaining budget. fn decrement(&mut self) -> bool { if let Some(num) = &mut self.0 { diff --git a/src/doc/mod.rs b/src/doc/mod.rs new file mode 100644 index 0000000..3a94934 --- /dev/null +++ b/src/doc/mod.rs @@ -0,0 +1,24 @@ +//! Types which are documented locally in the Tokio crate, but does not actually +//! live here. +//! +//! **Note** this module is only visible on docs.rs, you cannot use it directly +//! in your own code. + +/// The name of a type which is not defined here. +/// +/// This is typically used as an alias for another type, like so: +/// +/// ```rust,ignore +/// /// See [some::other::location](https://example.com). +/// type DEFINED_ELSEWHERE = crate::doc::NotDefinedHere; +/// ``` +/// +/// This type is uninhabitable like the [`never` type] to ensure that no one +/// will ever accidentally use it. +/// +/// [`never` type]: https://doc.rust-lang.org/std/primitive.never.html +#[derive(Debug)] +pub enum NotDefinedHere {} + +pub mod os; +pub mod winapi; diff --git a/src/doc/os.rs b/src/doc/os.rs new file mode 100644 index 0000000..0ddf869 --- /dev/null +++ b/src/doc/os.rs @@ -0,0 +1,26 @@ +//! See [std::os](https://doc.rust-lang.org/std/os/index.html). + +/// Platform-specific extensions to `std` for Windows. +/// +/// See [std::os::windows](https://doc.rust-lang.org/std/os/windows/index.html). +pub mod windows { + /// Windows-specific extensions to general I/O primitives. + /// + /// See [std::os::windows::io](https://doc.rust-lang.org/std/os/windows/io/index.html). + pub mod io { + /// See [std::os::windows::io::RawHandle](https://doc.rust-lang.org/std/os/windows/io/type.RawHandle.html) + pub type RawHandle = crate::doc::NotDefinedHere; + + /// See [std::os::windows::io::AsRawHandle](https://doc.rust-lang.org/std/os/windows/io/trait.AsRawHandle.html) + pub trait AsRawHandle { + /// See [std::os::windows::io::FromRawHandle::from_raw_handle](https://doc.rust-lang.org/std/os/windows/io/trait.AsRawHandle.html#tymethod.as_raw_handle) + fn as_raw_handle(&self) -> RawHandle; + } + + /// See [std::os::windows::io::FromRawHandle](https://doc.rust-lang.org/std/os/windows/io/trait.FromRawHandle.html) + pub trait FromRawHandle { + /// See [std::os::windows::io::FromRawHandle::from_raw_handle](https://doc.rust-lang.org/std/os/windows/io/trait.FromRawHandle.html#tymethod.from_raw_handle) + unsafe fn from_raw_handle(handle: RawHandle) -> Self; + } + } +} diff --git a/src/doc/winapi.rs b/src/doc/winapi.rs new file mode 100644 index 0000000..be68749 --- /dev/null +++ b/src/doc/winapi.rs @@ -0,0 +1,66 @@ +//! See [winapi]. +//! +//! [winapi]: https://docs.rs/winapi + +/// See [winapi::shared](https://docs.rs/winapi/*/winapi/shared/index.html). +pub mod shared { + /// See [winapi::shared::winerror](https://docs.rs/winapi/*/winapi/shared/winerror/index.html). + #[allow(non_camel_case_types)] + pub mod winerror { + /// See [winapi::shared::winerror::ERROR_ACCESS_DENIED][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_ACCESS_DENIED.html + pub type ERROR_ACCESS_DENIED = crate::doc::NotDefinedHere; + + /// See [winapi::shared::winerror::ERROR_PIPE_BUSY][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_PIPE_BUSY.html + pub type ERROR_PIPE_BUSY = crate::doc::NotDefinedHere; + + /// See [winapi::shared::winerror::ERROR_MORE_DATA][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/shared/winerror/constant.ERROR_MORE_DATA.html + pub type ERROR_MORE_DATA = crate::doc::NotDefinedHere; + } +} + +/// See [winapi::um](https://docs.rs/winapi/*/winapi/um/index.html). +pub mod um { + /// See [winapi::um::winbase](https://docs.rs/winapi/*/winapi/um/winbase/index.html). + #[allow(non_camel_case_types)] + pub mod winbase { + /// See [winapi::um::winbase::PIPE_TYPE_MESSAGE][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_TYPE_MESSAGE.html + pub type PIPE_TYPE_MESSAGE = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_TYPE_BYTE][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_TYPE_BYTE.html + pub type PIPE_TYPE_BYTE = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_CLIENT_END][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_CLIENT_END.html + pub type PIPE_CLIENT_END = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::PIPE_SERVER_END][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.PIPE_SERVER_END.html + pub type PIPE_SERVER_END = crate::doc::NotDefinedHere; + + /// See [winapi::um::winbase::SECURITY_IDENTIFICATION][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/winbase/constant.SECURITY_IDENTIFICATION.html + pub type SECURITY_IDENTIFICATION = crate::doc::NotDefinedHere; + } + + /// See [winapi::um::minwinbase](https://docs.rs/winapi/*/winapi/um/minwinbase/index.html). + #[allow(non_camel_case_types)] + pub mod minwinbase { + /// See [winapi::um::minwinbase::SECURITY_ATTRIBUTES][winapi] + /// + /// [winapi]: https://docs.rs/winapi/*/winapi/um/minwinbase/constant.SECURITY_ATTRIBUTES.html + pub type SECURITY_ATTRIBUTES = crate::doc::NotDefinedHere; + } +} diff --git a/src/fs/create_dir.rs b/src/fs/create_dir.rs index e03b04d..4119695 100644 --- a/src/fs/create_dir.rs +++ b/src/fs/create_dir.rs @@ -3,7 +3,7 @@ use crate::fs::asyncify; use std::io; use std::path::Path; -/// Creates a new, empty directory at the provided path +/// Creates a new, empty directory at the provided path. /// /// This is an async version of [`std::fs::create_dir`][std] /// diff --git a/src/fs/dir_builder.rs b/src/fs/dir_builder.rs index b184934..97168bf 100644 --- a/src/fs/dir_builder.rs +++ b/src/fs/dir_builder.rs @@ -14,7 +14,7 @@ pub struct DirBuilder { /// Indicates whether to create parent directories if they are missing. recursive: bool, - /// Set the Unix mode for newly created directories. + /// Sets the Unix mode for newly created directories. #[cfg(unix)] pub(super) mode: Option<u32>, } diff --git a/src/fs/file.rs b/src/fs/file.rs index 5c06e73..61071cf 100644 --- a/src/fs/file.rs +++ b/src/fs/file.rs @@ -3,7 +3,7 @@ //! [`File`]: File use self::State::*; -use crate::fs::{asyncify, sys}; +use crate::fs::asyncify; use crate::io::blocking::Buf; use crate::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use crate::sync::Mutex; @@ -19,6 +19,19 @@ use std::task::Context; use std::task::Poll; use std::task::Poll::*; +#[cfg(test)] +use super::mocks::spawn_blocking; +#[cfg(test)] +use super::mocks::JoinHandle; +#[cfg(test)] +use super::mocks::MockFile as StdFile; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(not(test))] +use crate::blocking::JoinHandle; +#[cfg(not(test))] +use std::fs::File as StdFile; + /// A reference to an open file on the filesystem. /// /// This is a specialized version of [`std::fs::File`][std] for usage from the @@ -61,7 +74,7 @@ use std::task::Poll::*; /// # } /// ``` /// -/// Read the contents of a file into a buffer +/// Read the contents of a file into a buffer: /// /// ```no_run /// use tokio::fs::File; @@ -78,7 +91,7 @@ use std::task::Poll::*; /// # } /// ``` pub struct File { - std: Arc<sys::File>, + std: Arc<StdFile>, inner: Mutex<Inner>, } @@ -96,7 +109,7 @@ struct Inner { #[derive(Debug)] enum State { Idle(Option<Buf>), - Busy(sys::Blocking<(Operation, Buf)>), + Busy(JoinHandle<(Operation, Buf)>), } #[derive(Debug)] @@ -142,7 +155,7 @@ impl File { /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt pub async fn open(path: impl AsRef<Path>) -> io::Result<File> { let path = path.as_ref().to_owned(); - let std = asyncify(|| sys::File::open(path)).await?; + let std = asyncify(|| StdFile::open(path)).await?; Ok(File::from_std(std)) } @@ -182,7 +195,7 @@ impl File { /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt pub async fn create(path: impl AsRef<Path>) -> io::Result<File> { let path = path.as_ref().to_owned(); - let std_file = asyncify(move || sys::File::create(path)).await?; + let std_file = asyncify(move || StdFile::create(path)).await?; Ok(File::from_std(std_file)) } @@ -199,7 +212,7 @@ impl File { /// let std_file = std::fs::File::open("foo.txt").unwrap(); /// let file = tokio::fs::File::from_std(std_file); /// ``` - pub fn from_std(std: sys::File) -> File { + pub fn from_std(std: StdFile) -> File { File { std: Arc::new(std), inner: Mutex::new(Inner { @@ -323,7 +336,7 @@ impl File { let std = self.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| std.set_len(size)) } else { @@ -370,7 +383,7 @@ impl File { asyncify(move || std.metadata()).await } - /// Create a new `File` instance that shares the same underlying file handle + /// Creates a new `File` instance that shares the same underlying file handle /// as the existing `File` instance. Reads, writes, and seeks will affect both /// File instances simultaneously. /// @@ -409,7 +422,7 @@ impl File { /// # Ok(()) /// # } /// ``` - pub async fn into_std(mut self) -> sys::File { + pub async fn into_std(mut self) -> StdFile { self.inner.get_mut().complete_inflight().await; Arc::try_unwrap(self.std).expect("Arc::try_unwrap failed") } @@ -434,7 +447,7 @@ impl File { /// # Ok(()) /// # } /// ``` - pub fn try_into_std(mut self) -> Result<sys::File, Self> { + pub fn try_into_std(mut self) -> Result<StdFile, Self> { match Arc::try_unwrap(self.std) { Ok(file) => Ok(file), Err(std_file_arc) => { @@ -502,7 +515,7 @@ impl AsyncRead for File { buf.ensure_capacity_for(dst); let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = buf.read_from(&mut &*std); (Operation::Read(res), buf) })); @@ -569,7 +582,7 @@ impl AsyncSeek for File { let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = (&*std).seek(pos); (Operation::Seek(res), buf) })); @@ -636,7 +649,7 @@ impl AsyncWrite for File { let n = buf.copy_from(src); let std = me.std.clone(); - inner.state = Busy(sys::run(move || { + inner.state = Busy(spawn_blocking(move || { let res = if let Some(seek) = seek { (&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std)) } else { @@ -685,8 +698,8 @@ impl AsyncWrite for File { } } -impl From<sys::File> for File { - fn from(std: sys::File) -> Self { +impl From<StdFile> for File { + fn from(std: StdFile) -> Self { Self::from_std(std) } } @@ -709,7 +722,7 @@ impl std::os::unix::io::AsRawFd for File { #[cfg(unix)] impl std::os::unix::io::FromRawFd for File { unsafe fn from_raw_fd(fd: std::os::unix::io::RawFd) -> Self { - sys::File::from_raw_fd(fd).into() + StdFile::from_raw_fd(fd).into() } } @@ -723,7 +736,7 @@ impl std::os::windows::io::AsRawHandle for File { #[cfg(windows)] impl std::os::windows::io::FromRawHandle for File { unsafe fn from_raw_handle(handle: std::os::windows::io::RawHandle) -> Self { - sys::File::from_raw_handle(handle).into() + StdFile::from_raw_handle(handle).into() } } @@ -756,3 +769,6 @@ impl Inner { } } } + +#[cfg(test)] +mod tests; diff --git a/tests/fs_file_mocked.rs b/src/fs/file/tests.rs index 7771532..28b5ffe 100644 --- a/tests/fs_file_mocked.rs +++ b/src/fs/file/tests.rs @@ -1,80 +1,21 @@ -#![warn(rust_2018_idioms)] -#![cfg(feature = "full")] - -macro_rules! ready { - ($e:expr $(,)?) => { - match $e { - std::task::Poll::Ready(t) => t, - std::task::Poll::Pending => return std::task::Poll::Pending, - } - }; -} - -#[macro_export] -macro_rules! cfg_fs { - ($($item:item)*) => { $($item)* } -} - -#[macro_export] -macro_rules! cfg_io_std { - ($($item:item)*) => { $($item)* } -} - -use futures::future; - -// Load source -#[allow(warnings)] -#[path = "../src/fs/file.rs"] -mod file; -use file::File; - -#[allow(warnings)] -#[path = "../src/io/blocking.rs"] -mod blocking; - -// Load mocked types -mod support { - pub(crate) mod mock_file; - pub(crate) mod mock_pool; -} -pub(crate) use support::mock_pool as pool; - -// Place them where the source expects them -pub(crate) mod io { - pub(crate) use tokio::io::*; - - pub(crate) use crate::blocking; - - pub(crate) mod sys { - pub(crate) use crate::support::mock_pool::{run, Blocking}; - } -} -pub(crate) mod fs { - pub(crate) mod sys { - pub(crate) use crate::support::mock_file::File; - pub(crate) use crate::support::mock_pool::{run, Blocking}; - } - - pub(crate) use crate::support::mock_pool::asyncify; -} -pub(crate) mod sync { - pub(crate) use tokio::sync::Mutex; -} -use fs::sys; - -use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; -use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok, task}; - -use std::io::SeekFrom; +use super::*; +use crate::{ + fs::mocks::*, + io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, +}; +use mockall::{predicate::eq, Sequence}; +use tokio_test::{assert_pending, assert_ready_err, assert_ready_ok, task}; const HELLO: &[u8] = b"hello world..."; const FOO: &[u8] = b"foo bar baz..."; #[test] fn open_read() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); - + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); let mut buf = [0; 1024]; @@ -83,12 +24,10 @@ fn open_read() { assert_eq!(0, pool::len()); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); assert_eq!(1, pool::len()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); let n = assert_ready_ok!(t.poll()); @@ -98,9 +37,11 @@ fn open_read() { #[test] fn read_twice_before_dispatch() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); - + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); let mut buf = [0; 1024]; @@ -120,8 +61,11 @@ fn read_twice_before_dispatch() { #[test] fn read_with_smaller_buf() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO); + let mut file = MockFile::default(); + file.expect_inner_read().once().returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -153,8 +97,22 @@ fn read_with_smaller_buf() { #[test] fn read_with_bigger_buf() { - let (mock, file) = sys::File::mock(); - mock.read(&HELLO[..4]).read(&HELLO[4..]); + let mut seq = Sequence::new(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..4].copy_from_slice(&HELLO[..4]); + Ok(4) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len() - 4].copy_from_slice(&HELLO[4..]); + Ok(HELLO.len() - 4) + }); let mut file = File::from_std(file); @@ -194,8 +152,19 @@ fn read_with_bigger_buf() { #[test] fn read_err_then_read_success() { - let (mock, file) = sys::File::mock(); - mock.read_err().read(&HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -225,8 +194,11 @@ fn read_err_then_read_success() { #[test] fn open_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO); + let mut file = MockFile::default(); + file.expect_inner_write() + .once() + .with(eq(HELLO)) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -235,12 +207,10 @@ fn open_write() { assert_eq!(0, pool::len()); assert_ready_ok!(t.poll()); - assert_eq!(1, mock.remaining()); assert_eq!(1, pool::len()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(!t.is_woken()); let mut t = task::spawn(file.flush()); @@ -249,7 +219,7 @@ fn open_write() { #[test] fn flush_while_idle() { - let (_mock, file) = sys::File::mock(); + let file = MockFile::default(); let mut file = File::from_std(file); @@ -271,13 +241,42 @@ fn read_with_buffer_larger_than_max() { for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } - - let (mock, file) = sys::File::mock(); - mock.read(&data[0..chunk_a]) - .read(&data[chunk_a..chunk_b]) - .read(&data[chunk_b..chunk_c]) - .read(&data[chunk_c..]); - + let data = Arc::new(data); + let d0 = data.clone(); + let d1 = data.clone(); + let d2 = data.clone(); + let d3 = data.clone(); + + let mut seq = Sequence::new(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[0..chunk_a].copy_from_slice(&d0[0..chunk_a]); + Ok(chunk_a) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a].copy_from_slice(&d1[chunk_a..chunk_b]); + Ok(chunk_b - chunk_a) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a].copy_from_slice(&d2[chunk_b..chunk_c]); + Ok(chunk_c - chunk_b) + }); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(move |buf| { + buf[..chunk_a - 1].copy_from_slice(&d3[chunk_c..]); + Ok(chunk_a - 1) + }); let mut file = File::from_std(file); let mut actual = vec![0; chunk_d]; @@ -296,8 +295,7 @@ fn read_with_buffer_larger_than_max() { pos += n; } - assert_eq!(mock.remaining(), 0); - assert_eq!(data, &actual[..data.len()]); + assert_eq!(&data[..], &actual[..data.len()]); } #[test] @@ -314,12 +312,34 @@ fn write_with_buffer_larger_than_max() { for i in 0..(chunk_d - 1) { data.push((i % 151) as u8); } - - let (mock, file) = sys::File::mock(); - mock.write(&data[0..chunk_a]) - .write(&data[chunk_a..chunk_b]) - .write(&data[chunk_b..chunk_c]) - .write(&data[chunk_c..]); + let data = Arc::new(data); + let d0 = data.clone(); + let d1 = data.clone(); + let d2 = data.clone(); + let d3 = data.clone(); + + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d0[0..chunk_a]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d1[chunk_a..chunk_b]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d2[chunk_b..chunk_c]) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .withf(move |buf| buf == &d3[chunk_c..chunk_d - 1]) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -344,14 +364,22 @@ fn write_with_buffer_larger_than_max() { } pool::run_one(); - - assert_eq!(mock.remaining(), 0); } #[test] fn write_twice_before_dispatch() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|buf| Ok(buf.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|buf| Ok(buf.len())); let mut file = File::from_std(file); @@ -380,10 +408,24 @@ fn write_twice_before_dispatch() { #[test] fn incomplete_read_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-(HELLO.len() as i64), 0) - .write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -406,8 +448,25 @@ fn incomplete_read_followed_by_write() { #[test] fn incomplete_partial_read_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO).seek_current_ok(-10, 0).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .in_sequence(&mut seq) + .with(eq(SeekFrom::Current(-10))) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -433,10 +492,25 @@ fn incomplete_partial_read_followed_by_write() { #[test] fn incomplete_read_followed_by_flush() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-(HELLO.len() as i64), 0) - .write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .in_sequence(&mut seq) + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .returning(|_| Ok(0)); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -458,8 +532,18 @@ fn incomplete_read_followed_by_flush() { #[test] fn incomplete_flush_followed_by_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).write(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(FOO)) + .returning(|_| Ok(FOO.len())); let mut file = File::from_std(file); @@ -484,8 +568,10 @@ fn incomplete_flush_followed_by_write() { #[test] fn read_err() { - let (mock, file) = sys::File::mock(); - mock.read_err(); + let mut file = MockFile::default(); + file.expect_inner_read() + .once() + .returning(|_| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); @@ -502,8 +588,10 @@ fn read_err() { #[test] fn write_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err(); + let mut file = MockFile::default(); + file.expect_inner_write() + .once() + .returning(|_| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); @@ -518,8 +606,19 @@ fn write_write_err() { #[test] fn write_read_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().read(HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -541,8 +640,19 @@ fn write_read_write_err() { #[test] fn write_read_flush_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().read(HELLO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); let mut file = File::from_std(file); @@ -564,8 +674,17 @@ fn write_read_flush_err() { #[test] fn write_seek_write_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().seek_start_ok(0); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Start(0))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); let mut file = File::from_std(file); @@ -587,8 +706,17 @@ fn write_seek_write_err() { #[test] fn write_seek_flush_err() { - let (mock, file) = sys::File::mock(); - mock.write_err().seek_start_ok(0); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .returning(|_| Err(io::ErrorKind::Other.into())); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Start(0))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); let mut file = File::from_std(file); @@ -610,8 +738,14 @@ fn write_seek_flush_err() { #[test] fn sync_all_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_all(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_all().once().returning(|| Ok(())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -635,8 +769,16 @@ fn sync_all_ordered_after_write() { #[test] fn sync_all_err_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_all_err(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_all() + .once() + .returning(|| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -660,8 +802,14 @@ fn sync_all_err_ordered_after_write() { #[test] fn sync_data_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_data(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_data().once().returning(|| Ok(())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -685,8 +833,16 @@ fn sync_data_ordered_after_write() { #[test] fn sync_data_err_ordered_after_write() { - let (mock, file) = sys::File::mock(); - mock.write(HELLO).sync_data_err(); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_write() + .once() + .in_sequence(&mut seq) + .with(eq(HELLO)) + .returning(|_| Ok(HELLO.len())); + file.expect_sync_data() + .once() + .returning(|| Err(io::ErrorKind::Other.into())); let mut file = File::from_std(file); let mut t = task::spawn(file.write(HELLO)); @@ -710,17 +866,15 @@ fn sync_data_err_ordered_after_write() { #[test] fn open_set_len_ok() { - let (mock, file) = sys::File::mock(); - mock.set_len(123); + let mut file = MockFile::default(); + file.expect_set_len().with(eq(123)).returning(|_| Ok(())); let file = File::from_std(file); let mut t = task::spawn(file.set_len(123)); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); assert_ready_ok!(t.poll()); @@ -728,17 +882,17 @@ fn open_set_len_ok() { #[test] fn open_set_len_err() { - let (mock, file) = sys::File::mock(); - mock.set_len_err(123); + let mut file = MockFile::default(); + file.expect_set_len() + .with(eq(123)) + .returning(|_| Err(io::ErrorKind::Other.into())); let file = File::from_std(file); let mut t = task::spawn(file.set_len(123)); assert_pending!(t.poll()); - assert_eq!(1, mock.remaining()); pool::run_one(); - assert_eq!(0, mock.remaining()); assert!(t.is_woken()); assert_ready_err!(t.poll()); @@ -746,11 +900,32 @@ fn open_set_len_err() { #[test] fn partial_read_set_len_ok() { - let (mock, file) = sys::File::mock(); - mock.read(HELLO) - .seek_current_ok(-14, 0) - .set_len(123) - .read(FOO); + let mut file = MockFile::default(); + let mut seq = Sequence::new(); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..HELLO.len()].copy_from_slice(HELLO); + Ok(HELLO.len()) + }); + file.expect_inner_seek() + .once() + .with(eq(SeekFrom::Current(-(HELLO.len() as i64)))) + .in_sequence(&mut seq) + .returning(|_| Ok(0)); + file.expect_set_len() + .once() + .in_sequence(&mut seq) + .with(eq(123)) + .returning(|_| Ok(())); + file.expect_inner_read() + .once() + .in_sequence(&mut seq) + .returning(|buf| { + buf[0..FOO.len()].copy_from_slice(FOO); + Ok(FOO.len()) + }); let mut buf = [0; 32]; let mut file = File::from_std(file); diff --git a/src/fs/mocks.rs b/src/fs/mocks.rs new file mode 100644 index 0000000..68ef4f3 --- /dev/null +++ b/src/fs/mocks.rs @@ -0,0 +1,136 @@ +//! Mock version of std::fs::File; +use mockall::mock; + +use crate::sync::oneshot; +use std::{ + cell::RefCell, + collections::VecDeque, + fs::{Metadata, Permissions}, + future::Future, + io::{self, Read, Seek, SeekFrom, Write}, + path::PathBuf, + pin::Pin, + task::{Context, Poll}, +}; + +mock! { + #[derive(Debug)] + pub File { + pub fn create(pb: PathBuf) -> io::Result<Self>; + // These inner_ methods exist because std::fs::File has two + // implementations for each of these methods: one on "&mut self" and + // one on "&&self". Defining both of those in terms of an inner_ method + // allows us to specify the expectation the same way, regardless of + // which method is used. + pub fn inner_flush(&self) -> io::Result<()>; + pub fn inner_read(&self, dst: &mut [u8]) -> io::Result<usize>; + pub fn inner_seek(&self, pos: SeekFrom) -> io::Result<u64>; + pub fn inner_write(&self, src: &[u8]) -> io::Result<usize>; + pub fn metadata(&self) -> io::Result<Metadata>; + pub fn open(pb: PathBuf) -> io::Result<Self>; + pub fn set_len(&self, size: u64) -> io::Result<()>; + pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()>; + pub fn sync_all(&self) -> io::Result<()>; + pub fn sync_data(&self) -> io::Result<()>; + pub fn try_clone(&self) -> io::Result<Self>; + } + #[cfg(windows)] + impl std::os::windows::io::AsRawHandle for File { + fn as_raw_handle(&self) -> std::os::windows::io::RawHandle; + } + #[cfg(windows)] + impl std::os::windows::io::FromRawHandle for File { + unsafe fn from_raw_handle(h: std::os::windows::io::RawHandle) -> Self; + } + #[cfg(unix)] + impl std::os::unix::io::AsRawFd for File { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd; + } + + #[cfg(unix)] + impl std::os::unix::io::FromRawFd for File { + unsafe fn from_raw_fd(h: std::os::unix::io::RawFd) -> Self; + } +} + +impl Read for MockFile { + fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { + self.inner_read(dst) + } +} + +impl Read for &'_ MockFile { + fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { + self.inner_read(dst) + } +} + +impl Seek for &'_ MockFile { + fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> { + self.inner_seek(pos) + } +} + +impl Write for &'_ MockFile { + fn write(&mut self, src: &[u8]) -> io::Result<usize> { + self.inner_write(src) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner_flush() + } +} + +thread_local! { + static QUEUE: RefCell<VecDeque<Box<dyn FnOnce() + Send>>> = RefCell::new(VecDeque::new()) +} + +#[derive(Debug)] +pub(super) struct JoinHandle<T> { + rx: oneshot::Receiver<T>, +} + +pub(super) fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> +where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, +{ + let (tx, rx) = oneshot::channel(); + let task = Box::new(move || { + let _ = tx.send(f()); + }); + + QUEUE.with(|cell| cell.borrow_mut().push_back(task)); + + JoinHandle { rx } +} + +impl<T> Future for JoinHandle<T> { + type Output = Result<T, io::Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + use std::task::Poll::*; + + match Pin::new(&mut self.rx).poll(cx) { + Ready(Ok(v)) => Ready(Ok(v)), + Ready(Err(e)) => panic!("error = {:?}", e), + Pending => Pending, + } + } +} + +pub(super) mod pool { + use super::*; + + pub(in super::super) fn len() -> usize { + QUEUE.with(|cell| cell.borrow().len()) + } + + pub(in super::super) fn run_one() { + let task = QUEUE + .with(|cell| cell.borrow_mut().pop_front()) + .expect("expected task to run, but none ready"); + + task(); + } +} diff --git a/src/fs/mod.rs b/src/fs/mod.rs index d4f0074..ca0264b 100644 --- a/src/fs/mod.rs +++ b/src/fs/mod.rs @@ -84,6 +84,9 @@ pub use self::write::write; mod copy; pub use self::copy::copy; +#[cfg(test)] +mod mocks; + feature! { #![unix] @@ -103,12 +106,17 @@ feature! { use std::io; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(test)] +use mocks::spawn_blocking; + pub(crate) async fn asyncify<F, T>(f: F) -> io::Result<T> where F: FnOnce() -> io::Result<T> + Send + 'static, T: Send + 'static, { - match sys::run(f).await { + match spawn_blocking(f).await { Ok(res) => res, Err(_) => Err(io::Error::new( io::ErrorKind::Other, @@ -116,12 +124,3 @@ where )), } } - -/// Types in this module can be mocked out in tests. -mod sys { - pub(crate) use std::fs::File; - - // TODO: don't rename - pub(crate) use crate::blocking::spawn_blocking as run; - pub(crate) use crate::blocking::JoinHandle as Blocking; -} diff --git a/src/fs/open_options.rs b/src/fs/open_options.rs index fa37a60..f3b4654 100644 --- a/src/fs/open_options.rs +++ b/src/fs/open_options.rs @@ -3,6 +3,13 @@ use crate::fs::{asyncify, File}; use std::io; use std::path::Path; +#[cfg(test)] +mod mock_open_options; +#[cfg(test)] +use mock_open_options::MockOpenOptions as StdOpenOptions; +#[cfg(not(test))] +use std::fs::OpenOptions as StdOpenOptions; + /// Options and flags which can be used to configure how a file is opened. /// /// This builder exposes the ability to configure how a [`File`] is opened and @@ -69,7 +76,7 @@ use std::path::Path; /// } /// ``` #[derive(Clone, Debug)] -pub struct OpenOptions(std::fs::OpenOptions); +pub struct OpenOptions(StdOpenOptions); impl OpenOptions { /// Creates a blank new set of options ready for configuration. @@ -89,7 +96,7 @@ impl OpenOptions { /// let future = options.read(true).open("foo.txt"); /// ``` pub fn new() -> OpenOptions { - OpenOptions(std::fs::OpenOptions::new()) + OpenOptions(StdOpenOptions::new()) } /// Sets the option for read access. @@ -384,7 +391,7 @@ impl OpenOptions { } /// Returns a mutable reference to the underlying `std::fs::OpenOptions` - pub(super) fn as_inner_mut(&mut self) -> &mut std::fs::OpenOptions { + pub(super) fn as_inner_mut(&mut self) -> &mut StdOpenOptions { &mut self.0 } } @@ -423,7 +430,7 @@ feature! { self } - /// Pass custom flags to the `flags` argument of `open`. + /// Passes custom flags to the `flags` argument of `open`. /// /// The bits that define the access mode are masked out with `O_ACCMODE`, to /// ensure they do not interfere with the access mode set by Rusts options. @@ -645,8 +652,8 @@ feature! { } } -impl From<std::fs::OpenOptions> for OpenOptions { - fn from(options: std::fs::OpenOptions) -> OpenOptions { +impl From<StdOpenOptions> for OpenOptions { + fn from(options: StdOpenOptions) -> OpenOptions { OpenOptions(options) } } diff --git a/src/fs/open_options/mock_open_options.rs b/src/fs/open_options/mock_open_options.rs new file mode 100644 index 0000000..cbbda0e --- /dev/null +++ b/src/fs/open_options/mock_open_options.rs @@ -0,0 +1,38 @@ +//! Mock version of std::fs::OpenOptions; +use mockall::mock; + +use crate::fs::mocks::MockFile; +#[cfg(unix)] +use std::os::unix::fs::OpenOptionsExt; +#[cfg(windows)] +use std::os::windows::fs::OpenOptionsExt; +use std::{io, path::Path}; + +mock! { + #[derive(Debug)] + pub OpenOptions { + pub fn append(&mut self, append: bool) -> &mut Self; + pub fn create(&mut self, create: bool) -> &mut Self; + pub fn create_new(&mut self, create_new: bool) -> &mut Self; + pub fn open<P: AsRef<Path> + 'static>(&self, path: P) -> io::Result<MockFile>; + pub fn read(&mut self, read: bool) -> &mut Self; + pub fn truncate(&mut self, truncate: bool) -> &mut Self; + pub fn write(&mut self, write: bool) -> &mut Self; + } + impl Clone for OpenOptions { + fn clone(&self) -> Self; + } + #[cfg(unix)] + impl OpenOptionsExt for OpenOptions { + fn custom_flags(&mut self, flags: i32) -> &mut Self; + fn mode(&mut self, mode: u32) -> &mut Self; + } + #[cfg(windows)] + impl OpenOptionsExt for OpenOptions { + fn access_mode(&mut self, access: u32) -> &mut Self; + fn share_mode(&mut self, val: u32) -> &mut Self; + fn custom_flags(&mut self, flags: u32) -> &mut Self; + fn attributes(&mut self, val: u32) -> &mut Self; + fn security_qos_flags(&mut self, flags: u32) -> &mut Self; + } +} diff --git a/src/fs/read.rs b/src/fs/read.rs index 2d80eb5..ada5ba3 100644 --- a/src/fs/read.rs +++ b/src/fs/read.rs @@ -13,8 +13,12 @@ use std::{io, path::Path}; /// buffer based on the file size when available, so it is generally faster than /// reading into a vector created with `Vec::new()`. /// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// /// [`File::open`]: super::File::open /// [`read_to_end`]: crate::io::AsyncReadExt::read_to_end +/// [`spawn_blocking`]: crate::task::spawn_blocking /// /// # Errors /// diff --git a/src/fs/read_dir.rs b/src/fs/read_dir.rs index aedaf7b..281ea4c 100644 --- a/src/fs/read_dir.rs +++ b/src/fs/read_dir.rs @@ -1,4 +1,4 @@ -use crate::fs::{asyncify, sys}; +use crate::fs::asyncify; use std::ffi::OsString; use std::fs::{FileType, Metadata}; @@ -10,9 +10,23 @@ use std::sync::Arc; use std::task::Context; use std::task::Poll; +#[cfg(test)] +use super::mocks::spawn_blocking; +#[cfg(test)] +use super::mocks::JoinHandle; +#[cfg(not(test))] +use crate::blocking::spawn_blocking; +#[cfg(not(test))] +use crate::blocking::JoinHandle; + /// Returns a stream over the entries within a directory. /// /// This is an async version of [`std::fs::read_dir`](std::fs::read_dir) +/// +/// This operation is implemented by running the equivalent blocking +/// operation on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking pub async fn read_dir(path: impl AsRef<Path>) -> io::Result<ReadDir> { let path = path.as_ref().to_owned(); let std = asyncify(|| std::fs::read_dir(path)).await?; @@ -20,7 +34,7 @@ pub async fn read_dir(path: impl AsRef<Path>) -> io::Result<ReadDir> { Ok(ReadDir(State::Idle(Some(std)))) } -/// Read the the entries in a directory. +/// Reads the the entries in a directory. /// /// This struct is returned from the [`read_dir`] function of this module and /// will yield instances of [`DirEntry`]. Through a [`DirEntry`] information @@ -45,11 +59,15 @@ pub struct ReadDir(State); #[derive(Debug)] enum State { Idle(Option<std::fs::ReadDir>), - Pending(sys::Blocking<(Option<io::Result<std::fs::DirEntry>>, std::fs::ReadDir)>), + Pending(JoinHandle<(Option<io::Result<std::fs::DirEntry>>, std::fs::ReadDir)>), } impl ReadDir { /// Returns the next entry in the directory stream. + /// + /// # Cancel safety + /// + /// This method is cancellation safe. pub async fn next_entry(&mut self) -> io::Result<Option<DirEntry>> { use crate::future::poll_fn; poll_fn(|cx| self.poll_next_entry(cx)).await @@ -79,7 +97,7 @@ impl ReadDir { State::Idle(ref mut std) => { let mut std = std.take().unwrap(); - self.0 = State::Pending(sys::run(move || { + self.0 = State::Pending(spawn_blocking(move || { let ret = std.next(); (ret, std) })); @@ -269,7 +287,7 @@ impl DirEntry { asyncify(move || std.file_type()).await } - /// Returns a reference to the underlying `std::fs::DirEntry` + /// Returns a reference to the underlying `std::fs::DirEntry`. #[cfg(unix)] pub(super) fn as_inner(&self) -> &std::fs::DirEntry { &self.0 diff --git a/src/fs/read_to_string.rs b/src/fs/read_to_string.rs index 4f37986..26228d9 100644 --- a/src/fs/read_to_string.rs +++ b/src/fs/read_to_string.rs @@ -7,6 +7,10 @@ use std::{io, path::Path}; /// /// This is the async equivalent of [`std::fs::read_to_string`][std]. /// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking /// [std]: fn@std::fs::read_to_string /// /// # Examples diff --git a/src/fs/write.rs b/src/fs/write.rs index 0ed9082..28606fb 100644 --- a/src/fs/write.rs +++ b/src/fs/write.rs @@ -7,6 +7,10 @@ use std::{io, path::Path}; /// /// This is the async equivalent of [`std::fs::write`][std]. /// +/// This operation is implemented by running the equivalent blocking operation +/// on a separate thread pool using [`spawn_blocking`]. +/// +/// [`spawn_blocking`]: crate::task::spawn_blocking /// [std]: fn@std::fs::write /// /// # Examples diff --git a/src/future/maybe_done.rs b/src/future/maybe_done.rs index 1e083ad..486efbe 100644 --- a/src/future/maybe_done.rs +++ b/src/future/maybe_done.rs @@ -1,4 +1,4 @@ -//! Definition of the MaybeDone combinator +//! Definition of the MaybeDone combinator. use std::future::Future; use std::mem; @@ -8,9 +8,9 @@ use std::task::{Context, Poll}; /// A future that may have completed. #[derive(Debug)] pub enum MaybeDone<Fut: Future> { - /// A not-yet-completed future + /// A not-yet-completed future. Future(Fut), - /// The output of the completed future + /// The output of the completed future. Done(Fut::Output), /// The empty variant after the result of a [`MaybeDone`] has been /// taken using the [`take_output`](MaybeDone::take_output) method. @@ -20,7 +20,7 @@ pub enum MaybeDone<Fut: Future> { // Safe because we never generate `Pin<&mut Fut::Output>` impl<Fut: Future + Unpin> Unpin for MaybeDone<Fut> {} -/// Wraps a future into a `MaybeDone` +/// Wraps a future into a `MaybeDone`. pub fn maybe_done<Fut: Future>(future: Fut) -> MaybeDone<Fut> { MaybeDone::Future(future) } diff --git a/src/future/mod.rs b/src/future/mod.rs index f7d93c9..96483ac 100644 --- a/src/future/mod.rs +++ b/src/future/mod.rs @@ -22,3 +22,14 @@ cfg_sync! { mod block_on; pub(crate) use block_on::block_on; } + +cfg_trace! { + mod trace; + pub(crate) use trace::InstrumentedFuture as Future; +} + +cfg_not_trace! { + cfg_rt! { + pub(crate) use std::future::Future; + } +} diff --git a/src/future/poll_fn.rs b/src/future/poll_fn.rs index 0169bd5..d82ce89 100644 --- a/src/future/poll_fn.rs +++ b/src/future/poll_fn.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] -//! Definition of the `PollFn` adapter combinator +//! Definition of the `PollFn` adapter combinator. use std::fmt; use std::future::Future; diff --git a/src/future/trace.rs b/src/future/trace.rs new file mode 100644 index 0000000..28789a6 --- /dev/null +++ b/src/future/trace.rs @@ -0,0 +1,11 @@ +use std::future::Future; + +pub(crate) trait InstrumentedFuture: Future { + fn id(&self) -> Option<tracing::Id>; +} + +impl<F: Future> InstrumentedFuture for tracing::instrument::Instrumented<F> { + fn id(&self) -> Option<tracing::Id> { + self.span().id() + } +} diff --git a/src/io/async_fd.rs b/src/io/async_fd.rs index 5a68d30..9ec5b7f 100644 --- a/src/io/async_fd.rs +++ b/src/io/async_fd.rs @@ -205,13 +205,13 @@ impl<T: AsRawFd> AsyncFd<T> { }) } - /// Returns a shared reference to the backing object of this [`AsyncFd`] + /// Returns a shared reference to the backing object of this [`AsyncFd`]. #[inline] pub fn get_ref(&self) -> &T { self.inner.as_ref().unwrap() } - /// Returns a mutable reference to the backing object of this [`AsyncFd`] + /// Returns a mutable reference to the backing object of this [`AsyncFd`]. #[inline] pub fn get_mut(&mut self) -> &mut T { self.inner.as_mut().unwrap() @@ -540,6 +540,16 @@ impl<'a, Inner: AsRawFd> AsyncFdReadyGuard<'a, Inner> { result => Ok(result), } } + + /// Returns a shared reference to the inner [`AsyncFd`]. + pub fn get_ref(&self) -> &AsyncFd<Inner> { + self.async_fd + } + + /// Returns a shared reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner(&self) -> &Inner { + self.get_ref().get_ref() + } } impl<'a, Inner: AsRawFd> AsyncFdReadyMutGuard<'a, Inner> { @@ -601,6 +611,26 @@ impl<'a, Inner: AsRawFd> AsyncFdReadyMutGuard<'a, Inner> { result => Ok(result), } } + + /// Returns a shared reference to the inner [`AsyncFd`]. + pub fn get_ref(&self) -> &AsyncFd<Inner> { + self.async_fd + } + + /// Returns a mutable reference to the inner [`AsyncFd`]. + pub fn get_mut(&mut self) -> &mut AsyncFd<Inner> { + self.async_fd + } + + /// Returns a shared reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner(&self) -> &Inner { + self.get_ref().get_ref() + } + + /// Returns a mutable reference to the backing object of the inner [`AsyncFd`]. + pub fn get_inner_mut(&mut self) -> &mut Inner { + self.get_mut().get_mut() + } } impl<'a, T: std::fmt::Debug + AsRawFd> std::fmt::Debug for AsyncFdReadyGuard<'a, T> { diff --git a/src/io/async_write.rs b/src/io/async_write.rs index 569fb9c..7ec1a30 100644 --- a/src/io/async_write.rs +++ b/src/io/async_write.rs @@ -45,7 +45,11 @@ use std::task::{Context, Poll}; pub trait AsyncWrite { /// Attempt to write bytes from `buf` into the object. /// - /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. + /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. If successful, + /// then it must be guaranteed that `n <= buf.len()`. A return value of `0` + /// typically means that the underlying object is no longer able to accept + /// bytes and will likely not be able to in the future as well, or that the + /// buffer provided is empty. /// /// If the object is not ready for writing, the method returns /// `Poll::Pending` and arranges for the current task (via diff --git a/src/io/blocking.rs b/src/io/blocking.rs index 94a3484..1d79ee7 100644 --- a/src/io/blocking.rs +++ b/src/io/blocking.rs @@ -16,7 +16,7 @@ use self::State::*; pub(crate) struct Blocking<T> { inner: Option<T>, state: State<T>, - /// `true` if the lower IO layer needs flushing + /// `true` if the lower IO layer needs flushing. need_flush: bool, } @@ -175,7 +175,7 @@ where } } -/// Repeats operations that are interrupted +/// Repeats operations that are interrupted. macro_rules! uninterruptibly { ($e:expr) => {{ loop { diff --git a/src/io/bsd/poll_aio.rs b/src/io/bsd/poll_aio.rs new file mode 100644 index 0000000..f1ac4b2 --- /dev/null +++ b/src/io/bsd/poll_aio.rs @@ -0,0 +1,195 @@ +//! Use POSIX AIO futures with Tokio. + +use crate::io::driver::{Handle, Interest, ReadyEvent, Registration}; +use mio::event::Source; +use mio::Registry; +use mio::Token; +use std::fmt; +use std::io; +use std::ops::{Deref, DerefMut}; +use std::os::unix::io::AsRawFd; +use std::os::unix::prelude::RawFd; +use std::task::{Context, Poll}; + +/// Like [`mio::event::Source`], but for POSIX AIO only. +/// +/// Tokio's consumer must pass an implementor of this trait to create a +/// [`Aio`] object. +pub trait AioSource { + /// Registers this AIO event source with Tokio's reactor. + fn register(&mut self, kq: RawFd, token: usize); + + /// Deregisters this AIO event source with Tokio's reactor. + fn deregister(&mut self); +} + +/// Wraps the user's AioSource in order to implement mio::event::Source, which +/// is what the rest of the crate wants. +struct MioSource<T>(T); + +impl<T: AioSource> Source for MioSource<T> { + fn register( + &mut self, + registry: &Registry, + token: Token, + interests: mio::Interest, + ) -> io::Result<()> { + assert!(interests.is_aio() || interests.is_lio()); + self.0.register(registry.as_raw_fd(), usize::from(token)); + Ok(()) + } + + fn deregister(&mut self, _registry: &Registry) -> io::Result<()> { + self.0.deregister(); + Ok(()) + } + + fn reregister( + &mut self, + registry: &Registry, + token: Token, + interests: mio::Interest, + ) -> io::Result<()> { + assert!(interests.is_aio() || interests.is_lio()); + self.0.register(registry.as_raw_fd(), usize::from(token)); + Ok(()) + } +} + +/// Associates a POSIX AIO control block with the reactor that drives it. +/// +/// `Aio`'s wrapped type must implement [`AioSource`] to be driven +/// by the reactor. +/// +/// The wrapped source may be accessed through the `Aio` via the `Deref` and +/// `DerefMut` traits. +/// +/// ## Clearing readiness +/// +/// If [`Aio::poll_ready`] returns ready, but the consumer determines that the +/// Source is not completely ready and must return to the Pending state, +/// [`Aio::clear_ready`] may be used. This can be useful with +/// [`lio_listio`], which may generate a kevent when only a portion of the +/// operations have completed. +/// +/// ## Platforms +/// +/// Only FreeBSD implements POSIX AIO with kqueue notification, so +/// `Aio` is only available for that operating system. +/// +/// [`lio_listio`]: https://pubs.opengroup.org/onlinepubs/9699919799/functions/lio_listio.html +// Note: Unlike every other kqueue event source, POSIX AIO registers events not +// via kevent(2) but when the aiocb is submitted to the kernel via aio_read, +// aio_write, etc. It needs the kqueue's file descriptor to do that. So +// AsyncFd can't be used for POSIX AIO. +// +// Note that Aio doesn't implement Drop. There's no need. Unlike other +// kqueue sources, simply dropping the object effectively deregisters it. +pub struct Aio<E> { + io: MioSource<E>, + registration: Registration, +} + +// ===== impl Aio ===== + +impl<E: AioSource> Aio<E> { + /// Creates a new `Aio` suitable for use with POSIX AIO functions. + /// + /// It will be associated with the default reactor. The runtime is usually + /// set implicitly when this function is called from a future driven by a + /// Tokio runtime, otherwise runtime can be set explicitly with + /// [`Runtime::enter`](crate::runtime::Runtime::enter) function. + pub fn new_for_aio(io: E) -> io::Result<Self> { + Self::new_with_interest(io, Interest::AIO) + } + + /// Creates a new `Aio` suitable for use with [`lio_listio`]. + /// + /// It will be associated with the default reactor. The runtime is usually + /// set implicitly when this function is called from a future driven by a + /// Tokio runtime, otherwise runtime can be set explicitly with + /// [`Runtime::enter`](crate::runtime::Runtime::enter) function. + /// + /// [`lio_listio`]: https://pubs.opengroup.org/onlinepubs/9699919799/functions/lio_listio.html + pub fn new_for_lio(io: E) -> io::Result<Self> { + Self::new_with_interest(io, Interest::LIO) + } + + fn new_with_interest(io: E, interest: Interest) -> io::Result<Self> { + let mut io = MioSource(io); + let handle = Handle::current(); + let registration = Registration::new_with_interest_and_handle(&mut io, interest, handle)?; + Ok(Self { io, registration }) + } + + /// Indicates to Tokio that the source is no longer ready. The internal + /// readiness flag will be cleared, and tokio will wait for the next + /// edge-triggered readiness notification from the OS. + /// + /// It is critical that this method not be called unless your code + /// _actually observes_ that the source is _not_ ready. The OS must + /// deliver a subsequent notification, or this source will block + /// forever. It is equally critical that you `do` call this method if you + /// resubmit the same structure to the kernel and poll it again. + /// + /// This method is not very useful with AIO readiness, since each `aiocb` + /// structure is typically only used once. It's main use with + /// [`lio_listio`], which will sometimes send notification when only a + /// portion of its elements are complete. In that case, the caller must + /// call `clear_ready` before resubmitting it. + /// + /// [`lio_listio`]: https://pubs.opengroup.org/onlinepubs/9699919799/functions/lio_listio.html + pub fn clear_ready(&self, ev: AioEvent) { + self.registration.clear_readiness(ev.0) + } + + /// Destroy the [`Aio`] and return its inner source. + pub fn into_inner(self) -> E { + self.io.0 + } + + /// Polls for readiness. Either AIO or LIO counts. + /// + /// This method returns: + /// * `Poll::Pending` if the underlying operation is not complete, whether + /// or not it completed successfully. This will be true if the OS is + /// still processing it, or if it has not yet been submitted to the OS. + /// * `Poll::Ready(Ok(_))` if the underlying operation is complete. + /// * `Poll::Ready(Err(_))` if the reactor has been shutdown. This does + /// _not_ indicate that the underlying operation encountered an error. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` + /// is scheduled to receive a wakeup when the underlying operation + /// completes. Note that on multiple calls to `poll_ready`, only the `Waker` from the + /// `Context` passed to the most recent call is scheduled to receive a wakeup. + pub fn poll_ready<'a>(&'a self, cx: &mut Context<'_>) -> Poll<io::Result<AioEvent>> { + let ev = ready!(self.registration.poll_read_ready(cx))?; + Poll::Ready(Ok(AioEvent(ev))) + } +} + +impl<E: AioSource> Deref for Aio<E> { + type Target = E; + + fn deref(&self) -> &E { + &self.io.0 + } +} + +impl<E: AioSource> DerefMut for Aio<E> { + fn deref_mut(&mut self) -> &mut E { + &mut self.io.0 + } +} + +impl<E: AioSource + fmt::Debug> fmt::Debug for Aio<E> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Aio").field("io", &self.io.0).finish() + } +} + +/// Opaque data returned by [`Aio::poll_ready`]. +/// +/// It can be fed back to [`Aio::clear_ready`]. +#[derive(Debug)] +pub struct AioEvent(ReadyEvent); diff --git a/src/io/driver/interest.rs b/src/io/driver/interest.rs index 9eead08..d6b46df 100644 --- a/src/io/driver/interest.rs +++ b/src/io/driver/interest.rs @@ -5,7 +5,7 @@ use crate::io::driver::Ready; use std::fmt; use std::ops; -/// Readiness event interest +/// Readiness event interest. /// /// Specifies the readiness events the caller is interested in when awaiting on /// I/O resource readiness states. @@ -14,12 +14,32 @@ use std::ops; pub struct Interest(mio::Interest); impl Interest { + // The non-FreeBSD definitions in this block are active only when + // building documentation. + cfg_aio! { + /// Interest for POSIX AIO. + #[cfg(target_os = "freebsd")] + pub const AIO: Interest = Interest(mio::Interest::AIO); + + /// Interest for POSIX AIO. + #[cfg(not(target_os = "freebsd"))] + pub const AIO: Interest = Interest(mio::Interest::READABLE); + + /// Interest for POSIX AIO lio_listio events. + #[cfg(target_os = "freebsd")] + pub const LIO: Interest = Interest(mio::Interest::LIO); + + /// Interest for POSIX AIO lio_listio events. + #[cfg(not(target_os = "freebsd"))] + pub const LIO: Interest = Interest(mio::Interest::READABLE); + } + /// Interest in all readable events. /// /// Readable interest includes read-closed events. pub const READABLE: Interest = Interest(mio::Interest::READABLE); - /// Interest in all writable events + /// Interest in all writable events. /// /// Writable interest includes write-closed events. pub const WRITABLE: Interest = Interest(mio::Interest::WRITABLE); @@ -58,7 +78,7 @@ impl Interest { self.0.is_writable() } - /// Add together two `Interst` values. + /// Add together two `Interest` values. /// /// This function works from a `const` context. /// diff --git a/src/io/driver/mod.rs b/src/io/driver/mod.rs index fa2d420..19f67a2 100644 --- a/src/io/driver/mod.rs +++ b/src/io/driver/mod.rs @@ -23,10 +23,10 @@ use std::io; use std::sync::{Arc, Weak}; use std::time::Duration; -/// I/O driver, backed by Mio +/// I/O driver, backed by Mio. pub(crate) struct Driver { /// Tracks the number of times `turn` is called. It is safe for this to wrap - /// as it is mostly used to determine when to call `compact()` + /// as it is mostly used to determine when to call `compact()`. tick: u8, /// Reuse the `mio::Events` value across calls to poll. @@ -35,22 +35,23 @@ pub(crate) struct Driver { /// Primary slab handle containing the state for each resource registered /// with this driver. During Drop this is moved into the Inner structure, so /// this is an Option to allow it to be vacated (until Drop this is always - /// Some) + /// Some). resources: Option<Slab<ScheduledIo>>, - /// The system event queue + /// The system event queue. poll: mio::Poll, /// State shared between the reactor and the handles. inner: Arc<Inner>, } -/// A reference to an I/O driver +/// A reference to an I/O driver. #[derive(Clone)] pub(crate) struct Handle { inner: Weak<Inner>, } +#[derive(Debug)] pub(crate) struct ReadyEvent { tick: u8, pub(crate) ready: Ready, @@ -65,13 +66,13 @@ pub(super) struct Inner { /// without risking new ones being registered in the meantime. resources: Mutex<Option<Slab<ScheduledIo>>>, - /// Registers I/O resources + /// Registers I/O resources. registry: mio::Registry, /// Allocates `ScheduledIo` handles when creating new resources. pub(super) io_dispatch: slab::Allocator<ScheduledIo>, - /// Used to wake up the reactor from a call to `turn` + /// Used to wake up the reactor from a call to `turn`. waker: mio::Waker, } @@ -96,7 +97,7 @@ const ADDRESS: bit::Pack = bit::Pack::least_significant(24); // // The generation prevents a race condition where a slab slot is reused for a // new socket while the I/O driver is about to apply a readiness event. The -// generaton value is checked when setting new readiness. If the generation do +// generation value is checked when setting new readiness. If the generation do // not match, then the readiness event is discarded. const GENERATION: bit::Pack = ADDRESS.then(7); @@ -252,7 +253,7 @@ impl fmt::Debug for Driver { cfg_rt! { impl Handle { - /// Returns a handle to the current reactor + /// Returns a handle to the current reactor. /// /// # Panics /// @@ -266,14 +267,14 @@ cfg_rt! { cfg_not_rt! { impl Handle { - /// Returns a handle to the current reactor + /// Returns a handle to the current reactor. /// /// # Panics /// /// This function panics if there is no current reactor set, or if the `rt` /// feature flag is not enabled. pub(super) fn current() -> Self { - panic!(crate::util::error::CONTEXT_MISSING_ERROR) + panic!("{}", crate::util::error::CONTEXT_MISSING_ERROR) } } } diff --git a/src/io/driver/ready.rs b/src/io/driver/ready.rs index 2ac01bd..2430d30 100644 --- a/src/io/driver/ready.rs +++ b/src/io/driver/ready.rs @@ -38,6 +38,17 @@ impl Ready { pub(crate) fn from_mio(event: &mio::event::Event) -> Ready { let mut ready = Ready::EMPTY; + #[cfg(all(target_os = "freebsd", feature = "net"))] + { + if event.is_aio() { + ready |= Ready::READABLE; + } + + if event.is_lio() { + ready |= Ready::READABLE; + } + } + if event.is_readable() { ready |= Ready::READABLE; } @@ -57,7 +68,7 @@ impl Ready { ready } - /// Returns true if `Ready` is the empty set + /// Returns true if `Ready` is the empty set. /// /// # Examples /// @@ -71,7 +82,7 @@ impl Ready { self == Ready::EMPTY } - /// Returns `true` if the value includes `readable` + /// Returns `true` if the value includes `readable`. /// /// # Examples /// @@ -87,7 +98,7 @@ impl Ready { self.contains(Ready::READABLE) || self.is_read_closed() } - /// Returns `true` if the value includes writable `readiness` + /// Returns `true` if the value includes writable `readiness`. /// /// # Examples /// @@ -103,7 +114,7 @@ impl Ready { self.contains(Ready::WRITABLE) || self.is_write_closed() } - /// Returns `true` if the value includes read-closed `readiness` + /// Returns `true` if the value includes read-closed `readiness`. /// /// # Examples /// @@ -118,7 +129,7 @@ impl Ready { self.contains(Ready::READ_CLOSED) } - /// Returns `true` if the value includes write-closed `readiness` + /// Returns `true` if the value includes write-closed `readiness`. /// /// # Examples /// @@ -143,7 +154,7 @@ impl Ready { (self & other) == other } - /// Create a `Ready` instance using the given `usize` representation. + /// Creates a `Ready` instance using the given `usize` representation. /// /// The `usize` representation must have been obtained from a call to /// `Readiness::as_usize`. diff --git a/src/io/driver/registration.rs b/src/io/driver/registration.rs index 8251fe6..7350be6 100644 --- a/src/io/driver/registration.rs +++ b/src/io/driver/registration.rs @@ -14,8 +14,9 @@ cfg_io_driver! { /// that it will receive task notifications on readiness. This is the lowest /// level API for integrating with a reactor. /// - /// The association between an I/O resource is made by calling [`new`]. Once - /// the association is established, it remains established until the + /// The association between an I/O resource is made by calling + /// [`new_with_interest_and_handle`]. + /// Once the association is established, it remains established until the /// registration instance is dropped. /// /// A registration instance represents two separate readiness streams. One @@ -36,7 +37,7 @@ cfg_io_driver! { /// stream. The write readiness event stream is only for `Ready::writable()` /// events. /// - /// [`new`]: method@Self::new + /// [`new_with_interest_and_handle`]: method@Self::new_with_interest_and_handle /// [`poll_read_ready`]: method@Self::poll_read_ready` /// [`poll_write_ready`]: method@Self::poll_write_ready` #[derive(Debug)] diff --git a/src/io/driver/scheduled_io.rs b/src/io/driver/scheduled_io.rs index 2626b40..76f9343 100644 --- a/src/io/driver/scheduled_io.rs +++ b/src/io/driver/scheduled_io.rs @@ -3,6 +3,7 @@ use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Mutex; use crate::util::bit; use crate::util::slab::Entry; +use crate::util::WakeList; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; use std::task::{Context, Poll, Waker}; @@ -35,16 +36,16 @@ cfg_io_readiness! { #[derive(Debug, Default)] struct Waiters { #[cfg(feature = "net")] - /// List of all current waiters + /// List of all current waiters. list: WaitList, - /// Waker used for AsyncRead + /// Waker used for AsyncRead. reader: Option<Waker>, - /// Waker used for AsyncWrite + /// Waker used for AsyncWrite. writer: Option<Waker>, - /// True if this ScheduledIo has been killed due to IO driver shutdown + /// True if this ScheduledIo has been killed due to IO driver shutdown. is_shutdown: bool, } @@ -53,19 +54,19 @@ cfg_io_readiness! { struct Waiter { pointers: linked_list::Pointers<Waiter>, - /// The waker for this task + /// The waker for this task. waker: Option<Waker>, - /// The interest this waiter is waiting on + /// The interest this waiter is waiting on. interest: Interest, is_ready: bool, - /// Should never be `!Unpin` + /// Should never be `!Unpin`. _p: PhantomPinned, } - /// Future returned by `readiness()` + /// Future returned by `readiness()`. struct Readiness<'a> { scheduled_io: &'a ScheduledIo, @@ -84,9 +85,9 @@ cfg_io_readiness! { // The `ScheduledIo::readiness` (`AtomicUsize`) is packed full of goodness. // -// | reserved | generation | driver tick | readinesss | -// |----------+------------+--------------+------------| -// | 1 bit | 7 bits + 8 bits + 16 bits | +// | reserved | generation | driver tick | readiness | +// |----------+------------+--------------+-----------| +// | 1 bit | 7 bits + 8 bits + 16 bits | const READINESS: bit::Pack = bit::Pack::least_significant(16); @@ -212,10 +213,7 @@ impl ScheduledIo { } fn wake0(&self, ready: Ready, shutdown: bool) { - const NUM_WAKERS: usize = 32; - - let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default(); - let mut curr = 0; + let mut wakers = WakeList::new(); let mut waiters = self.waiters.lock(); @@ -224,16 +222,14 @@ impl ScheduledIo { // check for AsyncRead slot if ready.is_readable() { if let Some(waker) = waiters.reader.take() { - wakers[curr] = Some(waker); - curr += 1; + wakers.push(waker); } } // check for AsyncWrite slot if ready.is_writable() { if let Some(waker) = waiters.writer.take() { - wakers[curr] = Some(waker); - curr += 1; + wakers.push(waker); } } @@ -241,15 +237,14 @@ impl ScheduledIo { 'outer: loop { let mut iter = waiters.list.drain_filter(|w| ready.satisfies(w.interest)); - while curr < NUM_WAKERS { + while wakers.can_push() { match iter.next() { Some(waiter) => { let waiter = unsafe { &mut *waiter.as_ptr() }; if let Some(waker) = waiter.waker.take() { waiter.is_ready = true; - wakers[curr] = Some(waker); - curr += 1; + wakers.push(waker); } } None => { @@ -260,11 +255,7 @@ impl ScheduledIo { drop(waiters); - for waker in wakers.iter_mut().take(curr) { - waker.take().unwrap().wake(); - } - - curr = 0; + wakers.wake_all(); // Acquire the lock again. waiters = self.waiters.lock(); @@ -273,9 +264,7 @@ impl ScheduledIo { // Release the lock before notifying drop(waiters); - for waker in wakers.iter_mut().take(curr) { - waker.take().unwrap().wake(); - } + wakers.wake_all(); } pub(super) fn ready_event(&self, interest: Interest) -> ReadyEvent { @@ -287,7 +276,7 @@ impl ScheduledIo { } } - /// Poll version of checking readiness for a certain direction. + /// Polls for readiness events in a given direction. /// /// These are to support `AsyncRead` and `AsyncWrite` polling methods, /// which cannot use the `async fn` version. This uses reserved reader @@ -374,7 +363,7 @@ unsafe impl Sync for ScheduledIo {} cfg_io_readiness! { impl ScheduledIo { - /// An async version of `poll_readiness` which uses a linked list of wakers + /// An async version of `poll_readiness` which uses a linked list of wakers. pub(crate) async fn readiness(&self, interest: Interest) -> ReadyEvent { self.readiness_fut(interest).await } diff --git a/src/io/mod.rs b/src/io/mod.rs index 14a4a63..cfdda61 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -217,6 +217,15 @@ cfg_io_driver_impl! { pub(crate) use poll_evented::PollEvented; } +cfg_aio! { + /// BSD-specific I/O types. + pub mod bsd { + mod poll_aio; + + pub use poll_aio::{Aio, AioEvent, AioSource}; + } +} + cfg_net_unix! { mod async_fd; diff --git a/src/io/poll_evented.rs b/src/io/poll_evented.rs index 47ae558..44e68a2 100644 --- a/src/io/poll_evented.rs +++ b/src/io/poll_evented.rs @@ -10,10 +10,10 @@ cfg_io_driver! { /// [`std::io::Write`] traits with the reactor that drives it. /// /// `PollEvented` uses [`Registration`] internally to take a type that - /// implements [`mio::Evented`] as well as [`std::io::Read`] and or + /// implements [`mio::event::Source`] as well as [`std::io::Read`] and or /// [`std::io::Write`] and associate it with a reactor that will drive it. /// - /// Once the [`mio::Evented`] type is wrapped by `PollEvented`, it can be + /// Once the [`mio::event::Source`] type is wrapped by `PollEvented`, it can be /// used from within the future's execution model. As such, the /// `PollEvented` type provides [`AsyncRead`] and [`AsyncWrite`] /// implementations using the underlying I/O resource as well as readiness @@ -40,13 +40,12 @@ cfg_io_driver! { /// [`poll_read_ready`] again will also indicate read readiness. /// /// When the operation is attempted and is unable to succeed due to the I/O - /// resource not being ready, the caller must call [`clear_read_ready`] or - /// [`clear_write_ready`]. This clears the readiness state until a new - /// readiness event is received. + /// resource not being ready, the caller must call `clear_readiness`. + /// This clears the readiness state until a new readiness event is received. /// /// This allows the caller to implement additional functions. For example, /// [`TcpListener`] implements poll_accept by using [`poll_read_ready`] and - /// [`clear_read_ready`]. + /// `clear_read_ready`. /// /// ## Platform-specific events /// @@ -54,17 +53,11 @@ cfg_io_driver! { /// These events are included as part of the read readiness event stream. The /// write readiness event stream is only for `Ready::writable()` events. /// - /// [`std::io::Read`]: trait@std::io::Read - /// [`std::io::Write`]: trait@std::io::Write - /// [`AsyncRead`]: trait@AsyncRead - /// [`AsyncWrite`]: trait@AsyncWrite - /// [`mio::Evented`]: trait@mio::Evented - /// [`Registration`]: struct@Registration - /// [`TcpListener`]: struct@crate::net::TcpListener - /// [`clear_read_ready`]: method@Self::clear_read_ready - /// [`clear_write_ready`]: method@Self::clear_write_ready - /// [`poll_read_ready`]: method@Self::poll_read_ready - /// [`poll_write_ready`]: method@Self::poll_write_ready + /// [`AsyncRead`]: crate::io::AsyncRead + /// [`AsyncWrite`]: crate::io::AsyncWrite + /// [`TcpListener`]: crate::net::TcpListener + /// [`poll_read_ready`]: Registration::poll_read_ready + /// [`poll_write_ready`]: Registration::poll_write_ready pub(crate) struct PollEvented<E: Source> { io: Option<E>, registration: Registration, @@ -120,7 +113,7 @@ impl<E: Source> PollEvented<E> { }) } - /// Returns a reference to the registration + /// Returns a reference to the registration. #[cfg(any( feature = "net", all(unix, feature = "process"), @@ -130,7 +123,7 @@ impl<E: Source> PollEvented<E> { &self.registration } - /// Deregister the inner io from the registration and returns a Result containing the inner io + /// Deregisters the inner io from the registration and returns a Result containing the inner io. #[cfg(any(feature = "net", feature = "process"))] pub(crate) fn into_inner(mut self) -> io::Result<E> { let mut inner = self.io.take().unwrap(); // As io shouldn't ever be None, just unwrap here. diff --git a/src/io/read_buf.rs b/src/io/read_buf.rs index 38e857d..ad58cbe 100644 --- a/src/io/read_buf.rs +++ b/src/io/read_buf.rs @@ -45,7 +45,7 @@ impl<'a> ReadBuf<'a> { /// Creates a new `ReadBuf` from a fully uninitialized buffer. /// - /// Use `assume_init` if part of the buffer is known to be already inintialized. + /// Use `assume_init` if part of the buffer is known to be already initialized. #[inline] pub fn uninit(buf: &'a mut [MaybeUninit<u8>]) -> ReadBuf<'a> { ReadBuf { @@ -85,7 +85,7 @@ impl<'a> ReadBuf<'a> { #[inline] pub fn take(&mut self, n: usize) -> ReadBuf<'_> { let max = std::cmp::min(self.remaining(), n); - // Saftey: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`. + // Safety: We don't set any of the `unfilled_mut` with `MaybeUninit::uninit`. unsafe { ReadBuf::uninit(&mut self.unfilled_mut()[..max]) } } @@ -217,7 +217,7 @@ impl<'a> ReadBuf<'a> { /// /// # Panics /// - /// Panics if the filled region of the buffer would become larger than the intialized region. + /// Panics if the filled region of the buffer would become larger than the initialized region. #[inline] pub fn set_filled(&mut self, n: usize) { assert!( diff --git a/src/io/split.rs b/src/io/split.rs index 732eb3b..8258a0f 100644 --- a/src/io/split.rs +++ b/src/io/split.rs @@ -63,7 +63,7 @@ impl<T> ReadHalf<T> { /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same /// stream. pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool { - other.is_pair_of(&self) + other.is_pair_of(self) } /// Reunites with a previously split `WriteHalf`. @@ -90,7 +90,7 @@ impl<T> ReadHalf<T> { } impl<T> WriteHalf<T> { - /// Check if this `WriteHalf` and some `ReadHalf` were split from the same + /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same /// stream. pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool { Arc::ptr_eq(&self.inner, &other.inner) diff --git a/src/io/stdio_common.rs b/src/io/stdio_common.rs index d21c842..7e4a198 100644 --- a/src/io/stdio_common.rs +++ b/src/io/stdio_common.rs @@ -7,7 +7,7 @@ use std::task::{Context, Poll}; /// if buffer contents seems to be utf8. Otherwise it only trims buffer down to MAX_BUF. /// That's why, wrapped writer will always receive well-formed utf-8 bytes. /// # Other platforms -/// passes data to `inner` as is +/// Passes data to `inner` as is. #[derive(Debug)] pub(crate) struct SplitByUtf8BoundaryIfWindows<W> { inner: W, @@ -52,10 +52,10 @@ where buf = &buf[..crate::io::blocking::MAX_BUF]; - // Now there are two possibilites. + // Now there are two possibilities. // If caller gave is binary buffer, we **should not** shrink it // anymore, because excessive shrinking hits performance. - // If caller gave as binary buffer, we **must** additionaly + // If caller gave as binary buffer, we **must** additionally // shrink it to strip incomplete char at the end of buffer. // that's why check we will perform now is allowed to have // false-positive. diff --git a/src/io/util/async_buf_read_ext.rs b/src/io/util/async_buf_read_ext.rs index 233ac31..b241e35 100644 --- a/src/io/util/async_buf_read_ext.rs +++ b/src/io/util/async_buf_read_ext.rs @@ -1,3 +1,4 @@ +use crate::io::util::fill_buf::{fill_buf, FillBuf}; use crate::io::util::lines::{lines, Lines}; use crate::io::util::read_line::{read_line, ReadLine}; use crate::io::util::read_until::{read_until, ReadUntil}; @@ -36,6 +37,18 @@ cfg_io_util! { /// [`fill_buf`]: AsyncBufRead::poll_fill_buf /// [`ErrorKind::Interrupted`]: std::io::ErrorKind::Interrupted /// + /// # Cancel safety + /// + /// If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then some data may have been partially read. Any + /// partially read bytes are appended to `buf`, and the method can be + /// called again to continue reading until `byte`. + /// + /// This method returns the total number of bytes read. If you cancel + /// the call to `read_until` and then call it again to continue reading, + /// the counter is reset. + /// /// # Examples /// /// [`std::io::Cursor`][`Cursor`] is a type that implements `BufRead`. In @@ -114,6 +127,30 @@ cfg_io_util! { /// /// [`read_until`]: AsyncBufReadExt::read_until /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If the method is used as the + /// event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then some data may have been partially + /// read, and this data is lost. There are no guarantees regarding the + /// contents of `buf` when the call is cancelled. The current + /// implementation replaces `buf` with the empty string, but this may + /// change in the future. + /// + /// This function does not behave like [`read_until`] because of the + /// requirement that a string contains only valid utf-8. If you need a + /// cancellation safe `read_line`, there are three options: + /// + /// * Call [`read_until`] with a newline character and manually perform the utf-8 check. + /// * The stream returned by [`lines`] has a cancellation safe + /// [`next_line`] method. + /// * Use [`tokio_util::codec::LinesCodec`][LinesCodec]. + /// + /// [LinesCodec]: https://docs.rs/tokio-util/0.6/tokio_util/codec/struct.LinesCodec.html + /// [`read_until`]: Self::read_until + /// [`lines`]: Self::lines + /// [`next_line`]: crate::io::Lines::next_line + /// /// # Examples /// /// [`std::io::Cursor`][`Cursor`] is a type that implements @@ -173,10 +210,11 @@ cfg_io_util! { /// [`BufRead::split`](std::io::BufRead::split). /// /// The stream returned from this function will yield instances of - /// [`io::Result`]`<`[`Vec<u8>`]`>`. Each vector returned will *not* have + /// [`io::Result`]`<`[`Option`]`<`[`Vec<u8>`]`>>`. Each vector returned will *not* have /// the delimiter byte at the end. /// /// [`io::Result`]: std::io::Result + /// [`Option`]: core::option::Option /// [`Vec<u8>`]: std::vec::Vec /// /// # Errors @@ -206,14 +244,68 @@ cfg_io_util! { split(self, byte) } + /// Returns the contents of the internal buffer, filling it with more + /// data from the inner reader if it is empty. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`consume`] method to function properly. When calling this method, + /// none of the contents will be "read" in the sense that later calling + /// `read` may return the same contents. As such, [`consume`] must be + /// called with the number of bytes that are consumed from this buffer + /// to ensure that the bytes are never returned twice. + /// + /// An empty buffer returned indicates that the stream has reached EOF. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn fill_buf(&mut self) -> io::Result<&[u8]>; + /// ``` + /// + /// # Errors + /// + /// This function will return an I/O error if the underlying reader was + /// read, but returned an error. + /// + /// [`consume`]: crate::io::AsyncBufReadExt::consume + fn fill_buf(&mut self) -> FillBuf<'_, Self> + where + Self: Unpin, + { + fill_buf(self) + } + + /// Tells this buffer that `amt` bytes have been consumed from the + /// buffer, so they should no longer be returned in calls to [`read`]. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`fill_buf`] method to function properly. This function does not + /// perform any I/O, it simply informs this object that some amount of + /// its buffer, returned from [`fill_buf`], has been consumed and should + /// no longer be returned. As such, this function may do odd things if + /// [`fill_buf`] isn't called before calling it. + /// + /// The `amt` must be less than the number of bytes in the buffer + /// returned by [`fill_buf`]. + /// + /// [`read`]: crate::io::AsyncReadExt::read + /// [`fill_buf`]: crate::io::AsyncBufReadExt::fill_buf + fn consume(&mut self, amt: usize) + where + Self: Unpin, + { + std::pin::Pin::new(self).consume(amt) + } + /// Returns a stream over the lines of this reader. /// This method is the async equivalent to [`BufRead::lines`](std::io::BufRead::lines). /// /// The stream returned from this function will yield instances of - /// [`io::Result`]`<`[`String`]`>`. Each string returned will *not* have a newline + /// [`io::Result`]`<`[`Option`]`<`[`String`]`>>`. Each string returned will *not* have a newline /// byte (the 0xA byte) or CRLF (0xD, 0xA bytes) at the end. /// /// [`io::Result`]: std::io::Result + /// [`Option`]: core::option::Option /// [`String`]: String /// /// # Errors diff --git a/src/io/util/async_read_ext.rs b/src/io/util/async_read_ext.rs index e715f9d..df5445c 100644 --- a/src/io/util/async_read_ext.rs +++ b/src/io/util/async_read_ext.rs @@ -2,6 +2,7 @@ use crate::io::util::chain::{chain, Chain}; use crate::io::util::read::{read, Read}; use crate::io::util::read_buf::{read_buf, ReadBuf}; use crate::io::util::read_exact::{read_exact, ReadExact}; +use crate::io::util::read_int::{ReadF32, ReadF32Le, ReadF64, ReadF64Le}; use crate::io::util::read_int::{ ReadI128, ReadI128Le, ReadI16, ReadI16Le, ReadI32, ReadI32Le, ReadI64, ReadI64Le, ReadI8, }; @@ -105,8 +106,10 @@ cfg_io_util! { /// async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize>; /// ``` /// - /// This function does not provide any guarantees about whether it - /// completes immediately or asynchronously + /// This method does not provide any guarantees about whether it + /// completes immediately or asynchronously. + /// + /// # Return /// /// If the return value of this method is `Ok(n)`, then it must be /// guaranteed that `0 <= n <= buf.len()`. A nonzero `n` value indicates @@ -136,6 +139,12 @@ cfg_io_util! { /// variant will be returned. If an error is returned then it must be /// guaranteed that no bytes were read. /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no data was read. + /// /// # Examples /// /// [`File`][crate::fs::File]s implement `Read`: @@ -175,14 +184,19 @@ cfg_io_util! { /// Usually, only a single `read` syscall is issued, even if there is /// more space in the supplied buffer. /// - /// This function does not provide any guarantees about whether it - /// completes immediately or asynchronously + /// This method does not provide any guarantees about whether it + /// completes immediately or asynchronously. /// /// # Return /// - /// On a successful read, the number of read bytes is returned. If the - /// supplied buffer is not empty and the function returns `Ok(0)` then - /// the source has reached an "end-of-file" event. + /// A nonzero `n` value indicates that the buffer `buf` has been filled + /// in with `n` bytes of data from this source. If `n` is `0`, then it + /// can indicate one of two scenarios: + /// + /// 1. This reader has reached its "end of file" and will likely no longer + /// be able to produce bytes. Note that this does not mean that the + /// reader will *always* no longer be able to produce bytes. + /// 2. The buffer specified had a remaining capacity of zero. /// /// # Errors /// @@ -190,6 +204,12 @@ cfg_io_util! { /// variant will be returned. If an error is returned then it must be /// guaranteed that no bytes were read. /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no data was read. + /// /// # Examples /// /// [`File`] implements `Read` and [`BytesMut`] implements [`BufMut`]: @@ -254,6 +274,13 @@ cfg_io_util! { /// it has read, but it will never read more than would be necessary to /// completely fill the buffer. /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If the method is used as the + /// event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then some data may already have been + /// read into `buf`. + /// /// # Examples /// /// [`File`][crate::fs::File]s implement `Read`: @@ -579,7 +606,7 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut reader = Cursor::new(vec![0x80, 0, 0, 0, 0, 0, 0, 0]); /// - /// assert_eq!(i64::min_value(), reader.read_i64().await?); + /// assert_eq!(i64::MIN, reader.read_i64().await?); /// Ok(()) /// } /// ``` @@ -659,12 +686,88 @@ cfg_io_util! { /// 0, 0, 0, 0, 0, 0, 0, 0 /// ]); /// - /// assert_eq!(i128::min_value(), reader.read_i128().await?); + /// assert_eq!(i128::MIN, reader.read_i128().await?); /// Ok(()) /// } /// ``` fn read_i128(&mut self) -> ReadI128; + /// Reads an 32-bit floating point type in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f32(&mut self) -> io::Result<f32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 32-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0x7f, 0xff, 0xff]); + /// + /// assert_eq!(f32::MIN, reader.read_f32().await?); + /// Ok(()) + /// } + /// ``` + fn read_f32(&mut self) -> ReadF32; + + /// Reads an 64-bit floating point type in big-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f64(&mut self) -> io::Result<f64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 64-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0xff, 0xef, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + /// ]); + /// + /// assert_eq!(f64::MIN, reader.read_f64().await?); + /// Ok(()) + /// } + /// ``` + fn read_f64(&mut self) -> ReadF64; + /// Reads an unsigned 16-bit integer in little-endian order from the /// underlying reader. /// @@ -971,6 +1074,82 @@ cfg_io_util! { /// } /// ``` fn read_i128_le(&mut self) -> ReadI128Le; + + /// Reads an 32-bit floating point type in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f32_le(&mut self) -> io::Result<f32>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 32-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![0xff, 0xff, 0x7f, 0xff]); + /// + /// assert_eq!(f32::MIN, reader.read_f32_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_f32_le(&mut self) -> ReadF32Le; + + /// Reads an 64-bit floating point type in little-endian order from the + /// underlying reader. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn read_f64_le(&mut self) -> io::Result<f64>; + /// ``` + /// + /// It is recommended to use a buffered reader to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncReadExt::read_exact`]. + /// + /// [`AsyncReadExt::read_exact`]: AsyncReadExt::read_exact + /// + /// # Examples + /// + /// Read 64-bit floating point type from a `AsyncRead`: + /// + /// ```rust + /// use tokio::io::{self, AsyncReadExt}; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut reader = Cursor::new(vec![ + /// 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0xff + /// ]); + /// + /// assert_eq!(f64::MIN, reader.read_f64_le().await?); + /// Ok(()) + /// } + /// ``` + fn read_f64_le(&mut self) -> ReadF64Le; } /// Reads all bytes until EOF in this source, placing them into `buf`. diff --git a/src/io/util/async_seek_ext.rs b/src/io/util/async_seek_ext.rs index 297a4a6..46b3e6c 100644 --- a/src/io/util/async_seek_ext.rs +++ b/src/io/util/async_seek_ext.rs @@ -67,6 +67,16 @@ cfg_io_util! { seek(self, pos) } + /// Creates a future which will rewind to the beginning of the stream. + /// + /// This is convenience method, equivalent to to `self.seek(SeekFrom::Start(0))`. + fn rewind(&mut self) -> Seek<'_, Self> + where + Self: Unpin, + { + self.seek(SeekFrom::Start(0)) + } + /// Creates a future which will return the current seek position from the /// start of the stream. /// diff --git a/src/io/util/async_write_ext.rs b/src/io/util/async_write_ext.rs index d011d82..93a3183 100644 --- a/src/io/util/async_write_ext.rs +++ b/src/io/util/async_write_ext.rs @@ -2,7 +2,9 @@ use crate::io::util::flush::{flush, Flush}; use crate::io::util::shutdown::{shutdown, Shutdown}; use crate::io::util::write::{write, Write}; use crate::io::util::write_all::{write_all, WriteAll}; +use crate::io::util::write_all_buf::{write_all_buf, WriteAllBuf}; use crate::io::util::write_buf::{write_buf, WriteBuf}; +use crate::io::util::write_int::{WriteF32, WriteF32Le, WriteF64, WriteF64Le}; use crate::io::util::write_int::{ WriteI128, WriteI128Le, WriteI16, WriteI16Le, WriteI32, WriteI32Le, WriteI64, WriteI64Le, WriteI8, @@ -18,7 +20,7 @@ use std::io::IoSlice; use bytes::Buf; cfg_io_util! { - /// Defines numeric writer + /// Defines numeric writer. macro_rules! write_impl { ( $( @@ -96,6 +98,13 @@ cfg_io_util! { /// It is **not** considered an error if the entire buffer could not be /// written to this writer. /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// /// # Examples /// /// ```no_run @@ -128,6 +137,13 @@ cfg_io_util! { /// /// See [`AsyncWrite::poll_write_vectored`] for more details. /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// /// # Examples /// /// ```no_run @@ -159,7 +175,6 @@ cfg_io_util! { write_vectored(self, bufs) } - /// Writes a buffer into this writer, advancing the buffer's internal /// cursor. /// @@ -195,12 +210,20 @@ cfg_io_util! { /// It is **not** considered an error if the entire buffer could not be /// written to this writer. /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as + /// the event in a [`tokio::select!`](crate::select) statement and some + /// other branch completes first, then it is guaranteed that no data was + /// written to this `AsyncWrite`. + /// /// # Examples /// - /// [`File`] implements `Read` and [`Cursor<&[u8]>`] implements [`Buf`]: + /// [`File`] implements [`AsyncWrite`] and [`Cursor`]`<&[u8]>` implements [`Buf`]: /// /// [`File`]: crate::fs::File /// [`Buf`]: bytes::Buf + /// [`Cursor`]: std::io::Cursor /// /// ```no_run /// use tokio::io::{self, AsyncWriteExt}; @@ -238,6 +261,70 @@ cfg_io_util! { /// Equivalent to: /// /// ```ignore + /// async fn write_all_buf(&mut self, buf: impl Buf) -> Result<(), io::Error> { + /// while buf.has_remaining() { + /// self.write_buf(&mut buf).await?; + /// } + /// Ok(()) + /// } + /// ``` + /// + /// This method will continuously call [`write`] until + /// [`buf.has_remaining()`](bytes::Buf::has_remaining) returns false. This method will not + /// return until the entire buffer has been successfully written or an error occurs. The + /// first error generated will be returned. + /// + /// The buffer is advanced after each chunk is successfully written. After failure, + /// `src.chunk()` will return the chunk that failed to write. + /// + /// # Cancel safety + /// + /// If `write_all_buf` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then the data in the provided buffer may have been + /// partially written. However, it is guaranteed that the provided + /// buffer has been [advanced] by the amount of bytes that have been + /// partially written. + /// + /// # Examples + /// + /// [`File`] implements [`AsyncWrite`] and [`Cursor`]`<&[u8]>` implements [`Buf`]: + /// + /// [`File`]: crate::fs::File + /// [`Buf`]: bytes::Buf + /// [`Cursor`]: std::io::Cursor + /// [advanced]: bytes::Buf::advance + /// + /// ```no_run + /// use tokio::io::{self, AsyncWriteExt}; + /// use tokio::fs::File; + /// + /// use std::io::Cursor; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut file = File::create("foo.txt").await?; + /// let mut buffer = Cursor::new(b"data to write"); + /// + /// file.write_all_buf(&mut buffer).await?; + /// Ok(()) + /// } + /// ``` + /// + /// [`write`]: AsyncWriteExt::write + fn write_all_buf<'a, B>(&'a mut self, src: &'a mut B) -> WriteAllBuf<'a, Self, B> + where + Self: Sized + Unpin, + B: Buf, + { + write_all_buf(self, src) + } + + /// Attempts to write an entire buffer into this writer. + /// + /// Equivalent to: + /// + /// ```ignore /// async fn write_all(&mut self, buf: &[u8]) -> io::Result<()>; /// ``` /// @@ -246,6 +333,14 @@ cfg_io_util! { /// has been successfully written or such an error occurs. The first /// error generated from this method will be returned. /// + /// # Cancel safety + /// + /// This method is not cancellation safe. If it is used as the event + /// in a [`tokio::select!`](crate::select) statement and some other + /// branch completes first, then the provided buffer may have been + /// partially written, but future calls to `write_all` will start over + /// from the beginning of the buffer. + /// /// # Errors /// /// This function will return the first error that [`write`] returns. @@ -258,9 +353,9 @@ cfg_io_util! { /// /// #[tokio::main] /// async fn main() -> io::Result<()> { - /// let mut buffer = File::create("foo.txt").await?; + /// let mut file = File::create("foo.txt").await?; /// - /// buffer.write_all(b"some bytes").await?; + /// file.write_all(b"some bytes").await?; /// Ok(()) /// } /// ``` @@ -567,8 +662,8 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_i64(i64::min_value()).await?; - /// writer.write_i64(i64::max_value()).await?; + /// writer.write_i64(i64::MIN).await?; + /// writer.write_i64(i64::MAX).await?; /// /// assert_eq!(writer, b"\x80\x00\x00\x00\x00\x00\x00\x00\x7f\xff\xff\xff\xff\xff\xff\xff"); /// Ok(()) @@ -645,7 +740,7 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_i128(i128::min_value()).await?; + /// writer.write_i128(i128::MIN).await?; /// /// assert_eq!(writer, vec![ /// 0x80, 0, 0, 0, 0, 0, 0, 0, @@ -656,6 +751,81 @@ cfg_io_util! { /// ``` fn write_i128(&mut self, n: i128) -> WriteI128; + /// Writes an 32-bit floating point type in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f32(&mut self, n: f32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 32-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f32(f32::MIN).await?; + /// + /// assert_eq!(writer, vec![0xff, 0x7f, 0xff, 0xff]); + /// Ok(()) + /// } + /// ``` + fn write_f32(&mut self, n: f32) -> WriteF32; + + /// Writes an 64-bit floating point type in big-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f64(&mut self, n: f64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 64-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f64(f64::MIN).await?; + /// + /// assert_eq!(writer, vec![ + /// 0xff, 0xef, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_f64(&mut self, n: f64) -> WriteF64; /// Writes an unsigned 16-bit integer in little-endian order to the /// underlying writer. @@ -876,8 +1046,8 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_i64_le(i64::min_value()).await?; - /// writer.write_i64_le(i64::max_value()).await?; + /// writer.write_i64_le(i64::MIN).await?; + /// writer.write_i64_le(i64::MAX).await?; /// /// assert_eq!(writer, b"\x00\x00\x00\x00\x00\x00\x00\x80\xff\xff\xff\xff\xff\xff\xff\x7f"); /// Ok(()) @@ -954,7 +1124,7 @@ cfg_io_util! { /// async fn main() -> io::Result<()> { /// let mut writer = Vec::new(); /// - /// writer.write_i128_le(i128::min_value()).await?; + /// writer.write_i128_le(i128::MIN).await?; /// /// assert_eq!(writer, vec![ /// 0, 0, 0, 0, 0, 0, 0, @@ -964,6 +1134,82 @@ cfg_io_util! { /// } /// ``` fn write_i128_le(&mut self, n: i128) -> WriteI128Le; + + /// Writes an 32-bit floating point type in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f32_le(&mut self, n: f32) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 32-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f32_le(f32::MIN).await?; + /// + /// assert_eq!(writer, vec![0xff, 0xff, 0x7f, 0xff]); + /// Ok(()) + /// } + /// ``` + fn write_f32_le(&mut self, n: f32) -> WriteF32Le; + + /// Writes an 64-bit floating point type in little-endian order to the + /// underlying writer. + /// + /// Equivalent to: + /// + /// ```ignore + /// async fn write_f64_le(&mut self, n: f64) -> io::Result<()>; + /// ``` + /// + /// It is recommended to use a buffered writer to avoid excessive + /// syscalls. + /// + /// # Errors + /// + /// This method returns the same errors as [`AsyncWriteExt::write_all`]. + /// + /// [`AsyncWriteExt::write_all`]: AsyncWriteExt::write_all + /// + /// # Examples + /// + /// Write 64-bit floating point type to a `AsyncWrite`: + /// + /// ```rust + /// use tokio::io::{self, AsyncWriteExt}; + /// + /// #[tokio::main] + /// async fn main() -> io::Result<()> { + /// let mut writer = Vec::new(); + /// + /// writer.write_f64_le(f64::MIN).await?; + /// + /// assert_eq!(writer, vec![ + /// 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xef, 0xff + /// ]); + /// Ok(()) + /// } + /// ``` + fn write_f64_le(&mut self, n: f64) -> WriteF64Le; } /// Flushes this output stream, ensuring that all intermediately buffered diff --git a/src/io/util/buf_reader.rs b/src/io/util/buf_reader.rs index 271f61b..7df610b 100644 --- a/src/io/util/buf_reader.rs +++ b/src/io/util/buf_reader.rs @@ -1,11 +1,11 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io; +use std::io::{self, IoSlice, SeekFrom}; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{cmp, fmt}; +use std::{cmp, fmt, mem}; pin_project! { /// The `BufReader` struct adds buffering to any reader. @@ -30,6 +30,7 @@ pin_project! { pub(super) buf: Box<[u8]>, pub(super) pos: usize, pub(super) cap: usize, + pub(super) seek_state: SeekState, } } @@ -48,6 +49,7 @@ impl<R: AsyncRead> BufReader<R> { buf: buffer.into_boxed_slice(), pos: 0, cap: 0, + seek_state: SeekState::Init, } } @@ -141,6 +143,122 @@ impl<R: AsyncRead> AsyncBufRead for BufReader<R> { } } +#[derive(Debug, Clone, Copy)] +pub(super) enum SeekState { + /// start_seek has not been called. + Init, + /// start_seek has been called, but poll_complete has not yet been called. + Start(SeekFrom), + /// Waiting for completion of the first poll_complete in the `n.checked_sub(remainder).is_none()` branch. + PendingOverflowed(i64), + /// Waiting for completion of poll_complete. + Pending, +} + +/// Seeks to an offset, in bytes, in the underlying reader. +/// +/// The position used for seeking with `SeekFrom::Current(_)` is the +/// position the underlying reader would be at if the `BufReader` had no +/// internal buffer. +/// +/// Seeking always discards the internal buffer, even if the seek position +/// would otherwise fall within it. This guarantees that calling +/// `.into_inner()` immediately after a seek yields the underlying reader +/// at the same position. +/// +/// See [`AsyncSeek`] for more details. +/// +/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)` +/// where `n` minus the internal buffer length overflows an `i64`, two +/// seeks will be performed instead of one. If the second seek returns +/// `Err`, the underlying reader will be left at the same position it would +/// have if you called `seek` with `SeekFrom::Current(0)`. +impl<R: AsyncRead + AsyncSeek> AsyncSeek for BufReader<R> { + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + // We needs to call seek operation multiple times. + // And we should always call both start_seek and poll_complete, + // as start_seek alone cannot guarantee that the operation will be completed. + // poll_complete receives a Context and returns a Poll, so it cannot be called + // inside start_seek. + *self.project().seek_state = SeekState::Start(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let res = match mem::replace(self.as_mut().project().seek_state, SeekState::Init) { + SeekState::Init => { + // 1.x AsyncSeek recommends calling poll_complete before start_seek. + // We don't have to guarantee that the value returned by + // poll_complete called without start_seek is correct, + // so we'll return 0. + return Poll::Ready(Ok(0)); + } + SeekState::Start(SeekFrom::Current(n)) => { + let remainder = (self.cap - self.pos) as i64; + // it should be safe to assume that remainder fits within an i64 as the alternative + // means we managed to allocate 8 exbibytes and that's absurd. + // But it's not out of the realm of possibility for some weird underlying reader to + // support seeking by i64::MIN so we need to handle underflow when subtracting + // remainder. + if let Some(offset) = n.checked_sub(remainder) { + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(offset))?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } else { + // seek backwards by our remainder, and then by the offset + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(-remainder))?; + if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { + *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n); + return Poll::Pending; + } + + // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676 + self.as_mut().discard_buffer(); + + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(n))?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } + } + SeekState::PendingOverflowed(n) => { + if self.as_mut().get_pin_mut().poll_complete(cx)?.is_pending() { + *self.as_mut().project().seek_state = SeekState::PendingOverflowed(n); + return Poll::Pending; + } + + // https://github.com/rust-lang/rust/pull/61157#issuecomment-495932676 + self.as_mut().discard_buffer(); + + self.as_mut() + .get_pin_mut() + .start_seek(SeekFrom::Current(n))?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } + SeekState::Start(pos) => { + // Seeking with Start/End doesn't care about our buffer length. + self.as_mut().get_pin_mut().start_seek(pos)?; + self.as_mut().get_pin_mut().poll_complete(cx)? + } + SeekState::Pending => self.as_mut().get_pin_mut().poll_complete(cx)?, + }; + + match res { + Poll::Ready(res) => { + self.discard_buffer(); + Poll::Ready(Ok(res)) + } + Poll::Pending => { + *self.as_mut().project().seek_state = SeekState::Pending; + Poll::Pending + } + } + } +} + impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> { fn poll_write( self: Pin<&mut Self>, @@ -150,6 +268,18 @@ impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> { self.get_pin_mut().poll_write(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.get_pin_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.get_ref().is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.get_pin_mut().poll_flush(cx) } diff --git a/src/io/util/buf_stream.rs b/src/io/util/buf_stream.rs index cc857e2..595c142 100644 --- a/src/io/util/buf_stream.rs +++ b/src/io/util/buf_stream.rs @@ -1,8 +1,8 @@ use crate::io::util::{BufReader, BufWriter}; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; -use std::io; +use std::io::{self, IoSlice, SeekFrom}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -94,9 +94,11 @@ impl<RW> From<BufWriter<BufReader<RW>>> for BufStream<RW> { buf: rbuf, pos, cap, + seek_state: rseek_state, }, buf: wbuf, written, + seek_state: wseek_state, } = b; BufStream { @@ -105,10 +107,12 @@ impl<RW> From<BufWriter<BufReader<RW>>> for BufStream<RW> { inner, buf: wbuf, written, + seek_state: wseek_state, }, buf: rbuf, pos, cap, + seek_state: rseek_state, }, } } @@ -123,6 +127,18 @@ impl<RW: AsyncRead + AsyncWrite> AsyncWrite for BufStream<RW> { self.project().inner.poll_write(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.project().inner.poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { self.project().inner.poll_flush(cx) } @@ -142,6 +158,34 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> { } } +/// Seek to an offset, in bytes, in the underlying stream. +/// +/// The position used for seeking with `SeekFrom::Current(_)` is the +/// position the underlying stream would be at if the `BufStream` had no +/// internal buffer. +/// +/// Seeking always discards the internal buffer, even if the seek position +/// would otherwise fall within it. This guarantees that calling +/// `.into_inner()` immediately after a seek yields the underlying reader +/// at the same position. +/// +/// See [`AsyncSeek`] for more details. +/// +/// Note: In the edge case where you're seeking with `SeekFrom::Current(n)` +/// where `n` minus the internal buffer length overflows an `i64`, two +/// seeks will be performed instead of one. If the second seek returns +/// `Err`, the underlying reader will be left at the same position it would +/// have if you called `seek` with `SeekFrom::Current(0)`. +impl<RW: AsyncRead + AsyncWrite + AsyncSeek> AsyncSeek for BufStream<RW> { + fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> { + self.project().inner.start_seek(position) + } + + fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + self.project().inner.poll_complete(cx) + } +} + impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { self.project().inner.poll_fill_buf(cx) diff --git a/src/io/util/buf_writer.rs b/src/io/util/buf_writer.rs index 5e3d4b7..8dd1bba 100644 --- a/src/io/util/buf_writer.rs +++ b/src/io/util/buf_writer.rs @@ -1,9 +1,9 @@ use crate::io::util::DEFAULT_BUF_SIZE; -use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, ReadBuf}; use pin_project_lite::pin_project; use std::fmt; -use std::io::{self, Write}; +use std::io::{self, IoSlice, SeekFrom, Write}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -34,6 +34,7 @@ pin_project! { pub(super) inner: W, pub(super) buf: Vec<u8>, pub(super) written: usize, + pub(super) seek_state: SeekState, } } @@ -50,6 +51,7 @@ impl<W: AsyncWrite> BufWriter<W> { inner, buf: Vec::with_capacity(cap), written: 0, + seek_state: SeekState::Init, } } @@ -131,6 +133,72 @@ impl<W: AsyncWrite> AsyncWrite for BufWriter<W> { } } + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut bufs: &[IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + if self.inner.is_write_vectored() { + let total_len = bufs + .iter() + .fold(0usize, |acc, b| acc.saturating_add(b.len())); + if total_len > self.buf.capacity() - self.buf.len() { + ready!(self.as_mut().flush_buf(cx))?; + } + let me = self.as_mut().project(); + if total_len >= me.buf.capacity() { + // It's more efficient to pass the slices directly to the + // underlying writer than to buffer them. + // The case when the total_len calculation saturates at + // usize::MAX is also handled here. + me.inner.poll_write_vectored(cx, bufs) + } else { + bufs.iter().for_each(|b| me.buf.extend_from_slice(b)); + Poll::Ready(Ok(total_len)) + } + } else { + // Remove empty buffers at the beginning of bufs. + while bufs.first().map(|buf| buf.len()) == Some(0) { + bufs = &bufs[1..]; + } + if bufs.is_empty() { + return Poll::Ready(Ok(0)); + } + // Flush if the first buffer doesn't fit. + let first_len = bufs[0].len(); + if first_len > self.buf.capacity() - self.buf.len() { + ready!(self.as_mut().flush_buf(cx))?; + debug_assert!(self.buf.is_empty()); + } + let me = self.as_mut().project(); + if first_len >= me.buf.capacity() { + // The slice is at least as large as the buffering capacity, + // so it's better to write it directly, bypassing the buffer. + debug_assert!(me.buf.is_empty()); + return me.inner.poll_write(cx, &bufs[0]); + } else { + me.buf.extend_from_slice(&bufs[0]); + bufs = &bufs[1..]; + } + let mut total_written = first_len; + debug_assert!(total_written != 0); + // Append the buffers that fit in the internal buffer. + for buf in bufs { + if buf.len() > me.buf.capacity() - me.buf.len() { + break; + } else { + me.buf.extend_from_slice(buf); + total_written += buf.len(); + } + } + Poll::Ready(Ok(total_written)) + } + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { ready!(self.as_mut().flush_buf(cx))?; self.get_pin_mut().poll_flush(cx) @@ -142,6 +210,62 @@ impl<W: AsyncWrite> AsyncWrite for BufWriter<W> { } } +#[derive(Debug, Clone, Copy)] +pub(super) enum SeekState { + /// start_seek has not been called. + Init, + /// start_seek has been called, but poll_complete has not yet been called. + Start(SeekFrom), + /// Waiting for completion of poll_complete. + Pending, +} + +/// Seek to the offset, in bytes, in the underlying writer. +/// +/// Seeking always writes out the internal buffer before seeking. +impl<W: AsyncWrite + AsyncSeek> AsyncSeek for BufWriter<W> { + fn start_seek(self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + // We need to flush the internal buffer before seeking. + // It receives a `Context` and returns a `Poll`, so it cannot be called + // inside `start_seek`. + *self.project().seek_state = SeekState::Start(pos); + Ok(()) + } + + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + let pos = match self.seek_state { + SeekState::Init => { + return self.project().inner.poll_complete(cx); + } + SeekState::Start(pos) => Some(pos), + SeekState::Pending => None, + }; + + // Flush the internal buffer before seeking. + ready!(self.as_mut().flush_buf(cx))?; + + let mut me = self.project(); + if let Some(pos) = pos { + // Ensure previous seeks have finished before starting a new one + ready!(me.inner.as_mut().poll_complete(cx))?; + if let Err(e) = me.inner.as_mut().start_seek(pos) { + *me.seek_state = SeekState::Init; + return Poll::Ready(Err(e)); + } + } + match me.inner.poll_complete(cx) { + Poll::Ready(res) => { + *me.seek_state = SeekState::Init; + Poll::Ready(res) + } + Poll::Pending => { + *me.seek_state = SeekState::Pending; + Poll::Pending + } + } + } +} + impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> { fn poll_read( self: Pin<&mut Self>, diff --git a/src/io/util/copy.rs b/src/io/util/copy.rs index 3cd425b..d0ab7cb 100644 --- a/src/io/util/copy.rs +++ b/src/io/util/copy.rs @@ -8,6 +8,7 @@ use std::task::{Context, Poll}; #[derive(Debug)] pub(super) struct CopyBuffer { read_done: bool, + need_flush: bool, pos: usize, cap: usize, amt: u64, @@ -18,10 +19,11 @@ impl CopyBuffer { pub(super) fn new() -> Self { Self { read_done: false, + need_flush: false, pos: 0, cap: 0, amt: 0, - buf: vec![0; 2048].into_boxed_slice(), + buf: vec![0; super::DEFAULT_BUF_SIZE].into_boxed_slice(), } } @@ -41,7 +43,22 @@ impl CopyBuffer { if self.pos == self.cap && !self.read_done { let me = &mut *self; let mut buf = ReadBuf::new(&mut me.buf); - ready!(reader.as_mut().poll_read(cx, &mut buf))?; + + match reader.as_mut().poll_read(cx, &mut buf) { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + + return Poll::Pending; + } + } + let n = buf.filled().len(); if n == 0 { self.read_done = true; @@ -63,9 +80,18 @@ impl CopyBuffer { } else { self.pos += i; self.amt += i as u64; + self.need_flush = true; } } + // If pos larger than cap, this loop will never stop. + // In particular, user's wrong poll_write implementation returning + // incorrect written length may lead to thread blocking. + debug_assert!( + self.pos <= self.cap, + "writer returned length larger than input slice" + ); + // If we've written all the data and we've seen EOF, flush out the // data and finish the transfer. if self.pos == self.cap && self.read_done { diff --git a/src/io/util/copy_bidirectional.rs b/src/io/util/copy_bidirectional.rs index cc43f0f..c93060b 100644 --- a/src/io/util/copy_bidirectional.rs +++ b/src/io/util/copy_bidirectional.rs @@ -104,6 +104,7 @@ where /// # Return value /// /// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`. +#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error> where A: AsyncRead + AsyncWrite + Unpin + ?Sized, diff --git a/src/io/util/fill_buf.rs b/src/io/util/fill_buf.rs new file mode 100644 index 0000000..3655c01 --- /dev/null +++ b/src/io/util/fill_buf.rs @@ -0,0 +1,53 @@ +use crate::io::AsyncBufRead; + +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// Future for the [`fill_buf`](crate::io::AsyncBufReadExt::fill_buf) method. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct FillBuf<'a, R: ?Sized> { + reader: Option<&'a mut R>, + #[pin] + _pin: PhantomPinned, + } +} + +pub(crate) fn fill_buf<R>(reader: &mut R) -> FillBuf<'_, R> +where + R: AsyncBufRead + ?Sized + Unpin, +{ + FillBuf { + reader: Some(reader), + _pin: PhantomPinned, + } +} + +impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for FillBuf<'a, R> { + type Output = io::Result<&'a [u8]>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let me = self.project(); + + let reader = me.reader.take().expect("Polled after completion."); + match Pin::new(&mut *reader).poll_fill_buf(cx) { + Poll::Ready(Ok(slice)) => unsafe { + // Safety: This is necessary only due to a limitation in the + // borrow checker. Once Rust starts using the polonius borrow + // checker, this can be simplified. + let slice = std::mem::transmute::<&[u8], &'a [u8]>(slice); + Poll::Ready(Ok(slice)) + }, + Poll::Ready(Err(err)) => Poll::Ready(Err(err)), + Poll::Pending => { + *me.reader = Some(reader); + Poll::Pending + } + } + } +} diff --git a/src/io/util/lines.rs b/src/io/util/lines.rs index ed6a944..717f633 100644 --- a/src/io/util/lines.rs +++ b/src/io/util/lines.rs @@ -8,7 +8,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; pin_project! { - /// Read lines from an [`AsyncBufRead`]. + /// Reads lines from an [`AsyncBufRead`]. /// /// A `Lines` can be turned into a `Stream` with [`LinesStream`]. /// @@ -47,6 +47,10 @@ where { /// Returns the next line in the stream. /// + /// # Cancel safety + /// + /// This method is cancellation safe. + /// /// # Examples /// /// ``` @@ -68,12 +72,12 @@ where poll_fn(|cx| Pin::new(&mut *self).poll_next_line(cx)).await } - /// Obtain a mutable reference to the underlying reader + /// Obtains a mutable reference to the underlying reader. pub fn get_mut(&mut self) -> &mut R { &mut self.reader } - /// Obtain a reference to the underlying reader + /// Obtains a reference to the underlying reader. pub fn get_ref(&mut self) -> &R { &self.reader } @@ -102,11 +106,9 @@ where /// /// When the method returns `Poll::Pending`, the `Waker` in the provided /// `Context` is scheduled to receive a wakeup when more bytes become - /// available on the underlying IO resource. - /// - /// Note that on multiple calls to `poll_next_line`, only the `Waker` from - /// the `Context` passed to the most recent call is scheduled to receive a - /// wakeup. + /// available on the underlying IO resource. Note that on multiple calls to + /// `poll_next_line`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. pub fn poll_next_line( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -128,7 +130,7 @@ where } } - Poll::Ready(Ok(Some(mem::replace(me.buf, String::new())))) + Poll::Ready(Ok(Some(mem::take(me.buf)))) } } diff --git a/src/io/util/mem.rs b/src/io/util/mem.rs index e91a932..4eefe7b 100644 --- a/src/io/util/mem.rs +++ b/src/io/util/mem.rs @@ -16,6 +16,14 @@ use std::{ /// that can be used as in-memory IO types. Writing to one of the pairs will /// allow that data to be read from the other, and vice versa. /// +/// # Closing a `DuplexStream` +/// +/// If one end of the `DuplexStream` channel is dropped, any pending reads on +/// the other side will continue to read data until the buffer is drained, then +/// they will signal EOF by returning 0 bytes. Any writes to the other side, +/// including pending ones (that are waiting for free space in the buffer) will +/// return `Err(BrokenPipe)` immediately. +/// /// # Example /// /// ``` @@ -37,6 +45,7 @@ use std::{ /// # } /// ``` #[derive(Debug)] +#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] pub struct DuplexStream { read: Arc<Mutex<Pipe>>, write: Arc<Mutex<Pipe>>, @@ -72,6 +81,7 @@ struct Pipe { /// /// The `max_buf_size` argument is the maximum amount of bytes that can be /// written to a side before the write returns `Poll::Pending`. +#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))] pub fn duplex(max_buf_size: usize) -> (DuplexStream, DuplexStream) { let one = Arc::new(Mutex::new(Pipe::new(max_buf_size))); let two = Arc::new(Mutex::new(Pipe::new(max_buf_size))); @@ -134,7 +144,8 @@ impl AsyncWrite for DuplexStream { impl Drop for DuplexStream { fn drop(&mut self) { // notify the other side of the closure - self.write.lock().close(); + self.write.lock().close_write(); + self.read.lock().close_read(); } } @@ -151,12 +162,21 @@ impl Pipe { } } - fn close(&mut self) { + fn close_write(&mut self) { self.is_closed = true; + // needs to notify any readers that no more data will come if let Some(waker) = self.read_waker.take() { waker.wake(); } } + + fn close_read(&mut self) { + self.is_closed = true; + // needs to notify any writers that they have to abort + if let Some(waker) = self.write_waker.take() { + waker.wake(); + } + } } impl AsyncRead for Pipe { @@ -217,7 +237,7 @@ impl AsyncWrite for Pipe { mut self: Pin<&mut Self>, _: &mut task::Context<'_>, ) -> Poll<std::io::Result<()>> { - self.close(); + self.close_write(); Poll::Ready(Ok(())) } } diff --git a/src/io/util/mod.rs b/src/io/util/mod.rs index ab38664..21199d0 100644 --- a/src/io/util/mod.rs +++ b/src/io/util/mod.rs @@ -49,6 +49,7 @@ cfg_io_util! { mod read_exact; mod read_int; mod read_line; + mod fill_buf; mod read_to_end; mod vec_with_initialized; @@ -77,6 +78,7 @@ cfg_io_util! { mod write_vectored; mod write_all; mod write_buf; + mod write_all_buf; mod write_int; diff --git a/src/io/util/read_int.rs b/src/io/util/read_int.rs index 5b9fb7b..164dcf5 100644 --- a/src/io/util/read_int.rs +++ b/src/io/util/read_int.rs @@ -142,6 +142,9 @@ reader!(ReadI32, i32, get_i32); reader!(ReadI64, i64, get_i64); reader!(ReadI128, i128, get_i128); +reader!(ReadF32, f32, get_f32); +reader!(ReadF64, f64, get_f64); + reader!(ReadU16Le, u16, get_u16_le); reader!(ReadU32Le, u32, get_u32_le); reader!(ReadU64Le, u64, get_u64_le); @@ -151,3 +154,6 @@ reader!(ReadI16Le, i16, get_i16_le); reader!(ReadI32Le, i32, get_i32_le); reader!(ReadI64Le, i64, get_i64_le); reader!(ReadI128Le, i128, get_i128_le); + +reader!(ReadF32Le, f32, get_f32_le); +reader!(ReadF64Le, f64, get_f64_le); diff --git a/src/io/util/read_line.rs b/src/io/util/read_line.rs index d38ffaf..e641f51 100644 --- a/src/io/util/read_line.rs +++ b/src/io/util/read_line.rs @@ -36,7 +36,7 @@ where { ReadLine { reader, - buf: mem::replace(string, String::new()).into_bytes(), + buf: mem::take(string).into_bytes(), output: string, read: 0, _pin: PhantomPinned, @@ -99,7 +99,7 @@ pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>( read: &mut usize, ) -> Poll<io::Result<usize>> { let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); - let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); + let utf8_res = String::from_utf8(mem::take(buf)); // At this point both buf and output are empty. The allocation is in utf8_res. diff --git a/src/io/util/read_to_string.rs b/src/io/util/read_to_string.rs index 2c17383..b3d82a2 100644 --- a/src/io/util/read_to_string.rs +++ b/src/io/util/read_to_string.rs @@ -37,7 +37,7 @@ pub(crate) fn read_to_string<'a, R>( where R: AsyncRead + ?Sized + Unpin, { - let buf = mem::replace(string, String::new()).into_bytes(); + let buf = mem::take(string).into_bytes(); ReadToString { reader, buf: VecWithInitialized::new(buf), diff --git a/src/io/util/read_until.rs b/src/io/util/read_until.rs index 3599cff..90a0e8a 100644 --- a/src/io/util/read_until.rs +++ b/src/io/util/read_until.rs @@ -10,12 +10,12 @@ use std::task::{Context, Poll}; pin_project! { /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. - /// The delimeter is included in the resulting vector. + /// The delimiter is included in the resulting vector. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadUntil<'a, R: ?Sized> { reader: &'a mut R, - delimeter: u8, + delimiter: u8, buf: &'a mut Vec<u8>, // The number of bytes appended to buf. This can be less than buf.len() if // the buffer was not empty when the operation was started. @@ -28,7 +28,7 @@ pin_project! { pub(crate) fn read_until<'a, R>( reader: &'a mut R, - delimeter: u8, + delimiter: u8, buf: &'a mut Vec<u8>, ) -> ReadUntil<'a, R> where @@ -36,7 +36,7 @@ where { ReadUntil { reader, - delimeter, + delimiter, buf, read: 0, _pin: PhantomPinned, @@ -46,14 +46,14 @@ where pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>( mut reader: Pin<&mut R>, cx: &mut Context<'_>, - delimeter: u8, + delimiter: u8, buf: &mut Vec<u8>, read: &mut usize, ) -> Poll<io::Result<usize>> { loop { let (done, used) = { let available = ready!(reader.as_mut().poll_fill_buf(cx))?; - if let Some(i) = memchr::memchr(delimeter, available) { + if let Some(i) = memchr::memchr(delimiter, available) { buf.extend_from_slice(&available[..=i]); (true, i + 1) } else { @@ -74,6 +74,6 @@ impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let me = self.project(); - read_until_internal(Pin::new(*me.reader), cx, *me.delimeter, me.buf, me.read) + read_until_internal(Pin::new(*me.reader), cx, *me.delimiter, me.buf, me.read) } } diff --git a/src/io/util/split.rs b/src/io/util/split.rs index 4f3ce4e..7489c24 100644 --- a/src/io/util/split.rs +++ b/src/io/util/split.rs @@ -95,7 +95,7 @@ where let n = ready!(read_until_internal( me.reader, cx, *me.delim, me.buf, me.read, ))?; - // read_until_internal resets me.read to zero once it finds the delimeter + // read_until_internal resets me.read to zero once it finds the delimiter debug_assert_eq!(*me.read, 0); if n == 0 && me.buf.is_empty() { @@ -106,7 +106,7 @@ where me.buf.pop(); } - Poll::Ready(Ok(Some(mem::replace(me.buf, Vec::new())))) + Poll::Ready(Ok(Some(mem::take(me.buf)))) } } diff --git a/src/io/util/write_all_buf.rs b/src/io/util/write_all_buf.rs new file mode 100644 index 0000000..05af7fe --- /dev/null +++ b/src/io/util/write_all_buf.rs @@ -0,0 +1,56 @@ +use crate::io::AsyncWrite; + +use bytes::Buf; +use pin_project_lite::pin_project; +use std::future::Future; +use std::io; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + /// A future to write some of the buffer to an `AsyncWrite`. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + pub struct WriteAllBuf<'a, W, B> { + writer: &'a mut W, + buf: &'a mut B, + #[pin] + _pin: PhantomPinned, + } +} + +/// Tries to write some bytes from the given `buf` to the writer in an +/// asynchronous manner, returning a future. +pub(crate) fn write_all_buf<'a, W, B>(writer: &'a mut W, buf: &'a mut B) -> WriteAllBuf<'a, W, B> +where + W: AsyncWrite + Unpin, + B: Buf, +{ + WriteAllBuf { + writer, + buf, + _pin: PhantomPinned, + } +} + +impl<W, B> Future for WriteAllBuf<'_, W, B> +where + W: AsyncWrite + Unpin, + B: Buf, +{ + type Output = io::Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let me = self.project(); + while me.buf.has_remaining() { + let n = ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk())?); + me.buf.advance(n); + if n == 0 { + return Poll::Ready(Err(io::ErrorKind::WriteZero.into())); + } + } + + Poll::Ready(Ok(())) + } +} diff --git a/src/io/util/write_int.rs b/src/io/util/write_int.rs index 13bc191..63cd491 100644 --- a/src/io/util/write_int.rs +++ b/src/io/util/write_int.rs @@ -135,6 +135,9 @@ writer!(WriteI32, i32, put_i32); writer!(WriteI64, i64, put_i64); writer!(WriteI128, i128, put_i128); +writer!(WriteF32, f32, put_f32); +writer!(WriteF64, f64, put_f64); + writer!(WriteU16Le, u16, put_u16_le); writer!(WriteU32Le, u32, put_u32_le); writer!(WriteU64Le, u64, put_u64_le); @@ -144,3 +147,6 @@ writer!(WriteI16Le, i16, put_i16_le); writer!(WriteI32Le, i32, put_i32_le); writer!(WriteI64Le, i64, put_i64_le); writer!(WriteI128Le, i128, put_i128_le); + +writer!(WriteF32Le, f32, put_f32_le); +writer!(WriteF64Le, f64, put_f64_le); @@ -9,12 +9,18 @@ rust_2018_idioms, unreachable_pub )] -#![cfg_attr(docsrs, deny(broken_intra_doc_links))] +#![deny(unused_must_use)] +#![cfg_attr(docsrs, deny(rustdoc::broken_intra_doc_links))] #![doc(test( no_crate_inject, attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables)) ))] #![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg_hide))] +#![cfg_attr(docsrs, doc(cfg_hide(docsrs)))] +#![cfg_attr(docsrs, doc(cfg_hide(loom)))] +#![cfg_attr(docsrs, doc(cfg_hide(not(loom))))] +#![cfg_attr(docsrs, allow(unused_attributes))] //! A runtime for writing reliable network applications without compromising speed. //! @@ -159,8 +165,8 @@ //! [`tokio::runtime`]: crate::runtime //! [`Builder`]: crate::runtime::Builder //! [`Runtime`]: crate::runtime::Runtime -//! [rt]: runtime/index.html#basic-scheduler -//! [rt-multi-thread]: runtime/index.html#threaded-scheduler +//! [rt]: runtime/index.html#current-thread-scheduler +//! [rt-multi-thread]: runtime/index.html#multi-thread-scheduler //! [rt-features]: runtime/index.html#runtime-scheduler //! //! ## CPU-bound tasks and blocking code @@ -203,9 +209,15 @@ //! ``` //! //! If your code is CPU-bound and you wish to limit the number of threads used -//! to run it, you should run it on another thread pool such as [rayon]. You -//! can use an [`oneshot`] channel to send the result back to Tokio when the -//! rayon task finishes. +//! to run it, you should use a separate thread pool dedicated to CPU bound tasks. +//! For example, you could consider using the [rayon] library for CPU-bound +//! tasks. It is also possible to create an extra Tokio runtime dedicated to +//! CPU-bound tasks, but if you do this, you should be careful that the extra +//! runtime runs _only_ CPU-bound tasks, as IO-bound tasks on that runtime +//! will behave poorly. +//! +//! Hint: If using rayon, you can use a [`oneshot`] channel to send the result back +//! to Tokio when the rayon task finishes. //! //! [rayon]: https://docs.rs/rayon //! [`oneshot`]: crate::sync::oneshot @@ -306,8 +318,9 @@ //! - `rt-multi-thread`: Enables the heavier, multi-threaded, work-stealing scheduler. //! - `io-util`: Enables the IO based `Ext` traits. //! - `io-std`: Enable `Stdout`, `Stdin` and `Stderr` types. -//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and `UdpSocket`, -//! as well as (on Unix-like systems) `AsyncFd` +//! - `net`: Enables `tokio::net` types such as `TcpStream`, `UnixStream` and +//! `UdpSocket`, as well as (on Unix-like systems) `AsyncFd` and (on +//! FreeBSD) `PollAio`. //! - `time`: Enables `tokio::time` types and allows the schedulers to enable //! the built in timer. //! - `process`: Enables `tokio::process` types. @@ -341,6 +354,19 @@ //! //! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section +// Test that pointer width is compatible. This asserts that e.g. usize is at +// least 32 bits, which a lot of components in Tokio currently assumes. +// +// TODO: improve once we have MSRV access to const eval to make more flexible. +#[cfg(not(any( + target_pointer_width = "32", + target_pointer_width = "64", + target_pointer_width = "128" +)))] +compile_error! { + "Tokio requires the platform pointer width to be 32, 64, or 128 bits" +} + // Includes re-exports used by macros. // // This module is not intended to be part of the public API. In general, any @@ -442,6 +468,28 @@ mod util; /// ``` pub mod stream {} +// local re-exports of platform specific things, allowing for decent +// documentation to be shimmed in on docs.rs + +#[cfg(docsrs)] +pub mod doc; + +#[cfg(docsrs)] +#[allow(unused)] +pub(crate) use self::doc::os; + +#[cfg(not(docsrs))] +#[allow(unused)] +pub(crate) use std::os; + +#[cfg(docsrs)] +#[allow(unused)] +pub(crate) use self::doc::winapi; + +#[cfg(all(not(docsrs), windows, feature = "net"))] +#[allow(unused)] +pub(crate) use ::winapi; + cfg_macros! { /// Implementation detail of the `select!` macro. This macro is **not** /// intended to be used as part of the public API and is permitted to @@ -449,19 +497,30 @@ cfg_macros! { #[doc(hidden)] pub use tokio_macros::select_priv_declare_output_enum; + /// Implementation detail of the `select!` macro. This macro is **not** + /// intended to be used as part of the public API and is permitted to + /// change. + #[doc(hidden)] + pub use tokio_macros::select_priv_clean_pattern; + cfg_rt! { #[cfg(feature = "rt-multi-thread")] #[cfg(not(test))] // Work around for rust-lang/rust#62127 #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + #[doc(inline)] pub use tokio_macros::main; #[cfg(feature = "rt-multi-thread")] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] + #[doc(inline)] pub use tokio_macros::test; cfg_not_rt_multi_thread! { #[cfg(not(test))] // Work around for rust-lang/rust#62127 + #[doc(inline)] pub use tokio_macros::main_rt as main; + + #[doc(inline)] pub use tokio_macros::test_rt as test; } } @@ -469,7 +528,10 @@ cfg_macros! { // Always fail if rt is not enabled. cfg_not_rt! { #[cfg(not(test))] + #[doc(inline)] pub use tokio_macros::main_fail as main; + + #[doc(inline)] pub use tokio_macros::test_fail as test; } } diff --git a/src/loom/std/atomic_u64.rs b/src/loom/std/atomic_u64.rs index a86a195..8ea6bd4 100644 --- a/src/loom/std/atomic_u64.rs +++ b/src/loom/std/atomic_u64.rs @@ -2,21 +2,17 @@ //! re-export of `AtomicU64`. On 32 bit platforms, this is implemented using a //! `Mutex`. -pub(crate) use self::imp::AtomicU64; - // `AtomicU64` can only be used on targets with `target_has_atomic` is 64 or greater. // Once `cfg_target_has_atomic` feature is stable, we can replace it with // `#[cfg(target_has_atomic = "64")]`. // Refs: https://github.com/rust-lang/rust/tree/master/src/librustc_target -#[cfg(not(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc")))] -mod imp { +cfg_has_atomic_u64! { pub(crate) use std::sync::atomic::AtomicU64; } -#[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))] -mod imp { +cfg_not_has_atomic_u64! { + use crate::loom::sync::Mutex; use std::sync::atomic::Ordering; - use std::sync::Mutex; #[derive(Debug)] pub(crate) struct AtomicU64 { @@ -31,15 +27,15 @@ mod imp { } pub(crate) fn load(&self, _: Ordering) -> u64 { - *self.inner.lock().unwrap() + *self.inner.lock() } pub(crate) fn store(&self, val: u64, _: Ordering) { - *self.inner.lock().unwrap() = val; + *self.inner.lock() = val; } pub(crate) fn fetch_or(&self, val: u64, _: Ordering) -> u64 { - let mut lock = self.inner.lock().unwrap(); + let mut lock = self.inner.lock(); let prev = *lock; *lock = prev | val; prev @@ -52,7 +48,7 @@ mod imp { _success: Ordering, _failure: Ordering, ) -> Result<u64, u64> { - let mut lock = self.inner.lock().unwrap(); + let mut lock = self.inner.lock(); if *lock == current { *lock = new; diff --git a/src/loom/std/mod.rs b/src/loom/std/mod.rs index b29cbee..8b6e8bc 100644 --- a/src/loom/std/mod.rs +++ b/src/loom/std/mod.rs @@ -93,4 +93,17 @@ pub(crate) mod sys { } } -pub(crate) use std::thread; +pub(crate) mod thread { + #[inline] + pub(crate) fn yield_now() { + // TODO: once we bump MSRV to 1.49+, use `hint::spin_loop` instead. + #[allow(deprecated)] + std::sync::atomic::spin_loop_hint(); + } + + #[allow(unused_imports)] + pub(crate) use std::thread::{ + current, panicking, park, park_timeout, sleep, spawn, Builder, JoinHandle, LocalKey, + Result, Thread, ThreadId, + }; +} diff --git a/src/loom/std/mutex.rs b/src/loom/std/mutex.rs index bf14d62..3f686e0 100644 --- a/src/loom/std/mutex.rs +++ b/src/loom/std/mutex.rs @@ -1,7 +1,7 @@ use std::sync::{self, MutexGuard, TryLockError}; /// Adapter for `std::Mutex` that removes the poisoning aspects -// from its api +/// from its api. #[derive(Debug)] pub(crate) struct Mutex<T: ?Sized>(sync::Mutex<T>); diff --git a/src/macros/cfg.rs b/src/macros/cfg.rs index 3442612..606bce7 100644 --- a/src/macros/cfg.rs +++ b/src/macros/cfg.rs @@ -13,7 +13,7 @@ macro_rules! feature { } } -/// Enables enter::block_on +/// Enables enter::block_on. macro_rules! cfg_block_on { ($($item:item)*) => { $( @@ -28,7 +28,7 @@ macro_rules! cfg_block_on { } } -/// Enables internal `AtomicWaker` impl +/// Enables internal `AtomicWaker` impl. macro_rules! cfg_atomic_waker_impl { ($($item:item)*) => { $( @@ -45,6 +45,18 @@ macro_rules! cfg_atomic_waker_impl { } } +macro_rules! cfg_aio { + ($($item:item)*) => { + $( + #[cfg(all(any(docsrs, target_os = "freebsd"), feature = "net"))] + #[cfg_attr(docsrs, + doc(cfg(all(target_os = "freebsd", feature = "net"))) + )] + $item + )* + } +} + macro_rules! cfg_fs { ($($item:item)*) => { $( @@ -87,6 +99,7 @@ macro_rules! cfg_io_driver_impl { feature = "process", all(unix, feature = "signal"), ))] + #[cfg_attr(docsrs, doc(cfg(all())))] $item )* } @@ -157,7 +170,25 @@ macro_rules! cfg_macros { $( #[cfg(feature = "macros")] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] - #[doc(inline)] + $item + )* + } +} + +macro_rules! cfg_stats { + ($($item:item)*) => { + $( + #[cfg(all(tokio_unstable, feature = "stats"))] + #[cfg_attr(docsrs, doc(cfg(feature = "stats")))] + $item + )* + } +} + +macro_rules! cfg_not_stats { + ($($item:item)*) => { + $( + #[cfg(not(all(tokio_unstable, feature = "stats")))] $item )* } @@ -177,7 +208,17 @@ macro_rules! cfg_net_unix { ($($item:item)*) => { $( #[cfg(all(unix, feature = "net"))] - #[cfg_attr(docsrs, doc(cfg(feature = "net")))] + #[cfg_attr(docsrs, doc(cfg(all(unix, feature = "net"))))] + $item + )* + } +} + +macro_rules! cfg_net_windows { + ($($item:item)*) => { + $( + #[cfg(all(any(all(doc, docsrs), windows), feature = "net"))] + #[cfg_attr(docsrs, doc(cfg(all(windows, feature = "net"))))] $item )* } @@ -375,3 +416,31 @@ macro_rules! cfg_not_coop { )* } } + +macro_rules! cfg_has_atomic_u64 { + ($($item:item)*) => { + $( + #[cfg(not(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "powerpc", + target_arch = "riscv32" + )))] + $item + )* + } +} + +macro_rules! cfg_not_has_atomic_u64 { + ($($item:item)*) => { + $( + #[cfg(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "powerpc", + target_arch = "riscv32" + ))] + $item + )* + } +} diff --git a/src/macros/join.rs b/src/macros/join.rs index 5f37af5..f91b5f1 100644 --- a/src/macros/join.rs +++ b/src/macros/join.rs @@ -1,4 +1,4 @@ -/// Wait on multiple concurrent branches, returning when **all** branches +/// Waits on multiple concurrent branches, returning when **all** branches /// complete. /// /// The `join!` macro must be used inside of async functions, closures, and diff --git a/src/macros/mod.rs b/src/macros/mod.rs index b0af521..a1839c8 100644 --- a/src/macros/mod.rs +++ b/src/macros/mod.rs @@ -15,6 +15,11 @@ mod ready; #[macro_use] mod thread_local; +cfg_trace! { + #[macro_use] + mod trace; +} + #[macro_use] #[cfg(feature = "rt")] pub(crate) mod scoped_tls; diff --git a/src/macros/scoped_tls.rs b/src/macros/scoped_tls.rs index a00aae2..f2504cb 100644 --- a/src/macros/scoped_tls.rs +++ b/src/macros/scoped_tls.rs @@ -3,7 +3,7 @@ use crate::loom::thread::LocalKey; use std::cell::Cell; use std::marker; -/// Set a reference as a thread-local +/// Sets a reference as a thread-local. macro_rules! scoped_thread_local { ($(#[$attrs:meta])* $vis:vis static $name:ident: $ty:ty) => ( $(#[$attrs])* diff --git a/src/macros/select.rs b/src/macros/select.rs index 3ba16b6..051f8cb 100644 --- a/src/macros/select.rs +++ b/src/macros/select.rs @@ -1,4 +1,4 @@ -/// Wait on multiple concurrent branches, returning when the **first** branch +/// Waits on multiple concurrent branches, returning when the **first** branch /// completes, cancelling the remaining branches. /// /// The `select!` macro must be used inside of async functions, closures, and @@ -23,10 +23,10 @@ /// returns the result of evaluating the completed branch's `<handler>` /// expression. /// -/// Additionally, each branch may include an optional `if` precondition. This -/// precondition is evaluated **before** the `<async expression>`. If the -/// precondition returns `false`, the branch is entirely disabled. This -/// capability is useful when using `select!` within a loop. +/// Additionally, each branch may include an optional `if` precondition. If the +/// precondition returns `false`, then the branch is disabled. The provided +/// `<async expression>` is still evaluated but the resulting future is never +/// polled. This capability is useful when using `select!` within a loop. /// /// The complete lifecycle of a `select!` expression is as follows: /// @@ -42,12 +42,10 @@ /// to the provided `<pattern>`, if the pattern matches, evaluate `<handler>` /// and return. If the pattern **does not** match, disable the current branch /// and for the remainder of the current call to `select!`. Continue from step 3. -/// 5. If **all** branches are disabled, evaluate the `else` expression. If none -/// is provided, panic. +/// 5. If **all** branches are disabled, evaluate the `else` expression. If no +/// else branch is provided, panic. /// -/// # Notes -/// -/// ### Runtime characteristics +/// # Runtime characteristics /// /// By running all async expressions on the current task, the expressions are /// able to run **concurrently** but not in **parallel**. This means all @@ -58,76 +56,7 @@ /// /// [`tokio::spawn`]: crate::spawn /// -/// ### Avoid racy `if` preconditions -/// -/// Given that `if` preconditions are used to disable `select!` branches, some -/// caution must be used to avoid missing values. -/// -/// For example, here is **incorrect** usage of `sleep` with `if`. The objective -/// is to repeatedly run an asynchronous task for up to 50 milliseconds. -/// However, there is a potential for the `sleep` completion to be missed. -/// -/// ```no_run -/// use tokio::time::{self, Duration}; -/// -/// async fn some_async_work() { -/// // do work -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// let sleep = time::sleep(Duration::from_millis(50)); -/// tokio::pin!(sleep); -/// -/// while !sleep.is_elapsed() { -/// tokio::select! { -/// _ = &mut sleep, if !sleep.is_elapsed() => { -/// println!("operation timed out"); -/// } -/// _ = some_async_work() => { -/// println!("operation completed"); -/// } -/// } -/// } -/// } -/// ``` -/// -/// In the above example, `sleep.is_elapsed()` may return `true` even if -/// `sleep.poll()` never returned `Ready`. This opens up a potential race -/// condition where `sleep` expires between the `while !sleep.is_elapsed()` -/// check and the call to `select!` resulting in the `some_async_work()` call to -/// run uninterrupted despite the sleep having elapsed. -/// -/// One way to write the above example without the race would be: -/// -/// ``` -/// use tokio::time::{self, Duration}; -/// -/// async fn some_async_work() { -/// # time::sleep(Duration::from_millis(10)).await; -/// // do work -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// let sleep = time::sleep(Duration::from_millis(50)); -/// tokio::pin!(sleep); -/// -/// loop { -/// tokio::select! { -/// _ = &mut sleep => { -/// println!("operation timed out"); -/// break; -/// } -/// _ = some_async_work() => { -/// println!("operation completed"); -/// } -/// } -/// } -/// } -/// ``` -/// -/// ### Fairness +/// # Fairness /// /// By default, `select!` randomly picks a branch to check first. This provides /// some level of fairness when calling `select!` in a loop with branches that @@ -151,10 +80,60 @@ /// /// # Panics /// -/// `select!` panics if all branches are disabled **and** there is no provided -/// `else` branch. A branch is disabled when the provided `if` precondition -/// returns `false` **or** when the pattern does not match the result of `<async -/// expression>. +/// The `select!` macro panics if all branches are disabled **and** there is no +/// provided `else` branch. A branch is disabled when the provided `if` +/// precondition returns `false` **or** when the pattern does not match the +/// result of `<async expression>`. +/// +/// # Cancellation safety +/// +/// When using `select!` in a loop to receive messages from multiple sources, +/// you should make sure that the receive call is cancellation safe to avoid +/// losing messages. This section goes through various common methods and +/// describes whether they are cancel safe. The lists in this section are not +/// exhaustive. +/// +/// The following methods are cancellation safe: +/// +/// * [`tokio::sync::mpsc::Receiver::recv`](crate::sync::mpsc::Receiver::recv) +/// * [`tokio::sync::mpsc::UnboundedReceiver::recv`](crate::sync::mpsc::UnboundedReceiver::recv) +/// * [`tokio::sync::broadcast::Receiver::recv`](crate::sync::broadcast::Receiver::recv) +/// * [`tokio::sync::watch::Receiver::changed`](crate::sync::watch::Receiver::changed) +/// * [`tokio::net::TcpListener::accept`](crate::net::TcpListener::accept) +/// * [`tokio::net::UnixListener::accept`](crate::net::UnixListener::accept) +/// * [`tokio::io::AsyncReadExt::read`](crate::io::AsyncReadExt::read) on any `AsyncRead` +/// * [`tokio::io::AsyncReadExt::read_buf`](crate::io::AsyncReadExt::read_buf) on any `AsyncRead` +/// * [`tokio::io::AsyncWriteExt::write`](crate::io::AsyncWriteExt::write) on any `AsyncWrite` +/// * [`tokio::io::AsyncWriteExt::write_buf`](crate::io::AsyncWriteExt::write_buf) on any `AsyncWrite` +/// * [`tokio_stream::StreamExt::next`](https://docs.rs/tokio-stream/0.1/tokio_stream/trait.StreamExt.html#method.next) on any `Stream` +/// * [`futures::stream::StreamExt::next`](https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.next) on any `Stream` +/// +/// The following methods are not cancellation safe and can lead to loss of data: +/// +/// * [`tokio::io::AsyncReadExt::read_exact`](crate::io::AsyncReadExt::read_exact) +/// * [`tokio::io::AsyncReadExt::read_to_end`](crate::io::AsyncReadExt::read_to_end) +/// * [`tokio::io::AsyncReadExt::read_to_string`](crate::io::AsyncReadExt::read_to_string) +/// * [`tokio::io::AsyncWriteExt::write_all`](crate::io::AsyncWriteExt::write_all) +/// +/// The following methods are not cancellation safe because they use a queue for +/// fairness and cancellation makes you lose your place in the queue: +/// +/// * [`tokio::sync::Mutex::lock`](crate::sync::Mutex::lock) +/// * [`tokio::sync::RwLock::read`](crate::sync::RwLock::read) +/// * [`tokio::sync::RwLock::write`](crate::sync::RwLock::write) +/// * [`tokio::sync::Semaphore::acquire`](crate::sync::Semaphore::acquire) +/// * [`tokio::sync::Notify::notified`](crate::sync::Notify::notified) +/// +/// To determine whether your own methods are cancellation safe, look for the +/// location of uses of `.await`. This is because when an asynchronous method is +/// cancelled, that always happens at an `.await`. If your function behaves +/// correctly even if it is restarted while waiting at an `.await`, then it is +/// cancellation safe. +/// +/// Be aware that cancelling something that is not cancellation safe is not +/// necessarily wrong. For example, if you are cancelling a task because the +/// application is shutting down, then you probably don't care that partially +/// read data is lost. /// /// # Examples /// @@ -310,7 +289,7 @@ /// loop { /// tokio::select! { /// // If you run this example without `biased;`, the polling order is -/// // psuedo-random, and the assertions on the value of count will +/// // pseudo-random, and the assertions on the value of count will /// // (probably) fail. /// biased; /// @@ -338,6 +317,77 @@ /// } /// } /// ``` +/// +/// ## Avoid racy `if` preconditions +/// +/// Given that `if` preconditions are used to disable `select!` branches, some +/// caution must be used to avoid missing values. +/// +/// For example, here is **incorrect** usage of `sleep` with `if`. The objective +/// is to repeatedly run an asynchronous task for up to 50 milliseconds. +/// However, there is a potential for the `sleep` completion to be missed. +/// +/// ```no_run,should_panic +/// use tokio::time::{self, Duration}; +/// +/// async fn some_async_work() { +/// // do work +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let sleep = time::sleep(Duration::from_millis(50)); +/// tokio::pin!(sleep); +/// +/// while !sleep.is_elapsed() { +/// tokio::select! { +/// _ = &mut sleep, if !sleep.is_elapsed() => { +/// println!("operation timed out"); +/// } +/// _ = some_async_work() => { +/// println!("operation completed"); +/// } +/// } +/// } +/// +/// panic!("This example shows how not to do it!"); +/// } +/// ``` +/// +/// In the above example, `sleep.is_elapsed()` may return `true` even if +/// `sleep.poll()` never returned `Ready`. This opens up a potential race +/// condition where `sleep` expires between the `while !sleep.is_elapsed()` +/// check and the call to `select!` resulting in the `some_async_work()` call to +/// run uninterrupted despite the sleep having elapsed. +/// +/// One way to write the above example without the race would be: +/// +/// ``` +/// use tokio::time::{self, Duration}; +/// +/// async fn some_async_work() { +/// # time::sleep(Duration::from_millis(10)).await; +/// // do work +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let sleep = time::sleep(Duration::from_millis(50)); +/// tokio::pin!(sleep); +/// +/// loop { +/// tokio::select! { +/// _ = &mut sleep => { +/// println!("operation timed out"); +/// break; +/// } +/// _ = some_async_work() => { +/// println!("operation completed"); +/// } +/// } +/// } +/// } +/// ``` #[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] macro_rules! select { @@ -398,7 +448,7 @@ macro_rules! select { // set the appropriate bit in `disabled`. $( if !$c { - let mask = 1 << $crate::count!( $($skip)* ); + let mask: util::Mask = 1 << $crate::count!( $($skip)* ); disabled |= mask; } )* @@ -417,7 +467,7 @@ macro_rules! select { let mut is_pending = false; // Choose a starting index to begin polling the futures at. In - // practice, this will either be a psuedo-randomly generrated + // practice, this will either be a pseudo-randomly generated // number by default, or the constant 0 if `biased;` is // supplied. let start = $start; @@ -452,7 +502,7 @@ macro_rules! select { let mut fut = unsafe { Pin::new_unchecked(fut) }; // Try polling it - let out = match fut.poll(cx) { + let out = match Future::poll(fut, cx) { Ready(out) => out, Pending => { // Track that at least one future is @@ -470,7 +520,7 @@ macro_rules! select { #[allow(unused_variables)] #[allow(unused_mut)] match &out { - $bind => {} + $crate::select_priv_clean_pattern!($bind) => {} _ => continue, } diff --git a/src/macros/trace.rs b/src/macros/trace.rs new file mode 100644 index 0000000..31dde2f --- /dev/null +++ b/src/macros/trace.rs @@ -0,0 +1,27 @@ +cfg_trace! { + macro_rules! trace_op { + ($name:literal, $readiness:literal, $parent:expr) => { + tracing::trace!( + target: "runtime::resource::poll_op", + parent: $parent, + op_name = $name, + is_ready = $readiness + ); + } + } + + macro_rules! trace_poll_op { + ($name:literal, $poll:expr, $parent:expr $(,)*) => { + match $poll { + std::task::Poll::Ready(t) => { + trace_op!($name, true, $parent); + std::task::Poll::Ready(t) + } + std::task::Poll::Pending => { + trace_op!($name, false, $parent); + return std::task::Poll::Pending; + } + } + }; + } +} diff --git a/src/macros/try_join.rs b/src/macros/try_join.rs index fa5850e..6d3a893 100644 --- a/src/macros/try_join.rs +++ b/src/macros/try_join.rs @@ -1,4 +1,4 @@ -/// Wait on multiple concurrent branches, returning when **all** branches +/// Waits on multiple concurrent branches, returning when **all** branches /// complete with `Ok(_)` or on the first `Err(_)`. /// /// The `try_join!` macro must be used inside of async functions, closures, and @@ -59,6 +59,45 @@ /// } /// } /// ``` +/// +/// Using `try_join!` with spawned tasks. +/// +/// ``` +/// use tokio::task::JoinHandle; +/// +/// async fn do_stuff_async() -> Result<(), &'static str> { +/// // async work +/// # Err("failed") +/// } +/// +/// async fn more_async_work() -> Result<(), &'static str> { +/// // more here +/// # Ok(()) +/// } +/// +/// async fn flatten<T>(handle: JoinHandle<Result<T, &'static str>>) -> Result<T, &'static str> { +/// match handle.await { +/// Ok(Ok(result)) => Ok(result), +/// Ok(Err(err)) => Err(err), +/// Err(err) => Err("handling failed"), +/// } +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let handle1 = tokio::spawn(do_stuff_async()); +/// let handle2 = tokio::spawn(more_async_work()); +/// match tokio::try_join!(flatten(handle1), flatten(handle2)) { +/// Ok(val) => { +/// // do something with the values +/// } +/// Err(err) => { +/// println!("Failed with {}.", err); +/// # assert_eq!(err, "failed"); +/// } +/// } +/// } +/// ``` #[macro_export] #[cfg_attr(docsrs, doc(cfg(feature = "macros")))] macro_rules! try_join { diff --git a/src/net/mod.rs b/src/net/mod.rs index 2f17f9e..0b8c1ec 100644 --- a/src/net/mod.rs +++ b/src/net/mod.rs @@ -46,3 +46,7 @@ cfg_net_unix! { pub use unix::listener::UnixListener; pub use unix::stream::UnixStream; } + +cfg_net_windows! { + pub mod windows; +} diff --git a/src/net/tcp/listener.rs b/src/net/tcp/listener.rs index 5c093bb..8aecb21 100644 --- a/src/net/tcp/listener.rs +++ b/src/net/tcp/listener.rs @@ -125,6 +125,13 @@ impl TcpListener { /// established, the corresponding [`TcpStream`] and the remote peer's /// address will be returned. /// + /// # Cancel safety + /// + /// This method is cancel safe. If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no new connections were + /// accepted by this method. + /// /// [`TcpStream`]: struct@crate::net::TcpStream /// /// # Examples @@ -220,7 +227,7 @@ impl TcpListener { Ok(TcpListener { io }) } - /// Turn a [`tokio::net::TcpListener`] into a [`std::net::TcpListener`]. + /// Turns a [`tokio::net::TcpListener`] into a [`std::net::TcpListener`]. /// /// The returned [`std::net::TcpListener`] will have nonblocking mode set as /// `true`. Use [`set_nonblocking`] to change the blocking mode if needed. diff --git a/src/net/tcp/mod.rs b/src/net/tcp/mod.rs index 7f0f6d9..cb8a8b2 100644 --- a/src/net/tcp/mod.rs +++ b/src/net/tcp/mod.rs @@ -1,4 +1,4 @@ -//! TCP utility types +//! TCP utility types. pub(crate) mod listener; diff --git a/src/net/tcp/socket.rs b/src/net/tcp/socket.rs index 4bcbe3f..f54ff95 100644 --- a/src/net/tcp/socket.rs +++ b/src/net/tcp/socket.rs @@ -87,7 +87,7 @@ cfg_net! { } impl TcpSocket { - /// Create a new socket configured for IPv4. + /// Creates a new socket configured for IPv4. /// /// Calls `socket(2)` with `AF_INET` and `SOCK_STREAM`. /// @@ -121,7 +121,7 @@ impl TcpSocket { Ok(TcpSocket { inner }) } - /// Create a new socket configured for IPv6. + /// Creates a new socket configured for IPv6. /// /// Calls `socket(2)` with `AF_INET6` and `SOCK_STREAM`. /// @@ -155,7 +155,7 @@ impl TcpSocket { Ok(TcpSocket { inner }) } - /// Allow the socket to bind to an in-use address. + /// Allows the socket to bind to an in-use address. /// /// Behavior is platform specific. Refer to the target platform's /// documentation for more details. @@ -185,7 +185,7 @@ impl TcpSocket { self.inner.set_reuseaddr(reuseaddr) } - /// Retrieves the value set for `SO_REUSEADDR` on this socket + /// Retrieves the value set for `SO_REUSEADDR` on this socket. /// /// # Examples /// @@ -211,7 +211,7 @@ impl TcpSocket { self.inner.get_reuseaddr() } - /// Allow the socket to bind to an in-use port. Only available for unix systems + /// Allows the socket to bind to an in-use port. Only available for unix systems /// (excluding Solaris & Illumos). /// /// Behavior is platform specific. Refer to the target platform's @@ -245,7 +245,7 @@ impl TcpSocket { self.inner.set_reuseport(reuseport) } - /// Allow the socket to bind to an in-use port. Only available for unix systems + /// Allows the socket to bind to an in-use port. Only available for unix systems /// (excluding Solaris & Illumos). /// /// Behavior is platform specific. Refer to the target platform's @@ -348,7 +348,7 @@ impl TcpSocket { self.inner.get_recv_buffer_size() } - /// Get the local address of this socket. + /// Gets the local address of this socket. /// /// Will fail on windows if called before `bind`. /// @@ -374,7 +374,7 @@ impl TcpSocket { self.inner.get_localaddr() } - /// Bind the socket to the given address. + /// Binds the socket to the given address. /// /// This calls the `bind(2)` operating-system function. Behavior is /// platform specific. Refer to the target platform's documentation for more @@ -406,7 +406,7 @@ impl TcpSocket { self.inner.bind(addr) } - /// Establish a TCP connection with a peer at the specified socket address. + /// Establishes a TCP connection with a peer at the specified socket address. /// /// The `TcpSocket` is consumed. Once the connection is established, a /// connected [`TcpStream`] is returned. If the connection fails, the @@ -443,7 +443,7 @@ impl TcpSocket { TcpStream::connect_mio(mio).await } - /// Convert the socket into a `TcpListener`. + /// Converts the socket into a `TcpListener`. /// /// `backlog` defines the maximum number of pending connections are queued /// by the operating system at any given time. Connection are removed from @@ -482,6 +482,48 @@ impl TcpSocket { let mio = self.inner.listen(backlog)?; TcpListener::new(mio) } + + /// Converts a [`std::net::TcpStream`] into a `TcpSocket`. The provided + /// socket must not have been connected prior to calling this function. This + /// function is typically used together with crates such as [`socket2`] to + /// configure socket options that are not available on `TcpSocket`. + /// + /// [`std::net::TcpStream`]: struct@std::net::TcpStream + /// [`socket2`]: https://docs.rs/socket2/ + /// + /// # Examples + /// + /// ``` + /// use tokio::net::TcpSocket; + /// use socket2::{Domain, Socket, Type}; + /// + /// #[tokio::main] + /// async fn main() -> std::io::Result<()> { + /// + /// let socket2_socket = Socket::new(Domain::IPV4, Type::STREAM, None)?; + /// + /// let socket = TcpSocket::from_std_stream(socket2_socket.into()); + /// + /// Ok(()) + /// } + /// ``` + pub fn from_std_stream(std_stream: std::net::TcpStream) -> TcpSocket { + #[cfg(unix)] + { + use std::os::unix::io::{FromRawFd, IntoRawFd}; + + let raw_fd = std_stream.into_raw_fd(); + unsafe { TcpSocket::from_raw_fd(raw_fd) } + } + + #[cfg(windows)] + { + use std::os::windows::io::{FromRawSocket, IntoRawSocket}; + + let raw_socket = std_stream.into_raw_socket(); + unsafe { TcpSocket::from_raw_socket(raw_socket) } + } + } } impl fmt::Debug for TcpSocket { diff --git a/src/net/tcp/split.rs b/src/net/tcp/split.rs index 78bd688..0e02928 100644 --- a/src/net/tcp/split.rs +++ b/src/net/tcp/split.rs @@ -9,14 +9,18 @@ //! level. use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; use crate::net::TcpStream; use std::io; -use std::net::Shutdown; +use std::net::{Shutdown, SocketAddr}; use std::pin::Pin; use std::task::{Context, Poll}; +cfg_io_util! { + use bytes::BufMut; +} + /// Borrowed read half of a [`TcpStream`], created by [`split`]. /// /// Reading from a `ReadHalf` is usually done using the convenience methods found on the @@ -30,7 +34,7 @@ pub struct ReadHalf<'a>(&'a TcpStream); /// Borrowed write half of a [`TcpStream`], created by [`split`]. /// -/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will /// shut down the TCP stream in the write direction. /// /// Writing to an `WriteHalf` is usually done using the convenience methods found @@ -49,7 +53,7 @@ pub(crate) fn split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>) { } impl ReadHalf<'_> { - /// Attempt to receive data on the socket, without removing that data from + /// Attempts to receive data on the socket, without removing that data from /// the queue, registering the current task for wakeup if data is not yet /// available. /// @@ -57,7 +61,7 @@ impl ReadHalf<'_> { /// `Waker` from the `Context` passed to the most recent call is scheduled /// to receive a wakeup. /// - /// See the [`TcpStream::poll_peek`] level documenation for more details. + /// See the [`TcpStream::poll_peek`] level documentation for more details. /// /// # Examples /// @@ -95,7 +99,7 @@ impl ReadHalf<'_> { /// connected, without removing that data from the queue. On success, /// returns the number of bytes peeked. /// - /// See the [`TcpStream::peek`] level documenation for more details. + /// See the [`TcpStream::peek`] level documentation for more details. /// /// [`TcpStream::peek`]: TcpStream::peek /// @@ -134,6 +138,211 @@ impl ReadHalf<'_> { let mut buf = ReadBuf::new(buf); poll_fn(|cx| self.poll_peek(cx, &mut buf)).await } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// This function is equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.0.ready(interest).await + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// This function is also equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn readable(&self) -> io::Result<()> { + self.0.readable().await + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.0.try_read(buf) + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: Self::try_read() + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.0.try_read_vectored(bufs) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.0.try_read_buf(buf) + } + } + + /// Returns the remote address that this stream is connected to. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0.peer_addr() + } + + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.0.local_addr() + } +} + +impl WriteHalf<'_> { + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// This function is equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.0.ready(interest).await + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> io::Result<()> { + self.0.writable().await + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.0.try_write(buf) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: Self::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.0.try_write_vectored(bufs) + } + + /// Returns the remote address that this stream is connected to. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0.peer_addr() + } + + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.0.local_addr() + } } impl AsyncRead for ReadHalf<'_> { diff --git a/src/net/tcp/split_owned.rs b/src/net/tcp/split_owned.rs index d52c2f6..ef4e7b5 100644 --- a/src/net/tcp/split_owned.rs +++ b/src/net/tcp/split_owned.rs @@ -9,16 +9,20 @@ //! level. use crate::future::poll_fn; -use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; use crate::net::TcpStream; use std::error::Error; -use std::net::Shutdown; +use std::net::{Shutdown, SocketAddr}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use std::{fmt, io}; +cfg_io_util! { + use bytes::BufMut; +} + /// Owned read half of a [`TcpStream`], created by [`into_split`]. /// /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found @@ -112,7 +116,7 @@ impl OwnedReadHalf { /// `Waker` from the `Context` passed to the most recent call is scheduled /// to receive a wakeup. /// - /// See the [`TcpStream::poll_peek`] level documenation for more details. + /// See the [`TcpStream::poll_peek`] level documentation for more details. /// /// # Examples /// @@ -150,7 +154,7 @@ impl OwnedReadHalf { /// connected, without removing that data from the queue. On success, /// returns the number of bytes peeked. /// - /// See the [`TcpStream::peek`] level documenation for more details. + /// See the [`TcpStream::peek`] level documentation for more details. /// /// [`TcpStream::peek`]: TcpStream::peek /// @@ -189,6 +193,128 @@ impl OwnedReadHalf { let mut buf = ReadBuf::new(buf); poll_fn(|cx| self.poll_peek(cx, &mut buf)).await } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// This function is equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.inner.ready(interest).await + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// This function is also equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn readable(&self) -> io::Result<()> { + self.inner.readable().await + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.try_read(buf) + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: Self::try_read() + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.inner.try_read_vectored(bufs) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.inner.try_read_buf(buf) + } + } + + /// Returns the remote address that this stream is connected to. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.inner.peer_addr() + } + + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr() + } } impl AsyncRead for OwnedReadHalf { @@ -211,13 +337,94 @@ impl OwnedWriteHalf { reunite(other, self) } - /// Destroy the write half, but don't close the write half of the stream + /// Destroys the write half, but don't close the write half of the stream /// until the read half is dropped. If the read half has already been /// dropped, this closes the stream. pub fn forget(mut self) { self.shutdown_on_drop = false; drop(self); } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// This function is equivalent to [`TcpStream::ready`]. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.inner.ready(interest).await + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> io::Result<()> { + self.inner.writable().await + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.inner.try_write(buf) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: Self::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.inner.try_write_vectored(bufs) + } + + /// Returns the remote address that this stream is connected to. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.inner.peer_addr() + } + + /// Returns the local address that this stream is bound to. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr() + } } impl Drop for OwnedWriteHalf { diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index e231e5d..60d20fd 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -192,7 +192,7 @@ impl TcpStream { Ok(TcpStream { io }) } - /// Turn a [`tokio::net::TcpStream`] into a [`std::net::TcpStream`]. + /// Turns a [`tokio::net::TcpStream`] into a [`std::net::TcpStream`]. /// /// The returned [`std::net::TcpStream`] will have nonblocking mode set as `true`. /// Use [`set_nonblocking`] to change the blocking mode if needed. @@ -350,12 +350,19 @@ impl TcpStream { } } - /// Wait for any of the requested ready states. + /// Waits for any of the requested ready states. /// /// This function is usually paired with `try_read()` or `try_write()`. It /// can be used to concurrently read / write to the same socket on a single /// task without splitting the socket. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// Concurrently read and write to the stream on the same task without @@ -415,11 +422,18 @@ impl TcpStream { Ok(event.ready) } - /// Wait for the socket to become readable. + /// Waits for the socket to become readable. /// /// This function is equivalent to `ready(Interest::READABLE)` and is usually /// paired with `try_read()`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -496,7 +510,7 @@ impl TcpStream { self.io.registration().poll_read_ready(cx).map_ok(|_| ()) } - /// Try to read data from the stream into the provided buffer, returning how + /// Tries to read data from the stream into the provided buffer, returning how /// many bytes were read. /// /// Receives any pending data from the socket but does not wait for new data @@ -563,8 +577,86 @@ impl TcpStream { .try_io(Interest::READABLE, || (&*self.io).read(buf)) } + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: TcpStream::try_read() + /// [`readable()`]: TcpStream::readable() + /// [`ready()`]: TcpStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + use std::io::Read; + + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + cfg_io_util! { - /// Try to read data from the stream into the provided buffer, advancing the + /// Tries to read data from the stream into the provided buffer, advancing the /// buffer's internal cursor, returning how many bytes were read. /// /// Receives any pending data from the socket but does not wait for new data @@ -642,11 +734,18 @@ impl TcpStream { } } - /// Wait for the socket to become writable. + /// Waits for the socket to become writable. /// /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually /// paired with `try_write()`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -775,6 +874,103 @@ impl TcpStream { .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) } + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: TcpStream::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::TcpStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let stream = TcpStream::connect("127.0.0.1:8080").await?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> { + use std::io::Write; + + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(bufs)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `TcpStream` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: TcpStream::readable() + /// [`writable()`]: TcpStream::writable() + /// [`ready()`]: TcpStream::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + /// Receives data on the socket from the remote address to which it is /// connected, without removing that data from the queue. On success, /// returns the number of bytes peeked. @@ -1012,6 +1208,12 @@ impl TcpStream { split_owned(self) } + // == Poll IO functions that takes `&self` == + // + // To read or write without mutable access to the `UnixStream`, combine the + // `poll_read_ready` or `poll_write_ready` methods with the `try_read` or + // `try_write` methods. + pub(crate) fn poll_read_priv( &self, cx: &mut Context<'_>, diff --git a/src/net/udp.rs b/src/net/udp.rs index 6e63355..504d74e 100644 --- a/src/net/udp.rs +++ b/src/net/udp.rs @@ -12,7 +12,7 @@ cfg_io_util! { } cfg_net! { - /// A UDP socket + /// A UDP socket. /// /// UDP is "connectionless", unlike TCP. Meaning, regardless of what address you've bound to, a `UdpSocket` /// is free to communicate with many different remotes. In tokio there are basically two main ways to use `UdpSocket`: @@ -211,7 +211,7 @@ impl UdpSocket { UdpSocket::new(io) } - /// Turn a [`tokio::net::UdpSocket`] into a [`std::net::UdpSocket`]. + /// Turns a [`tokio::net::UdpSocket`] into a [`std::net::UdpSocket`]. /// /// The returned [`std::net::UdpSocket`] will have nonblocking mode set as /// `true`. Use [`set_nonblocking`] to change the blocking mode if needed. @@ -317,7 +317,7 @@ impl UdpSocket { })) } - /// Wait for any of the requested ready states. + /// Waits for any of the requested ready states. /// /// This function is usually paired with `try_recv()` or `try_send()`. It /// can be used to concurrently recv / send to the same socket on a single @@ -327,6 +327,13 @@ impl UdpSocket { /// false-positive and attempting an operation will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// Concurrently receive from and send to the socket on the same task @@ -381,7 +388,7 @@ impl UdpSocket { Ok(event.ready) } - /// Wait for the socket to become writable. + /// Waits for the socket to become writable. /// /// This function is equivalent to `ready(Interest::WRITABLE)` and is /// usually paired with `try_send()` or `try_send_to()`. @@ -390,6 +397,13 @@ impl UdpSocket { /// false-positive and attempting a `try_send()` will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -429,6 +443,39 @@ impl UdpSocket { Ok(()) } + /// Polls for write/send readiness. + /// + /// If the udp stream is not currently ready for sending, this method will + /// store a clone of the `Waker` from the provided `Context`. When the udp + /// stream becomes ready for sending, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_send_ready` or `poll_send`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_recv_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the udp stream is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the udp stream is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + /// Sends data on the socket to the remote address that the socket is /// connected to. /// @@ -442,6 +489,12 @@ impl UdpSocket { /// On success, the number of bytes sent is returned, otherwise, the /// encountered error is returned. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// /// # Examples /// /// ```no_run @@ -496,7 +549,7 @@ impl UdpSocket { .poll_write_io(cx, || self.io.send(buf)) } - /// Try to send data on the socket to the remote address to which it is + /// Tries to send data on the socket to the remote address to which it is /// connected. /// /// When the socket buffer is full, `Err(io::ErrorKind::WouldBlock)` is @@ -550,7 +603,7 @@ impl UdpSocket { .try_io(Interest::WRITABLE, || self.io.send(buf)) } - /// Wait for the socket to become readable. + /// Waits for the socket to become readable. /// /// This function is equivalent to `ready(Interest::READABLE)` and is usually /// paired with `try_recv()`. @@ -559,6 +612,13 @@ impl UdpSocket { /// false-positive and attempting a `try_recv()` will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -603,6 +663,39 @@ impl UdpSocket { Ok(()) } + /// Polls for read/receive readiness. + /// + /// If the udp stream is not currently ready for receiving, this method will + /// store a clone of the `Waker` from the provided `Context`. When the udp + /// socket becomes ready for reading, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_recv_ready`, `poll_recv` or + /// `poll_peek`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. (However, + /// `poll_send_ready` retains a second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the udp stream is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the udp stream is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + /// Receives a single datagram message on the socket from the remote address /// to which it is connected. On success, returns the number of bytes read. /// @@ -613,6 +706,13 @@ impl UdpSocket { /// The [`connect`] method will connect this socket to a remote address. /// This method will fail if the socket is not connected. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// /// [`connect`]: method@Self::connect /// /// ```no_run @@ -665,7 +765,7 @@ impl UdpSocket { /// [`connect`]: method@Self::connect pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { let n = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; @@ -681,7 +781,7 @@ impl UdpSocket { Poll::Ready(Ok(())) } - /// Try to receive a single datagram message on the socket from the remote + /// Tries to receive a single datagram message on the socket from the remote /// address to which it is connected. On success, returns the number of /// bytes read. /// @@ -738,7 +838,7 @@ impl UdpSocket { } cfg_io_util! { - /// Try to receive data from the stream into the provided buffer, advancing the + /// Tries to receive data from the stream into the provided buffer, advancing the /// buffer's internal cursor, returning how many bytes were read. /// /// The function must be called with valid byte array buf of sufficient size @@ -803,7 +903,7 @@ impl UdpSocket { }) } - /// Try to receive a single datagram message on the socket. On success, + /// Tries to receive a single datagram message on the socket. On success, /// returns the number of bytes read and the origin. /// /// The function must be called with valid byte array buf of sufficient size @@ -882,6 +982,12 @@ impl UdpSocket { /// /// [`ToSocketAddrs`]: crate::net::ToSocketAddrs /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send_to` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// /// # Example /// /// ```no_run @@ -938,14 +1044,14 @@ impl UdpSocket { .poll_write_io(cx, || self.io.send_to(buf, target)) } - /// Try to send data on the socket to the given address, but if the send is + /// Tries to send data on the socket to the given address, but if the send is /// blocked this will return right away. /// /// This function is usually paired with `writable()`. /// /// # Returns /// - /// If successfull, returns the number of bytes sent + /// If successful, returns the number of bytes sent /// /// Users should ensure that when the remote cannot receive, the /// [`ErrorKind::WouldBlock`] is properly handled. An error can also occur @@ -1005,6 +1111,13 @@ impl UdpSocket { /// size to hold the message bytes. If a message is too long to fit in the /// supplied buffer, excess bytes may be discarded. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// /// # Example /// /// ```no_run @@ -1053,7 +1166,7 @@ impl UdpSocket { buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<SocketAddr>> { let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; @@ -1069,7 +1182,7 @@ impl UdpSocket { Poll::Ready(Ok(addr)) } - /// Try to receive a single datagram message on the socket. On success, + /// Tries to receive a single datagram message on the socket. On success, /// returns the number of bytes read and the origin. /// /// The function must be called with valid byte array buf of sufficient size @@ -1123,6 +1236,41 @@ impl UdpSocket { .try_io(Interest::READABLE, || self.io.recv_from(buf)) } + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UdpSocket` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UdpSocket::readable() + /// [`writable()`]: UdpSocket::writable() + /// [`ready()`]: UdpSocket::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + /// Receives data from the socket, without removing it from the input queue. /// On success, returns the number of bytes read and the address from whence /// the data came. @@ -1192,7 +1340,7 @@ impl UdpSocket { buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<SocketAddr>> { let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; diff --git a/src/net/unix/datagram/socket.rs b/src/net/unix/datagram/socket.rs index 6bc5615..d5b6186 100644 --- a/src/net/unix/datagram/socket.rs +++ b/src/net/unix/datagram/socket.rs @@ -96,7 +96,7 @@ cfg_net_unix! { } impl UnixDatagram { - /// Wait for any of the requested ready states. + /// Waits for any of the requested ready states. /// /// This function is usually paired with `try_recv()` or `try_send()`. It /// can be used to concurrently recv / send to the same socket on a single @@ -106,6 +106,13 @@ impl UnixDatagram { /// false-positive and attempting an operation will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// Concurrently receive from and send to the socket on the same task @@ -162,7 +169,7 @@ impl UnixDatagram { Ok(event.ready) } - /// Wait for the socket to become writable. + /// Waits for the socket to become writable. /// /// This function is equivalent to `ready(Interest::WRITABLE)` and is /// usually paired with `try_send()` or `try_send_to()`. @@ -171,6 +178,13 @@ impl UnixDatagram { /// false-positive and attempting a `try_send()` will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -212,7 +226,40 @@ impl UnixDatagram { Ok(()) } - /// Wait for the socket to become readable. + /// Polls for write/send readiness. + /// + /// If the socket is not currently ready for sending, this method will + /// store a clone of the `Waker` from the provided `Context`. When the socket + /// becomes ready for sending, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_send_ready` or `poll_send`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_recv_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the socket is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_send_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Waits for the socket to become readable. /// /// This function is equivalent to `ready(Interest::READABLE)` and is usually /// paired with `try_recv()`. @@ -221,6 +268,13 @@ impl UnixDatagram { /// false-positive and attempting a `try_recv()` will return with /// `io::ErrorKind::WouldBlock`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -268,6 +322,39 @@ impl UnixDatagram { Ok(()) } + /// Polls for read/receive readiness. + /// + /// If the socket is not currently ready for receiving, this method will + /// store a clone of the `Waker` from the provided `Context`. When the + /// socket becomes ready for reading, `Waker::wake` will be called on the + /// waker. + /// + /// Note that on multiple calls to `poll_recv_ready`, `poll_recv` or + /// `poll_peek`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. (However, + /// `poll_send_ready` retains a second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the socket is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the socket is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_recv_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + /// Creates a new `UnixDatagram` bound to the specified path. /// /// # Examples @@ -376,7 +463,7 @@ impl UnixDatagram { Ok(UnixDatagram { io }) } - /// Turn a [`tokio::net::UnixDatagram`] into a [`std::os::unix::net::UnixDatagram`]. + /// Turns a [`tokio::net::UnixDatagram`] into a [`std::os::unix::net::UnixDatagram`]. /// /// The returned [`std::os::unix::net::UnixDatagram`] will have nonblocking /// mode set as `true`. Use [`set_nonblocking`] to change the blocking mode @@ -490,6 +577,12 @@ impl UnixDatagram { /// Sends data on the socket to the socket's peer. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// /// # Examples /// ``` /// # use std::error::Error; @@ -521,7 +614,7 @@ impl UnixDatagram { .await } - /// Try to send a datagram to the peer without waiting. + /// Tries to send a datagram to the peer without waiting. /// /// # Examples /// @@ -565,7 +658,7 @@ impl UnixDatagram { .try_io(Interest::WRITABLE, || self.io.send(buf)) } - /// Try to send a datagram to the peer without waiting. + /// Tries to send a datagram to the peer without waiting. /// /// # Examples /// @@ -613,6 +706,13 @@ impl UnixDatagram { /// Receives data from the socket. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// /// # Examples /// ``` /// # use std::error::Error; @@ -644,7 +744,7 @@ impl UnixDatagram { .await } - /// Try to receive a datagram from the peer without waiting. + /// Tries to receive a datagram from the peer without waiting. /// /// # Examples /// @@ -695,7 +795,7 @@ impl UnixDatagram { } cfg_io_util! { - /// Try to receive data from the socket without waiting. + /// Tries to receive data from the socket without waiting. /// /// # Examples /// @@ -756,7 +856,7 @@ impl UnixDatagram { Ok((n, SocketAddr(addr))) } - /// Try to read data from the stream into the provided buffer, advancing the + /// Tries to read data from the stream into the provided buffer, advancing the /// buffer's internal cursor, returning how many bytes were read. /// /// # Examples @@ -820,6 +920,12 @@ impl UnixDatagram { /// Sends data on the socket to the specified address. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `send_to` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that the message was not sent. + /// /// # Examples /// ``` /// # use std::error::Error; @@ -863,6 +969,13 @@ impl UnixDatagram { /// Receives data from the socket. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_from` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// socket. + /// /// # Examples /// ``` /// # use std::error::Error; @@ -927,7 +1040,7 @@ impl UnixDatagram { buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<SocketAddr>> { let (n, addr) = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; @@ -1028,7 +1141,7 @@ impl UnixDatagram { /// [`connect`]: method@Self::connect pub fn poll_recv(&self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { let n = ready!(self.io.registration().poll_read_io(cx, || { - // Safety: will not read the maybe uinitialized bytes. + // Safety: will not read the maybe uninitialized bytes. let b = unsafe { &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]) }; @@ -1044,7 +1157,7 @@ impl UnixDatagram { Poll::Ready(Ok(())) } - /// Try to receive data from the socket without waiting. + /// Tries to receive data from the socket without waiting. /// /// # Examples /// @@ -1096,6 +1209,41 @@ impl UnixDatagram { Ok((n, SocketAddr(addr))) } + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UnixDatagram` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixDatagram::readable() + /// [`writable()`]: UnixDatagram::writable() + /// [`ready()`]: UnixDatagram::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + /// Returns the local address that this socket is bound to. /// /// # Examples diff --git a/src/net/unix/listener.rs b/src/net/unix/listener.rs index b5b05a6..1785f8b 100644 --- a/src/net/unix/listener.rs +++ b/src/net/unix/listener.rs @@ -88,7 +88,7 @@ impl UnixListener { Ok(UnixListener { io }) } - /// Turn a [`tokio::net::UnixListener`] into a [`std::os::unix::net::UnixListener`]. + /// Turns a [`tokio::net::UnixListener`] into a [`std::os::unix::net::UnixListener`]. /// /// The returned [`std::os::unix::net::UnixListener`] will have nonblocking mode /// set as `true`. Use [`set_nonblocking`] to change the blocking mode if needed. @@ -128,6 +128,13 @@ impl UnixListener { } /// Accepts a new incoming connection to this listener. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If the method is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no new connections were + /// accepted by this method. pub async fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { let (mio, addr) = self .io diff --git a/src/net/unix/mod.rs b/src/net/unix/mod.rs index 19ee34a..14cb456 100644 --- a/src/net/unix/mod.rs +++ b/src/net/unix/mod.rs @@ -1,5 +1,9 @@ -//! Unix domain socket utility types +//! Unix domain socket utility types. +// This module does not currently provide any public API, but it was +// unintentionally defined as a public module. Hide it from the documentation +// instead of changing it to a private module to avoid breakage. +#[doc(hidden)] pub mod datagram; pub(crate) mod listener; diff --git a/src/net/unix/split.rs b/src/net/unix/split.rs index 24a711b..d4686c2 100644 --- a/src/net/unix/split.rs +++ b/src/net/unix/split.rs @@ -8,14 +8,19 @@ //! split has no associated overhead and enforces all invariants at the type //! level. -use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; use crate::net::UnixStream; +use crate::net::unix::SocketAddr; use std::io; use std::net::Shutdown; use std::pin::Pin; use std::task::{Context, Poll}; +cfg_io_util! { + use bytes::BufMut; +} + /// Borrowed read half of a [`UnixStream`], created by [`split`]. /// /// Reading from a `ReadHalf` is usually done using the convenience methods found on the @@ -29,7 +34,7 @@ pub struct ReadHalf<'a>(&'a UnixStream); /// Borrowed write half of a [`UnixStream`], created by [`split`]. /// -/// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will +/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will /// shut down the UnixStream stream in the write direction. /// /// Writing to an `WriteHalf` is usually done using the convenience methods found @@ -47,6 +52,206 @@ pub(crate) fn split(stream: &mut UnixStream) -> (ReadHalf<'_>, WriteHalf<'_>) { (ReadHalf(stream), WriteHalf(stream)) } +impl ReadHalf<'_> { + /// Wait for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.0.ready(interest).await + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn readable(&self) -> io::Result<()> { + self.0.readable().await + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.0.try_read(buf) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.0.try_read_buf(buf) + } + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: Self::try_read() + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.0.try_read_vectored(bufs) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0.peer_addr() + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.0.local_addr() + } +} + +impl WriteHalf<'_> { + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.0.ready(interest).await + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> io::Result<()> { + self.0.writable().await + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.0.try_write(buf) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: Self::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.0.try_write_vectored(buf) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.0.peer_addr() + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.0.local_addr() + } +} + impl AsyncRead for ReadHalf<'_> { fn poll_read( self: Pin<&mut Self>, diff --git a/src/net/unix/split_owned.rs b/src/net/unix/split_owned.rs index 3d6ac6a..9c3a2a4 100644 --- a/src/net/unix/split_owned.rs +++ b/src/net/unix/split_owned.rs @@ -8,9 +8,10 @@ //! split has no associated overhead and enforces all invariants at the type //! level. -use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; +use crate::io::{AsyncRead, AsyncWrite, Interest, ReadBuf, Ready}; use crate::net::UnixStream; +use crate::net::unix::SocketAddr; use std::error::Error; use std::net::Shutdown; use std::pin::Pin; @@ -18,6 +19,10 @@ use std::sync::Arc; use std::task::{Context, Poll}; use std::{fmt, io}; +cfg_io_util! { + use bytes::BufMut; +} + /// Owned read half of a [`UnixStream`], created by [`into_split`]. /// /// Reading from an `OwnedReadHalf` is usually done using the convenience methods found @@ -102,6 +107,124 @@ impl OwnedReadHalf { pub fn reunite(self, other: OwnedWriteHalf) -> Result<UnixStream, ReuniteError> { reunite(self, other) } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.inner.ready(interest).await + } + + /// Waits for the socket to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn readable(&self) -> io::Result<()> { + self.inner.readable().await + } + + /// Tries to read data from the stream into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.inner.try_read(buf) + } + + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> { + self.inner.try_read_buf(buf) + } + } + + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: Self::try_read() + /// [`readable()`]: Self::readable() + /// [`ready()`]: Self::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.inner.try_read_vectored(bufs) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.inner.peer_addr() + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr() + } } impl AsyncRead for OwnedReadHalf { @@ -124,13 +247,92 @@ impl OwnedWriteHalf { reunite(other, self) } - /// Destroy the write half, but don't close the write half of the stream + /// Destroys the write half, but don't close the write half of the stream /// until the read half is dropped. If the read half has already been /// dropped, this closes the stream. pub fn forget(mut self) { self.shutdown_on_drop = false; drop(self); } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same socket on a single + /// task without splitting the socket. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + self.inner.ready(interest).await + } + + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> io::Result<()> { + self.inner.writable().await + } + + /// Tries to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.inner.try_write(buf) + } + + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: Self::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.inner.try_write_vectored(buf) + } + + /// Returns the socket address of the remote half of this connection. + pub fn peer_addr(&self) -> io::Result<SocketAddr> { + self.inner.peer_addr() + } + + /// Returns the socket address of the local half of this connection. + pub fn local_addr(&self) -> io::Result<SocketAddr> { + self.inner.local_addr() + } } impl Drop for OwnedWriteHalf { diff --git a/src/net/unix/stream.rs b/src/net/unix/stream.rs index d797aae..4e7ef87 100644 --- a/src/net/unix/stream.rs +++ b/src/net/unix/stream.rs @@ -51,15 +51,27 @@ impl UnixStream { let stream = UnixStream::new(stream)?; poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?; + + if let Some(e) = stream.io.take_error()? { + return Err(e); + } + Ok(stream) } - /// Wait for any of the requested ready states. + /// Waits for any of the requested ready states. /// /// This function is usually paired with `try_read()` or `try_write()`. It /// can be used to concurrently read / write to the same socket on a single /// task without splitting the socket. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read or write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// Concurrently read and write to the stream on the same task without @@ -121,11 +133,18 @@ impl UnixStream { Ok(event.ready) } - /// Wait for the socket to become readable. + /// Waits for the socket to become readable. /// /// This function is equivalent to `ready(Interest::READABLE)` and is usually /// paired with `try_read()`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to read that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -271,8 +290,86 @@ impl UnixStream { .try_io(Interest::READABLE, || (&*self.io).read(buf)) } + /// Tries to read data from the stream into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: UnixStream::try_read() + /// [`readable()`]: UnixStream::readable() + /// [`ready()`]: UnixStream::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// loop { + /// // Wait for the socket to be readable + /// stream.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + cfg_io_util! { - /// Try to read data from the stream into the provided buffer, advancing the + /// Tries to read data from the stream into the provided buffer, advancing the /// buffer's internal cursor, returning how many bytes were read. /// /// Receives any pending data from the socket but does not wait for new data @@ -352,11 +449,18 @@ impl UnixStream { } } - /// Wait for the socket to become writable. + /// Waits for the socket to become writable. /// /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually /// paired with `try_write()`. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + /// /// # Examples /// /// ```no_run @@ -431,7 +535,7 @@ impl UnixStream { self.io.registration().poll_write_ready(cx).map_ok(|_| ()) } - /// Try to write a buffer to the stream, returning how many bytes were + /// Tries to write a buffer to the stream, returning how many bytes were /// written. /// /// The function will attempt to write the entire contents of `buf`, but @@ -487,6 +591,103 @@ impl UnixStream { .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) } + /// Tries to write several buffers to the stream, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: UnixStream::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// use std::error::Error; + /// use std::io; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// // Connect to a peer + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the socket to be writable + /// stream.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match stream.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `UnixStream` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: UnixStream::readable() + /// [`writable()`]: UnixStream::writable() + /// [`ready()`]: UnixStream::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } + /// Creates new `UnixStream` from a `std::os::unix::net::UnixStream`. /// /// This function is intended to be used to wrap a UnixStream from the @@ -508,7 +709,7 @@ impl UnixStream { Ok(UnixStream { io }) } - /// Turn a [`tokio::net::UnixStream`] into a [`std::os::unix::net::UnixStream`]. + /// Turns a [`tokio::net::UnixStream`] into a [`std::os::unix::net::UnixStream`]. /// /// The returned [`std::os::unix::net::UnixStream`] will have nonblocking /// mode set as `true`. Use [`set_nonblocking`] to change the blocking @@ -572,11 +773,41 @@ impl UnixStream { } /// Returns the socket address of the local half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// println!("{:?}", stream.local_addr()?); + /// # Ok(()) + /// # } + /// ``` pub fn local_addr(&self) -> io::Result<SocketAddr> { self.io.local_addr().map(SocketAddr) } /// Returns the socket address of the remote half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::UnixStream; + /// + /// # async fn dox() -> Result<(), Box<dyn std::error::Error>> { + /// let dir = tempfile::tempdir().unwrap(); + /// let bind_path = dir.path().join("bind_path"); + /// let stream = UnixStream::connect(bind_path).await?; + /// + /// println!("{:?}", stream.peer_addr()?); + /// # Ok(()) + /// # } + /// ``` pub fn peer_addr(&self) -> io::Result<SocketAddr> { self.io.peer_addr().map(SocketAddr) } @@ -603,7 +834,7 @@ impl UnixStream { // These lifetime markers also appear in the generated documentation, and make // it more clear that this is a *borrowed* split. #[allow(clippy::needless_lifetimes)] - /// Split a `UnixStream` into a read half and a write half, which can be used + /// Splits a `UnixStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. /// /// This method is more efficient than [`into_split`], but the halves cannot be @@ -686,14 +917,9 @@ impl AsyncWrite for UnixStream { impl UnixStream { // == Poll IO functions that takes `&self` == // - // They are not public because (taken from the doc of `PollEvented`): - // - // While `PollEvented` is `Sync` (if the underlying I/O type is `Sync`), the - // caller must ensure that there are at most two tasks that use a - // `PollEvented` instance concurrently. One for reading and one for writing. - // While violating this requirement is "safe" from a Rust memory model point - // of view, it will result in unexpected behavior in the form of lost - // notifications and tasks hanging. + // To read or write without mutable access to the `UnixStream`, combine the + // `poll_read_ready` or `poll_write_ready` methods with the `try_read` or + // `try_write` methods. pub(crate) fn poll_read_priv( &self, diff --git a/src/net/unix/ucred.rs b/src/net/unix/ucred.rs index 5c7c198..865303b 100644 --- a/src/net/unix/ucred.rs +++ b/src/net/unix/ucred.rs @@ -1,13 +1,13 @@ use libc::{gid_t, pid_t, uid_t}; -/// Credentials of a process +/// Credentials of a process. #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub struct UCred { - /// PID (process ID) of the process + /// PID (process ID) of the process. pid: Option<pid_t>, - /// UID (user ID) of the process + /// UID (user ID) of the process. uid: uid_t, - /// GID (group ID) of the process + /// GID (group ID) of the process. gid: gid_t, } @@ -25,21 +25,19 @@ impl UCred { /// Gets PID (process ID) of the process. /// /// This is only implemented under Linux, Android, iOS, macOS, Solaris and - /// Illumos. On other plaforms this will always return `None`. + /// Illumos. On other platforms this will always return `None`. pub fn pid(&self) -> Option<pid_t> { self.pid } } -#[cfg(any(target_os = "linux", target_os = "android"))] +#[cfg(any(target_os = "linux", target_os = "android", target_os = "openbsd"))] pub(crate) use self::impl_linux::get_peer_cred; -#[cfg(any( - target_os = "dragonfly", - target_os = "freebsd", - target_os = "netbsd", - target_os = "openbsd" -))] +#[cfg(any(target_os = "netbsd"))] +pub(crate) use self::impl_netbsd::get_peer_cred; + +#[cfg(any(target_os = "dragonfly", target_os = "freebsd"))] pub(crate) use self::impl_bsd::get_peer_cred; #[cfg(any(target_os = "macos", target_os = "ios"))] @@ -48,13 +46,16 @@ pub(crate) use self::impl_macos::get_peer_cred; #[cfg(any(target_os = "solaris", target_os = "illumos"))] pub(crate) use self::impl_solaris::get_peer_cred; -#[cfg(any(target_os = "linux", target_os = "android"))] +#[cfg(any(target_os = "linux", target_os = "android", target_os = "openbsd"))] pub(crate) mod impl_linux { use crate::net::unix::UnixStream; use libc::{c_void, getsockopt, socklen_t, SOL_SOCKET, SO_PEERCRED}; use std::{io, mem}; + #[cfg(target_os = "openbsd")] + use libc::sockpeercred as ucred; + #[cfg(any(target_os = "linux", target_os = "android"))] use libc::ucred; pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { @@ -73,7 +74,7 @@ pub(crate) mod impl_linux { // These paranoid checks should be optimized-out assert!(mem::size_of::<u32>() <= mem::size_of::<usize>()); - assert!(ucred_size <= u32::max_value() as usize); + assert!(ucred_size <= u32::MAX as usize); let mut ucred_size = ucred_size as socklen_t; @@ -97,12 +98,49 @@ pub(crate) mod impl_linux { } } -#[cfg(any( - target_os = "dragonfly", - target_os = "freebsd", - target_os = "netbsd", - target_os = "openbsd" -))] +#[cfg(any(target_os = "netbsd"))] +pub(crate) mod impl_netbsd { + use crate::net::unix::UnixStream; + + use libc::{c_void, getsockopt, socklen_t, unpcbid, LOCAL_PEEREID, SOL_SOCKET}; + use std::io; + use std::mem::size_of; + use std::os::unix::io::AsRawFd; + + pub(crate) fn get_peer_cred(sock: &UnixStream) -> io::Result<super::UCred> { + unsafe { + let raw_fd = sock.as_raw_fd(); + + let mut unpcbid = unpcbid { + unp_pid: 0, + unp_euid: 0, + unp_egid: 0, + }; + + let unpcbid_size = size_of::<unpcbid>(); + let mut unpcbid_size = unpcbid_size as socklen_t; + + let ret = getsockopt( + raw_fd, + SOL_SOCKET, + LOCAL_PEEREID, + &mut unpcbid as *mut unpcbid as *mut c_void, + &mut unpcbid_size, + ); + if ret == 0 && unpcbid_size as usize == size_of::<unpcbid>() { + Ok(super::UCred { + uid: unpcbid.unp_euid, + gid: unpcbid.unp_egid, + pid: Some(unpcbid.unp_pid), + }) + } else { + Err(io::Error::last_os_error()) + } + } + } +} + +#[cfg(any(target_os = "dragonfly", target_os = "freebsd"))] pub(crate) mod impl_bsd { use crate::net::unix::UnixStream; diff --git a/src/net/windows/mod.rs b/src/net/windows/mod.rs new file mode 100644 index 0000000..060b68e --- /dev/null +++ b/src/net/windows/mod.rs @@ -0,0 +1,3 @@ +//! Windows specific network types. + +pub mod named_pipe; diff --git a/src/net/windows/named_pipe.rs b/src/net/windows/named_pipe.rs new file mode 100644 index 0000000..550fd4d --- /dev/null +++ b/src/net/windows/named_pipe.rs @@ -0,0 +1,2250 @@ +//! Tokio support for [Windows named pipes]. +//! +//! [Windows named pipes]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes + +use std::ffi::c_void; +use std::ffi::OsStr; +use std::io::{self, Read, Write}; +use std::pin::Pin; +use std::ptr; +use std::task::{Context, Poll}; + +use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready}; +use crate::os::windows::io::{AsRawHandle, FromRawHandle, RawHandle}; + +// Hide imports which are not used when generating documentation. +#[cfg(not(docsrs))] +mod doc { + pub(super) use crate::os::windows::ffi::OsStrExt; + pub(super) use crate::winapi::shared::minwindef::{DWORD, FALSE}; + pub(super) use crate::winapi::um::fileapi; + pub(super) use crate::winapi::um::handleapi; + pub(super) use crate::winapi::um::namedpipeapi; + pub(super) use crate::winapi::um::winbase; + pub(super) use crate::winapi::um::winnt; + + pub(super) use mio::windows as mio_windows; +} + +// NB: none of these shows up in public API, so don't document them. +#[cfg(docsrs)] +mod doc { + pub type DWORD = crate::doc::NotDefinedHere; + + pub(super) mod mio_windows { + pub type NamedPipe = crate::doc::NotDefinedHere; + } +} + +use self::doc::*; + +/// A [Windows named pipe] server. +/// +/// Accepting client connections involves creating a server with +/// [`ServerOptions::create`] and waiting for clients to connect using +/// [`NamedPipeServer::connect`]. +/// +/// To avoid having clients sporadically fail with +/// [`std::io::ErrorKind::NotFound`] when they connect to a server, we must +/// ensure that at least one server instance is available at all times. This +/// means that the typical listen loop for a server is a bit involved, because +/// we have to ensure that we never drop a server accidentally while a client +/// might connect. +/// +/// So a correctly implemented server looks like this: +/// +/// ```no_run +/// use std::io; +/// use tokio::net::windows::named_pipe::ServerOptions; +/// +/// const PIPE_NAME: &str = r"\\.\pipe\named-pipe-idiomatic-server"; +/// +/// # #[tokio::main] async fn main() -> std::io::Result<()> { +/// // The first server needs to be constructed early so that clients can +/// // be correctly connected. Otherwise calling .wait will cause the client to +/// // error. +/// // +/// // Here we also make use of `first_pipe_instance`, which will ensure that +/// // there are no other servers up and running already. +/// let mut server = ServerOptions::new() +/// .first_pipe_instance(true) +/// .create(PIPE_NAME)?; +/// +/// // Spawn the server loop. +/// let server = tokio::spawn(async move { +/// loop { +/// // Wait for a client to connect. +/// let connected = server.connect().await?; +/// +/// // Construct the next server to be connected before sending the one +/// // we already have of onto a task. This ensures that the server +/// // isn't closed (after it's done in the task) before a new one is +/// // available. Otherwise the client might error with +/// // `io::ErrorKind::NotFound`. +/// server = ServerOptions::new().create(PIPE_NAME)?; +/// +/// let client = tokio::spawn(async move { +/// /* use the connected client */ +/// # Ok::<_, std::io::Error>(()) +/// }); +/// # if true { break } // needed for type inference to work +/// } +/// +/// Ok::<_, io::Error>(()) +/// }); +/// +/// /* do something else not server related here */ +/// # Ok(()) } +/// ``` +/// +/// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY +/// [Windows named pipe]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes +#[derive(Debug)] +pub struct NamedPipeServer { + io: PollEvented<mio_windows::NamedPipe>, +} + +impl NamedPipeServer { + /// Constructs a new named pipe server from the specified raw handle. + /// + /// This function will consume ownership of the handle given, passing + /// responsibility for closing the handle to the returned object. + /// + /// This function is also unsafe as the primitives currently returned have + /// the contract that they are the sole owner of the file descriptor they + /// are wrapping. Usage of this function could accidentally allow violating + /// this contract which can cause memory unsafety in code that relies on it + /// being true. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + pub unsafe fn from_raw_handle(handle: RawHandle) -> io::Result<Self> { + let named_pipe = mio_windows::NamedPipe::from_raw_handle(handle); + + Ok(Self { + io: PollEvented::new(named_pipe)?, + }) + } + + /// Retrieves information about the named pipe the server is associated + /// with. + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::{PipeEnd, PipeMode, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-info"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new() + /// .pipe_mode(PipeMode::Message) + /// .max_instances(5) + /// .create(PIPE_NAME)?; + /// + /// let server_info = server.info()?; + /// + /// assert_eq!(server_info.end, PipeEnd::Server); + /// assert_eq!(server_info.mode, PipeMode::Message); + /// assert_eq!(server_info.max_instances, 5); + /// # Ok(()) } + /// ``` + pub fn info(&self) -> io::Result<PipeInfo> { + // Safety: we're ensuring the lifetime of the named pipe. + unsafe { named_pipe_info(self.io.as_raw_handle()) } + } + + /// Enables a named pipe server process to wait for a client process to + /// connect to an instance of a named pipe. A client process connects by + /// creating a named pipe with the same name. + /// + /// This corresponds to the [`ConnectNamedPipe`] system call. + /// + /// # Cancel safety + /// + /// This method is cancellation safe in the sense that if it is used as the + /// event in a [`select!`](crate::select) statement and some other branch + /// completes first, then no connection events have been lost. + /// + /// [`ConnectNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/namedpipeapi/nf-namedpipeapi-connectnamedpipe + /// + /// # Example + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\mynamedpipe"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let pipe = ServerOptions::new().create(PIPE_NAME)?; + /// + /// // Wait for a client to connect. + /// pipe.connect().await?; + /// + /// // Use the connected client... + /// # Ok(()) } + /// ``` + pub async fn connect(&self) -> io::Result<()> { + loop { + match self.io.connect() { + Ok(()) => break, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + self.io.registration().readiness(Interest::WRITABLE).await?; + } + Err(e) => return Err(e), + } + } + + Ok(()) + } + + /// Disconnects the server end of a named pipe instance from a client + /// process. + /// + /// ``` + /// use tokio::io::AsyncWriteExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-disconnect"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .open(PIPE_NAME)?; + /// + /// // Wait for a client to become connected. + /// server.connect().await?; + /// + /// // Forcibly disconnect the client. + /// server.disconnect()?; + /// + /// // Write fails with an OS-specific error after client has been + /// // disconnected. + /// let e = client.write(b"ping").await.unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_NOT_CONNECTED as i32)); + /// # Ok(()) } + /// ``` + pub fn disconnect(&self) -> io::Result<()> { + self.io.disconnect() + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same pipe on a single + /// task without splitting the pipe. + /// + /// # Examples + /// + /// Concurrently read and write to the pipe on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-ready"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// let ready = server.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the pipe to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-readable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the pipe is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for reading, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_read_ready` or `poll_read`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_write_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Tries to read data from the pipe into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeServer::readable() + /// [`ready()`]: NamedPipeServer::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-read"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Tries to read data from the pipe into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: NamedPipeServer::try_read() + /// [`readable()`]: NamedPipeServer::readable() + /// [`ready()`]: NamedPipeServer::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-read-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// server.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + + /// Waits for the pipe to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-writable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the pipe is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for writing, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Tries to write a buffer to the pipe, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-write"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + + /// Tries to write several buffers to the pipe, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: NamedPipeServer::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-server-try-write-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let server = named_pipe::ServerOptions::new() + /// .create(PIPE_NAME)?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the pipe to be writable + /// server.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the + /// methods defined on the Tokio `NamedPipeServer` type, as this will mess with + /// the readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeServer::readable() + /// [`writable()`]: NamedPipeServer::writable() + /// [`ready()`]: NamedPipeServer::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } +} + +impl AsyncRead for NamedPipeServer { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unsafe { self.io.poll_read(cx, buf) } + } +} + +impl AsyncWrite for NamedPipeServer { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.io.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.io.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsRawHandle for NamedPipeServer { + fn as_raw_handle(&self) -> RawHandle { + self.io.as_raw_handle() + } +} + +/// A [Windows named pipe] client. +/// +/// Constructed using [`ClientOptions::open`]. +/// +/// Connecting a client correctly involves a few steps. When connecting through +/// [`ClientOptions::open`], it might error indicating one of two things: +/// +/// * [`std::io::ErrorKind::NotFound`] - There is no server available. +/// * [`ERROR_PIPE_BUSY`] - There is a server available, but it is busy. Sleep +/// for a while and try again. +/// +/// So a correctly implemented client looks like this: +/// +/// ```no_run +/// use std::time::Duration; +/// use tokio::net::windows::named_pipe::ClientOptions; +/// use tokio::time; +/// use winapi::shared::winerror; +/// +/// const PIPE_NAME: &str = r"\\.\pipe\named-pipe-idiomatic-client"; +/// +/// # #[tokio::main] async fn main() -> std::io::Result<()> { +/// let client = loop { +/// match ClientOptions::new().open(PIPE_NAME) { +/// Ok(client) => break client, +/// Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), +/// Err(e) => return Err(e), +/// } +/// +/// time::sleep(Duration::from_millis(50)).await; +/// }; +/// +/// /* use the connected client */ +/// # Ok(()) } +/// ``` +/// +/// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY +/// [Windows named pipe]: https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes +#[derive(Debug)] +pub struct NamedPipeClient { + io: PollEvented<mio_windows::NamedPipe>, +} + +impl NamedPipeClient { + /// Constructs a new named pipe client from the specified raw handle. + /// + /// This function will consume ownership of the handle given, passing + /// responsibility for closing the handle to the returned object. + /// + /// This function is also unsafe as the primitives currently returned have + /// the contract that they are the sole owner of the file descriptor they + /// are wrapping. Usage of this function could accidentally allow violating + /// this contract which can cause memory unsafety in code that relies on it + /// being true. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + pub unsafe fn from_raw_handle(handle: RawHandle) -> io::Result<Self> { + let named_pipe = mio_windows::NamedPipe::from_raw_handle(handle); + + Ok(Self { + io: PollEvented::new(named_pipe)?, + }) + } + + /// Retrieves information about the named pipe the client is associated + /// with. + /// + /// ```no_run + /// use tokio::net::windows::named_pipe::{ClientOptions, PipeEnd, PipeMode}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-info"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let client = ClientOptions::new() + /// .open(PIPE_NAME)?; + /// + /// let client_info = client.info()?; + /// + /// assert_eq!(client_info.end, PipeEnd::Client); + /// assert_eq!(client_info.mode, PipeMode::Message); + /// assert_eq!(client_info.max_instances, 5); + /// # Ok(()) } + /// ``` + pub fn info(&self) -> io::Result<PipeInfo> { + // Safety: we're ensuring the lifetime of the named pipe. + unsafe { named_pipe_info(self.io.as_raw_handle()) } + } + + /// Waits for any of the requested ready states. + /// + /// This function is usually paired with `try_read()` or `try_write()`. It + /// can be used to concurrently read / write to the same pipe on a single + /// task without splitting the pipe. + /// + /// # Examples + /// + /// Concurrently read and write to the pipe on the same task without + /// splitting. + /// + /// ```no_run + /// use tokio::io::Interest; + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-ready"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// let ready = client.ready(Interest::READABLE | Interest::WRITABLE).await?; + /// + /// if ready.is_readable() { + /// let mut data = vec![0; 1024]; + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut data) { + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// if ready.is_writable() { + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// println!("write {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// } + /// } + /// ``` + pub async fn ready(&self, interest: Interest) -> io::Result<Ready> { + let event = self.io.registration().readiness(interest).await?; + Ok(event.ready) + } + + /// Waits for the pipe to become readable. + /// + /// This function is equivalent to `ready(Interest::READABLE)` and is usually + /// paired with `try_read()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-readable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// let mut msg = vec![0; 1024]; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut msg) { + /// Ok(n) => { + /// msg.truncate(n); + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// println!("GOT = {:?}", msg); + /// Ok(()) + /// } + /// ``` + pub async fn readable(&self) -> io::Result<()> { + self.ready(Interest::READABLE).await?; + Ok(()) + } + + /// Polls for read readiness. + /// + /// If the pipe is not currently ready for reading, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for reading, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_read_ready` or `poll_read`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_write_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for reading. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for reading. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + + /// Tries to read data from the pipe into the provided buffer, returning how + /// many bytes were read. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeClient::readable() + /// [`ready()`]: NamedPipeClient::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-read"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf = [0; 4096]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read(buf)) + } + + /// Tries to read data from the pipe into the provided buffers, returning + /// how many bytes were read. + /// + /// Data is copied to fill each buffer in order, with the final buffer + /// written to possibly being only partially filled. This method behaves + /// equivalently to a single call to [`try_read()`] with concatenated + /// buffers. + /// + /// Receives any pending data from the pipe but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_vectored()` is non-blocking, the buffer does not have to be + /// stored by the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`try_read()`]: NamedPipeClient::try_read() + /// [`readable()`]: NamedPipeClient::readable() + /// [`ready()`]: NamedPipeClient::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the pipe's read half is closed + /// and will no longer yield data. If the pipe is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io::{self, IoSliceMut}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-read-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be readable + /// client.readable().await?; + /// + /// // Creating the buffer **after** the `await` prevents it from + /// // being stored in the async task. + /// let mut buf_a = [0; 512]; + /// let mut buf_b = [0; 1024]; + /// let mut bufs = [ + /// IoSliceMut::new(&mut buf_a), + /// IoSliceMut::new(&mut buf_b), + /// ]; + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read_vectored(&mut bufs) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_vectored(&self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) + } + + /// Waits for the pipe to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-writable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn writable(&self) -> io::Result<()> { + self.ready(Interest::WRITABLE).await?; + Ok(()) + } + + /// Polls for write readiness. + /// + /// If the pipe is not currently ready for writing, this method will + /// store a clone of the `Waker` from the provided `Context`. When the pipe + /// becomes ready for writing, `Waker::wake` will be called on the waker. + /// + /// Note that on multiple calls to `poll_write_ready` or `poll_write`, only + /// the `Waker` from the `Context` passed to the most recent call is + /// scheduled to receive a wakeup. (However, `poll_read_ready` retains a + /// second, independent waker.) + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// # Return value + /// + /// The function returns: + /// + /// * `Poll::Pending` if the pipe is not ready for writing. + /// * `Poll::Ready(Ok(()))` if the pipe is ready for writing. + /// * `Poll::Ready(Err(e))` if an error is encountered. + /// + /// # Errors + /// + /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + + /// Tries to write a buffer to the pipe, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-write"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write(b"hello world") { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write(buf)) + } + + /// Tries to write several buffers to the pipe, returning how many bytes + /// were written. + /// + /// Data is written from each buffer in order, with the final buffer read + /// from possible being only partially consumed. This method behaves + /// equivalently to a single call to [`try_write()`] with concatenated + /// buffers. + /// + /// This function is usually paired with `writable()`. + /// + /// [`try_write()`]: NamedPipeClient::try_write() + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the pipe is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-try-write-vectored"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box<dyn Error>> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// let bufs = [io::IoSlice::new(b"hello "), io::IoSlice::new(b"world")]; + /// + /// loop { + /// // Wait for the pipe to be writable + /// client.writable().await?; + /// + /// // Try to write data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_write_vectored(&bufs) { + /// Ok(n) => { + /// break; + /// } + /// Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_write_vectored(&self, buf: &[io::IoSlice<'_>]) -> io::Result<usize> { + self.io + .registration() + .try_io(Interest::WRITABLE, || (&*self.io).write_vectored(buf)) + } + + /// Tries to read or write from the socket using a user-provided IO operation. + /// + /// If the socket is ready, the provided closure is called. The closure + /// should attempt to perform IO operation from the socket by manually + /// calling the appropriate syscall. If the operation fails because the + /// socket is not actually ready, then the closure should return a + /// `WouldBlock` error and the readiness flag is cleared. The return value + /// of the closure is then returned by `try_io`. + /// + /// If the socket is not ready, then the closure is not called + /// and a `WouldBlock` error is returned. + /// + /// The closure should only return a `WouldBlock` error if it has performed + /// an IO operation on the socket that failed due to the socket not being + /// ready. Returning a `WouldBlock` error in any other situation will + /// incorrectly clear the readiness flag, which can cause the socket to + /// behave incorrectly. + /// + /// The closure should not perform the IO operation using any of the methods + /// defined on the Tokio `NamedPipeClient` type, as this will mess with the + /// readiness flag and can cause the socket to behave incorrectly. + /// + /// Usually, [`readable()`], [`writable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeClient::readable() + /// [`writable()`]: NamedPipeClient::writable() + /// [`ready()`]: NamedPipeClient::ready() + pub fn try_io<R>( + &self, + interest: Interest, + f: impl FnOnce() -> io::Result<R>, + ) -> io::Result<R> { + self.io.registration().try_io(interest, f) + } +} + +impl AsyncRead for NamedPipeClient { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + unsafe { self.io.poll_read(cx, buf) } + } +} + +impl AsyncWrite for NamedPipeClient { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.io.poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll<io::Result<usize>> { + self.io.poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + self.poll_flush(cx) + } +} + +impl AsRawHandle for NamedPipeClient { + fn as_raw_handle(&self) -> RawHandle { + self.io.as_raw_handle() + } +} + +// Helper to set a boolean flag as a bitfield. +macro_rules! bool_flag { + ($f:expr, $t:expr, $flag:expr) => {{ + let current = $f; + + if $t { + $f = current | $flag; + } else { + $f = current & !$flag; + }; + }}; +} + +/// A builder structure for construct a named pipe with named pipe-specific +/// options. This is required to use for named pipe servers who wants to modify +/// pipe-related options. +/// +/// See [`ServerOptions::create`]. +#[derive(Debug, Clone)] +pub struct ServerOptions { + open_mode: DWORD, + pipe_mode: DWORD, + max_instances: DWORD, + out_buffer_size: DWORD, + in_buffer_size: DWORD, + default_timeout: DWORD, +} + +impl ServerOptions { + /// Creates a new named pipe builder with the default settings. + /// + /// ``` + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-new"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn new() -> ServerOptions { + ServerOptions { + open_mode: winbase::PIPE_ACCESS_DUPLEX | winbase::FILE_FLAG_OVERLAPPED, + pipe_mode: winbase::PIPE_TYPE_BYTE | winbase::PIPE_REJECT_REMOTE_CLIENTS, + max_instances: winbase::PIPE_UNLIMITED_INSTANCES, + out_buffer_size: 65536, + in_buffer_size: 65536, + default_timeout: 0, + } + } + + /// The pipe mode. + /// + /// The default pipe mode is [`PipeMode::Byte`]. See [`PipeMode`] for + /// documentation of what each mode means. + /// + /// This corresponding to specifying [`dwPipeMode`]. + /// + /// [`dwPipeMode`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self { + self.pipe_mode = match pipe_mode { + PipeMode::Byte => winbase::PIPE_TYPE_BYTE, + PipeMode::Message => winbase::PIPE_TYPE_MESSAGE, + }; + + self + } + + /// The flow of data in the pipe goes from client to server only. + /// + /// This corresponds to setting [`PIPE_ACCESS_INBOUND`]. + /// + /// [`PIPE_ACCESS_INBOUND`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_access_inbound + /// + /// # Errors + /// + /// Server side prevents connecting by denying inbound access, client errors + /// with [`std::io::ErrorKind::PermissionDenied`] when attempting to create + /// the connection. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound-err1"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let _server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let e = ClientOptions::new() + /// .open(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// Disabling writing allows a client to connect, but errors with + /// [`std::io::ErrorKind::PermissionDenied`] if a write is attempted. + /// + /// ``` + /// use std::io; + /// use tokio::io::AsyncWriteExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound-err2"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .write(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let e = client.write(b"ping").await.unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// A unidirectional named pipe that only supports server-to-client + /// communication. + /// + /// ``` + /// use std::io; + /// use tokio::io::{AsyncReadExt, AsyncWriteExt}; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-inbound"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut server = ServerOptions::new() + /// .access_inbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .write(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let write = server.write_all(b"ping"); + /// + /// let mut buf = [0u8; 4]; + /// let read = client.read_exact(&mut buf); + /// + /// let ((), read) = tokio::try_join!(write, read)?; + /// + /// assert_eq!(read, 4); + /// assert_eq!(&buf[..], b"ping"); + /// # Ok(()) } + /// ``` + pub fn access_inbound(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.open_mode, allowed, winbase::PIPE_ACCESS_INBOUND); + self + } + + /// The flow of data in the pipe goes from server to client only. + /// + /// This corresponds to setting [`PIPE_ACCESS_OUTBOUND`]. + /// + /// [`PIPE_ACCESS_OUTBOUND`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_access_outbound + /// + /// # Errors + /// + /// Server side prevents connecting by denying outbound access, client + /// errors with [`std::io::ErrorKind::PermissionDenied`] when attempting to + /// create the connection. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound-err1"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let e = ClientOptions::new() + /// .open(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// Disabling reading allows a client to connect, but attempting to read + /// will error with [`std::io::ErrorKind::PermissionDenied`]. + /// + /// ``` + /// use std::io; + /// use tokio::io::AsyncReadExt; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound-err2"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .read(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let mut buf = [0u8; 4]; + /// let e = client.read(&mut buf).await.unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// A unidirectional named pipe that only supports client-to-server + /// communication. + /// + /// ``` + /// use tokio::io::{AsyncReadExt, AsyncWriteExt}; + /// use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-access-outbound"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let mut server = ServerOptions::new() + /// .access_outbound(false) + /// .create(PIPE_NAME)?; + /// + /// let mut client = ClientOptions::new() + /// .read(false) + /// .open(PIPE_NAME)?; + /// + /// server.connect().await?; + /// + /// let write = client.write_all(b"ping"); + /// + /// let mut buf = [0u8; 4]; + /// let read = server.read_exact(&mut buf); + /// + /// let ((), read) = tokio::try_join!(write, read)?; + /// + /// println!("done reading and writing"); + /// + /// assert_eq!(read, 4); + /// assert_eq!(&buf[..], b"ping"); + /// # Ok(()) } + /// ``` + pub fn access_outbound(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.open_mode, allowed, winbase::PIPE_ACCESS_OUTBOUND); + self + } + + /// If you attempt to create multiple instances of a pipe with this flag + /// set, creation of the first server instance succeeds, but creation of any + /// subsequent instances will fail with + /// [`std::io::ErrorKind::PermissionDenied`]. + /// + /// This option is intended to be used with servers that want to ensure that + /// they are the only process listening for clients on a given named pipe. + /// This is accomplished by enabling it for the first server instance + /// created in a process. + /// + /// This corresponds to setting [`FILE_FLAG_FIRST_PIPE_INSTANCE`]. + /// + /// # Errors + /// + /// If this option is set and more than one instance of the server for a + /// given named pipe exists, calling [`create`] will fail with + /// [`std::io::ErrorKind::PermissionDenied`]. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-first-instance-error"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let server1 = ServerOptions::new() + /// .first_pipe_instance(true) + /// .create(PIPE_NAME)?; + /// + /// // Second server errs, since it's not the first instance. + /// let e = ServerOptions::new() + /// .first_pipe_instance(true) + /// .create(PIPE_NAME) + /// .unwrap_err(); + /// + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// # Ok(()) } + /// ``` + /// + /// # Examples + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-first-instance"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut builder = ServerOptions::new(); + /// builder.first_pipe_instance(true); + /// + /// let server = builder.create(PIPE_NAME)?; + /// let e = builder.create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.kind(), io::ErrorKind::PermissionDenied); + /// drop(server); + /// + /// // OK: since, we've closed the other instance. + /// let _server2 = builder.create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + /// + /// [`create`]: ServerOptions::create + /// [`FILE_FLAG_FIRST_PIPE_INSTANCE`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_first_pipe_instance + pub fn first_pipe_instance(&mut self, first: bool) -> &mut Self { + bool_flag!( + self.open_mode, + first, + winbase::FILE_FLAG_FIRST_PIPE_INSTANCE + ); + self + } + + /// Indicates whether this server can accept remote clients or not. Remote + /// clients are disabled by default. + /// + /// This corresponds to setting [`PIPE_REJECT_REMOTE_CLIENTS`]. + /// + /// [`PIPE_REJECT_REMOTE_CLIENTS`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea#pipe_reject_remote_clients + pub fn reject_remote_clients(&mut self, reject: bool) -> &mut Self { + bool_flag!(self.pipe_mode, reject, winbase::PIPE_REJECT_REMOTE_CLIENTS); + self + } + + /// The maximum number of instances that can be created for this pipe. The + /// first instance of the pipe can specify this value; the same number must + /// be specified for other instances of the pipe. Acceptable values are in + /// the range 1 through 254. The default value is unlimited. + /// + /// This corresponds to specifying [`nMaxInstances`]. + /// + /// [`nMaxInstances`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + /// + /// # Errors + /// + /// The same numbers of `max_instances` have to be used by all servers. Any + /// additional servers trying to be built which uses a mismatching value + /// might error. + /// + /// ``` + /// use std::io; + /// use tokio::net::windows::named_pipe::{ServerOptions, ClientOptions}; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-max-instances"; + /// + /// # #[tokio::main] async fn main() -> io::Result<()> { + /// let mut server = ServerOptions::new(); + /// server.max_instances(2); + /// + /// let s1 = server.create(PIPE_NAME)?; + /// let c1 = ClientOptions::new().open(PIPE_NAME); + /// + /// let s2 = server.create(PIPE_NAME)?; + /// let c2 = ClientOptions::new().open(PIPE_NAME); + /// + /// // Too many servers! + /// let e = server.create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_BUSY as i32)); + /// + /// // Still too many servers even if we specify a higher value! + /// let e = server.max_instances(100).create(PIPE_NAME).unwrap_err(); + /// assert_eq!(e.raw_os_error(), Some(winerror::ERROR_PIPE_BUSY as i32)); + /// # Ok(()) } + /// ``` + /// + /// # Panics + /// + /// This function will panic if more than 254 instances are specified. If + /// you do not wish to set an instance limit, leave it unspecified. + /// + /// ```should_panic + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let builder = ServerOptions::new().max_instances(255); + /// # Ok(()) } + /// ``` + pub fn max_instances(&mut self, instances: usize) -> &mut Self { + assert!(instances < 255, "cannot specify more than 254 instances"); + self.max_instances = instances as DWORD; + self + } + + /// The number of bytes to reserve for the output buffer. + /// + /// This corresponds to specifying [`nOutBufferSize`]. + /// + /// [`nOutBufferSize`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn out_buffer_size(&mut self, buffer: u32) -> &mut Self { + self.out_buffer_size = buffer as DWORD; + self + } + + /// The number of bytes to reserve for the input buffer. + /// + /// This corresponds to specifying [`nInBufferSize`]. + /// + /// [`nInBufferSize`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + pub fn in_buffer_size(&mut self, buffer: u32) -> &mut Self { + self.in_buffer_size = buffer as DWORD; + self + } + + /// Creates the named pipe identified by `addr` for use as a server. + /// + /// This uses the [`CreateNamedPipe`] function. + /// + /// [`CreateNamedPipe`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// + /// # Examples + /// + /// ``` + /// use tokio::net::windows::named_pipe::ServerOptions; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-create"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn create(&self, addr: impl AsRef<OsStr>) -> io::Result<NamedPipeServer> { + // Safety: We're calling create_with_security_attributes_raw w/ a null + // pointer which disables it. + unsafe { self.create_with_security_attributes_raw(addr, ptr::null_mut()) } + } + + /// Creates the named pipe identified by `addr` for use as a server. + /// + /// This is the same as [`create`] except that it supports providing the raw + /// pointer to a structure of [`SECURITY_ATTRIBUTES`] which will be passed + /// as the `lpSecurityAttributes` argument to [`CreateFile`]. + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// [Tokio Runtime]: crate::runtime::Runtime + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// + /// # Safety + /// + /// The `attrs` argument must either be null or point at a valid instance of + /// the [`SECURITY_ATTRIBUTES`] structure. If the argument is null, the + /// behavior is identical to calling the [`create`] method. + /// + /// [`create`]: ServerOptions::create + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + /// [`SECURITY_ATTRIBUTES`]: crate::winapi::um::minwinbase::SECURITY_ATTRIBUTES + pub unsafe fn create_with_security_attributes_raw( + &self, + addr: impl AsRef<OsStr>, + attrs: *mut c_void, + ) -> io::Result<NamedPipeServer> { + let addr = encode_addr(addr); + + let h = namedpipeapi::CreateNamedPipeW( + addr.as_ptr(), + self.open_mode, + self.pipe_mode, + self.max_instances, + self.out_buffer_size, + self.in_buffer_size, + self.default_timeout, + attrs as *mut _, + ); + + if h == handleapi::INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + NamedPipeServer::from_raw_handle(h) + } +} + +/// A builder suitable for building and interacting with named pipes from the +/// client side. +/// +/// See [`ClientOptions::open`]. +#[derive(Debug, Clone)] +pub struct ClientOptions { + desired_access: DWORD, + security_qos_flags: DWORD, +} + +impl ClientOptions { + /// Creates a new named pipe builder with the default settings. + /// + /// ``` + /// use tokio::net::windows::named_pipe::{ServerOptions, ClientOptions}; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-new"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// // Server must be created in order for the client creation to succeed. + /// let server = ServerOptions::new().create(PIPE_NAME)?; + /// let client = ClientOptions::new().open(PIPE_NAME)?; + /// # Ok(()) } + /// ``` + pub fn new() -> Self { + Self { + desired_access: winnt::GENERIC_READ | winnt::GENERIC_WRITE, + security_qos_flags: winbase::SECURITY_IDENTIFICATION | winbase::SECURITY_SQOS_PRESENT, + } + } + + /// If the client supports reading data. This is enabled by default. + /// + /// This corresponds to setting [`GENERIC_READ`] in the call to [`CreateFile`]. + /// + /// [`GENERIC_READ`]: https://docs.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + pub fn read(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.desired_access, allowed, winnt::GENERIC_READ); + self + } + + /// If the created pipe supports writing data. This is enabled by default. + /// + /// This corresponds to setting [`GENERIC_WRITE`] in the call to [`CreateFile`]. + /// + /// [`GENERIC_WRITE`]: https://docs.microsoft.com/en-us/windows/win32/secauthz/generic-access-rights + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + pub fn write(&mut self, allowed: bool) -> &mut Self { + bool_flag!(self.desired_access, allowed, winnt::GENERIC_WRITE); + self + } + + /// Sets qos flags which are combined with other flags and attributes in the + /// call to [`CreateFile`]. + /// + /// By default `security_qos_flags` is set to [`SECURITY_IDENTIFICATION`], + /// calling this function would override that value completely with the + /// argument specified. + /// + /// When `security_qos_flags` is not set, a malicious program can gain the + /// elevated privileges of a privileged Rust process when it allows opening + /// user-specified paths, by tricking it into opening a named pipe. So + /// arguably `security_qos_flags` should also be set when opening arbitrary + /// paths. However the bits can then conflict with other flags, specifically + /// `FILE_FLAG_OPEN_NO_RECALL`. + /// + /// For information about possible values, see [Impersonation Levels] on the + /// Windows Dev Center site. The `SECURITY_SQOS_PRESENT` flag is set + /// automatically when using this method. + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// [`SECURITY_IDENTIFICATION`]: crate::winapi::um::winbase::SECURITY_IDENTIFICATION + /// [Impersonation Levels]: https://docs.microsoft.com/en-us/windows/win32/api/winnt/ne-winnt-security_impersonation_level + pub fn security_qos_flags(&mut self, flags: u32) -> &mut Self { + // See: https://github.com/rust-lang/rust/pull/58216 + self.security_qos_flags = flags | winbase::SECURITY_SQOS_PRESENT; + self + } + + /// Opens the named pipe identified by `addr`. + /// + /// This opens the client using [`CreateFile`] with the + /// `dwCreationDisposition` option set to `OPEN_EXISTING`. + /// + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilea + /// + /// # Errors + /// + /// This errors if called outside of a [Tokio Runtime], or in a runtime that + /// has not [enabled I/O], or if any OS-specific I/O errors occur. + /// + /// There are a few errors you need to take into account when creating a + /// named pipe on the client side: + /// + /// * [`std::io::ErrorKind::NotFound`] - This indicates that the named pipe + /// does not exist. Presumably the server is not up. + /// * [`ERROR_PIPE_BUSY`] - This error is raised when the named pipe exists, + /// but the server is not currently waiting for a connection. Please see the + /// examples for how to check for this error. + /// + /// [`ERROR_PIPE_BUSY`]: crate::winapi::shared::winerror::ERROR_PIPE_BUSY + /// [`winapi`]: crate::winapi + /// [enabled I/O]: crate::runtime::Builder::enable_io + /// [Tokio Runtime]: crate::runtime::Runtime + /// + /// A connect loop that waits until a pipe becomes available looks like + /// this: + /// + /// ```no_run + /// use std::time::Duration; + /// use tokio::net::windows::named_pipe::ClientOptions; + /// use tokio::time; + /// use winapi::shared::winerror; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\mynamedpipe"; + /// + /// # #[tokio::main] async fn main() -> std::io::Result<()> { + /// let client = loop { + /// match ClientOptions::new().open(PIPE_NAME) { + /// Ok(client) => break client, + /// Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), + /// Err(e) => return Err(e), + /// } + /// + /// time::sleep(Duration::from_millis(50)).await; + /// }; + /// + /// // use the connected client. + /// # Ok(()) } + /// ``` + pub fn open(&self, addr: impl AsRef<OsStr>) -> io::Result<NamedPipeClient> { + // Safety: We're calling open_with_security_attributes_raw w/ a null + // pointer which disables it. + unsafe { self.open_with_security_attributes_raw(addr, ptr::null_mut()) } + } + + /// Opens the named pipe identified by `addr`. + /// + /// This is the same as [`open`] except that it supports providing the raw + /// pointer to a structure of [`SECURITY_ATTRIBUTES`] which will be passed + /// as the `lpSecurityAttributes` argument to [`CreateFile`]. + /// + /// # Safety + /// + /// The `attrs` argument must either be null or point at a valid instance of + /// the [`SECURITY_ATTRIBUTES`] structure. If the argument is null, the + /// behavior is identical to calling the [`open`] method. + /// + /// [`open`]: ClientOptions::open + /// [`CreateFile`]: https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-createfilew + /// [`SECURITY_ATTRIBUTES`]: crate::winapi::um::minwinbase::SECURITY_ATTRIBUTES + pub unsafe fn open_with_security_attributes_raw( + &self, + addr: impl AsRef<OsStr>, + attrs: *mut c_void, + ) -> io::Result<NamedPipeClient> { + let addr = encode_addr(addr); + + // NB: We could use a platform specialized `OpenOptions` here, but since + // we have access to winapi it ultimately doesn't hurt to use + // `CreateFile` explicitly since it allows the use of our already + // well-structured wide `addr` to pass into CreateFileW. + let h = fileapi::CreateFileW( + addr.as_ptr(), + self.desired_access, + 0, + attrs as *mut _, + fileapi::OPEN_EXISTING, + self.get_flags(), + ptr::null_mut(), + ); + + if h == handleapi::INVALID_HANDLE_VALUE { + return Err(io::Error::last_os_error()); + } + + NamedPipeClient::from_raw_handle(h) + } + + fn get_flags(&self) -> u32 { + self.security_qos_flags | winbase::FILE_FLAG_OVERLAPPED + } +} + +/// The pipe mode of a named pipe. +/// +/// Set through [`ServerOptions::pipe_mode`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum PipeMode { + /// Data is written to the pipe as a stream of bytes. The pipe does not + /// distinguish bytes written during different write operations. + /// + /// Corresponds to [`PIPE_TYPE_BYTE`][crate::winapi::um::winbase::PIPE_TYPE_BYTE]. + Byte, + /// Data is written to the pipe as a stream of messages. The pipe treats the + /// bytes written during each write operation as a message unit. Any reading + /// on a named pipe returns [`ERROR_MORE_DATA`] when a message is not read + /// completely. + /// + /// Corresponds to [`PIPE_TYPE_MESSAGE`][crate::winapi::um::winbase::PIPE_TYPE_MESSAGE]. + /// + /// [`ERROR_MORE_DATA`]: crate::winapi::shared::winerror::ERROR_MORE_DATA + Message, +} + +/// Indicates the end of a named pipe. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum PipeEnd { + /// The named pipe refers to the client end of a named pipe instance. + /// + /// Corresponds to [`PIPE_CLIENT_END`][crate::winapi::um::winbase::PIPE_CLIENT_END]. + Client, + /// The named pipe refers to the server end of a named pipe instance. + /// + /// Corresponds to [`PIPE_SERVER_END`][crate::winapi::um::winbase::PIPE_SERVER_END]. + Server, +} + +/// Information about a named pipe. +/// +/// Constructed through [`NamedPipeServer::info`] or [`NamedPipeClient::info`]. +#[derive(Debug)] +#[non_exhaustive] +pub struct PipeInfo { + /// Indicates the mode of a named pipe. + pub mode: PipeMode, + /// Indicates the end of a named pipe. + pub end: PipeEnd, + /// The maximum number of instances that can be created for this pipe. + pub max_instances: u32, + /// The number of bytes to reserve for the output buffer. + pub out_buffer_size: u32, + /// The number of bytes to reserve for the input buffer. + pub in_buffer_size: u32, +} + +/// Encodes an address so that it is a null-terminated wide string. +fn encode_addr(addr: impl AsRef<OsStr>) -> Box<[u16]> { + let len = addr.as_ref().encode_wide().count(); + let mut vec = Vec::with_capacity(len + 1); + vec.extend(addr.as_ref().encode_wide()); + vec.push(0); + vec.into_boxed_slice() +} + +/// Internal function to get the info out of a raw named pipe. +unsafe fn named_pipe_info(handle: RawHandle) -> io::Result<PipeInfo> { + let mut flags = 0; + let mut out_buffer_size = 0; + let mut in_buffer_size = 0; + let mut max_instances = 0; + + let result = namedpipeapi::GetNamedPipeInfo( + handle, + &mut flags, + &mut out_buffer_size, + &mut in_buffer_size, + &mut max_instances, + ); + + if result == FALSE { + return Err(io::Error::last_os_error()); + } + + let mut end = PipeEnd::Client; + let mut mode = PipeMode::Byte; + + if flags & winbase::PIPE_SERVER_END != 0 { + end = PipeEnd::Server; + } + + if flags & winbase::PIPE_TYPE_MESSAGE != 0 { + mode = PipeMode::Message; + } + + Ok(PipeInfo { + end, + mode, + out_buffer_size, + in_buffer_size, + max_instances, + }) +} diff --git a/src/park/mod.rs b/src/park/mod.rs index edd9371..87d04ff 100644 --- a/src/park/mod.rs +++ b/src/park/mod.rs @@ -45,12 +45,12 @@ use std::fmt::Debug; use std::sync::Arc; use std::time::Duration; -/// Block the current thread. +/// Blocks the current thread. pub(crate) trait Park { /// Unpark handle type for the `Park` implementation. type Unpark: Unpark; - /// Error returned by `park` + /// Error returned by `park`. type Error: Debug; /// Gets a new `Unpark` handle associated with this `Park` instance. @@ -66,7 +66,7 @@ pub(crate) trait Park { /// /// This function **should** not panic, but ultimately, panics are left as /// an implementation detail. Refer to the documentation for the specific - /// `Park` implementation + /// `Park` implementation. fn park(&mut self) -> Result<(), Self::Error>; /// Parks the current thread for at most `duration`. @@ -82,10 +82,10 @@ pub(crate) trait Park { /// /// This function **should** not panic, but ultimately, panics are left as /// an implementation detail. Refer to the documentation for the specific - /// `Park` implementation + /// `Park` implementation. fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error>; - /// Release all resources holded by the parker for proper leak-free shutdown + /// Releases all resources holded by the parker for proper leak-free shutdown. fn shutdown(&mut self); } @@ -100,7 +100,7 @@ pub(crate) trait Unpark: Sync + Send + 'static { /// /// This function **should** not panic, but ultimately, panics are left as /// an implementation detail. Refer to the documentation for the specific - /// `Unpark` implementation + /// `Unpark` implementation. fn unpark(&self); } diff --git a/src/park/thread.rs b/src/park/thread.rs index 2725e45..27ce202 100644 --- a/src/park/thread.rs +++ b/src/park/thread.rs @@ -76,7 +76,7 @@ impl Park for ParkThread { // ==== impl Inner ==== impl Inner { - /// Park the current thread for at most `dur`. + /// Parks the current thread for at most `dur`. fn park(&self) { // If we were previously notified then we consume this notification and // return quickly. @@ -227,7 +227,7 @@ pub(crate) struct CachedParkThread { } impl CachedParkThread { - /// Create a new `ParkThread` handle for the current thread. + /// Creates a new `ParkThread` handle for the current thread. /// /// This type cannot be moved to other threads, so it should be created on /// the thread that the caller intends to park. @@ -241,7 +241,7 @@ impl CachedParkThread { self.with_current(|park_thread| park_thread.unpark()) } - /// Get a reference to the `ParkThread` handle for this thread. + /// Gets a reference to the `ParkThread` handle for this thread. fn with_current<F, R>(&self, f: F) -> Result<R, ParkError> where F: FnOnce(&ParkThread) -> R, diff --git a/src/process/mod.rs b/src/process/mod.rs index 00e39b0..6eeefdb 100644 --- a/src/process/mod.rs +++ b/src/process/mod.rs @@ -199,6 +199,8 @@ use std::io; #[cfg(unix)] use std::os::unix::process::CommandExt; #[cfg(windows)] +use std::os::windows::io::{AsRawHandle, RawHandle}; +#[cfg(windows)] use std::os::windows::process::CommandExt; use std::path::Path; use std::pin::Pin; @@ -223,9 +225,9 @@ pub struct Command { pub(crate) struct SpawnedChild { child: imp::Child, - stdin: Option<imp::ChildStdin>, - stdout: Option<imp::ChildStdout>, - stderr: Option<imp::ChildStderr>, + stdin: Option<imp::ChildStdio>, + stdout: Option<imp::ChildStdio>, + stderr: Option<imp::ChildStdio>, } impl Command { @@ -551,6 +553,7 @@ impl Command { /// /// [1]: https://msdn.microsoft.com/en-us/library/windows/desktop/ms684863(v=vs.85).aspx #[cfg(windows)] + #[cfg_attr(docsrs, doc(cfg(windows)))] pub fn creation_flags(&mut self, flags: u32) -> &mut Command { self.std.creation_flags(flags); self @@ -560,6 +563,7 @@ impl Command { /// `setuid` call in the child process. Failure in the `setuid` /// call will cause the spawn to fail. #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn uid(&mut self, id: u32) -> &mut Command { self.std.uid(id); self @@ -568,11 +572,26 @@ impl Command { /// Similar to `uid` but sets the group ID of the child process. This has /// the same semantics as the `uid` field. #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub fn gid(&mut self, id: u32) -> &mut Command { self.std.gid(id); self } + /// Sets executable argument. + /// + /// Set the first process argument, `argv[0]`, to something other than the + /// default executable path. + #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] + pub fn arg0<S>(&mut self, arg: S) -> &mut Command + where + S: AsRef<OsStr>, + { + self.std.arg0(arg); + self + } + /// Schedules a closure to be run just before the `exec` function is /// invoked. /// @@ -603,6 +622,7 @@ impl Command { /// working directory have successfully been changed, so output to these /// locations may not appear where intended. #[cfg(unix)] + #[cfg_attr(docsrs, doc(cfg(unix)))] pub unsafe fn pre_exec<F>(&mut self, f: F) -> &mut Command where F: FnMut() -> io::Result<()> + Send + Sync + 'static, @@ -934,6 +954,16 @@ impl Child { } } + /// Extracts the raw handle of the process associated with this child while + /// it is still running. Returns `None` if the child has exited. + #[cfg(windows)] + pub fn raw_handle(&self) -> Option<RawHandle> { + match &self.child { + FusedChild::Child(c) => Some(c.inner.as_raw_handle()), + FusedChild::Done(_) => None, + } + } + /// Attempts to force the child to exit, but does not wait for the request /// to take effect. /// @@ -994,13 +1024,22 @@ impl Child { /// If the caller wishes to explicitly control when the child's stdin /// handle is closed, they may `.take()` it before calling `.wait()`: /// - /// ```no_run + /// ``` + /// # #[cfg(not(unix))]fn main(){} + /// # #[cfg(unix)] /// use tokio::io::AsyncWriteExt; + /// # #[cfg(unix)] /// use tokio::process::Command; + /// # #[cfg(unix)] + /// use std::process::Stdio; /// + /// # #[cfg(unix)] /// #[tokio::main] /// async fn main() { - /// let mut child = Command::new("cat").spawn().unwrap(); + /// let mut child = Command::new("cat") + /// .stdin(Stdio::piped()) + /// .spawn() + /// .unwrap(); /// /// let mut stdin = child.stdin.take().unwrap(); /// tokio::spawn(async move { @@ -1112,7 +1151,7 @@ impl Child { /// handle of a child process asynchronously. #[derive(Debug)] pub struct ChildStdin { - inner: imp::ChildStdin, + inner: imp::ChildStdio, } /// The standard output stream for spawned children. @@ -1121,7 +1160,7 @@ pub struct ChildStdin { /// handle of a child process asynchronously. #[derive(Debug)] pub struct ChildStdout { - inner: imp::ChildStdout, + inner: imp::ChildStdio, } /// The standard error stream for spawned children. @@ -1130,7 +1169,52 @@ pub struct ChildStdout { /// handle of a child process asynchronously. #[derive(Debug)] pub struct ChildStderr { - inner: imp::ChildStderr, + inner: imp::ChildStdio, +} + +impl ChildStdin { + /// Creates an asynchronous `ChildStdin` from a synchronous one. + /// + /// # Errors + /// + /// This method may fail if an error is encountered when setting the pipe to + /// non-blocking mode, or when registering the pipe with the runtime's IO + /// driver. + pub fn from_std(inner: std::process::ChildStdin) -> io::Result<Self> { + Ok(Self { + inner: imp::stdio(inner)?, + }) + } +} + +impl ChildStdout { + /// Creates an asynchronous `ChildStderr` from a synchronous one. + /// + /// # Errors + /// + /// This method may fail if an error is encountered when setting the pipe to + /// non-blocking mode, or when registering the pipe with the runtime's IO + /// driver. + pub fn from_std(inner: std::process::ChildStdout) -> io::Result<Self> { + Ok(Self { + inner: imp::stdio(inner)?, + }) + } +} + +impl ChildStderr { + /// Creates an asynchronous `ChildStderr` from a synchronous one. + /// + /// # Errors + /// + /// This method may fail if an error is encountered when setting the pipe to + /// non-blocking mode, or when registering the pipe with the runtime's IO + /// driver. + pub fn from_std(inner: std::process::ChildStderr) -> io::Result<Self> { + Ok(Self { + inner: imp::stdio(inner)?, + }) + } } impl AsyncWrite for ChildStdin { diff --git a/src/process/unix/driver.rs b/src/process/unix/driver.rs index 110b484..84dc8fb 100644 --- a/src/process/unix/driver.rs +++ b/src/process/unix/driver.rs @@ -1,13 +1,10 @@ #![cfg_attr(not(feature = "rt"), allow(dead_code))] -//! Process driver +//! Process driver. use crate::park::Park; -use crate::process::unix::orphan::ReapOrphanQueue; use crate::process::unix::GlobalOrphanQueue; -use crate::signal::unix::driver::Driver as SignalDriver; -use crate::signal::unix::{signal_with_handle, SignalKind}; -use crate::sync::watch; +use crate::signal::unix::driver::{Driver as SignalDriver, Handle as SignalHandle}; use std::io; use std::time::Duration; @@ -16,51 +13,20 @@ use std::time::Duration; #[derive(Debug)] pub(crate) struct Driver { park: SignalDriver, - inner: CoreDriver<watch::Receiver<()>, GlobalOrphanQueue>, -} - -#[derive(Debug)] -struct CoreDriver<S, Q> { - sigchild: S, - orphan_queue: Q, -} - -trait HasChanged { - fn has_changed(&mut self) -> bool; -} - -impl<T> HasChanged for watch::Receiver<T> { - fn has_changed(&mut self) -> bool { - self.try_has_changed().and_then(Result::ok).is_some() - } -} - -// ===== impl CoreDriver ===== - -impl<S, Q> CoreDriver<S, Q> -where - S: HasChanged, - Q: ReapOrphanQueue, -{ - fn process(&mut self) { - if self.sigchild.has_changed() { - self.orphan_queue.reap_orphans(); - } - } + signal_handle: SignalHandle, } // ===== impl Driver ===== impl Driver { /// Creates a new signal `Driver` instance that delegates wakeups to `park`. - pub(crate) fn new(park: SignalDriver) -> io::Result<Self> { - let sigchild = signal_with_handle(SignalKind::child(), park.handle())?; - let inner = CoreDriver { - sigchild, - orphan_queue: GlobalOrphanQueue, - }; + pub(crate) fn new(park: SignalDriver) -> Self { + let signal_handle = park.handle(); - Ok(Self { park, inner }) + Self { + park, + signal_handle, + } } } @@ -76,13 +42,13 @@ impl Park for Driver { fn park(&mut self) -> Result<(), Self::Error> { self.park.park()?; - self.inner.process(); + GlobalOrphanQueue::reap_orphans(&self.signal_handle); Ok(()) } fn park_timeout(&mut self, duration: Duration) -> Result<(), Self::Error> { self.park.park_timeout(duration)?; - self.inner.process(); + GlobalOrphanQueue::reap_orphans(&self.signal_handle); Ok(()) } @@ -90,43 +56,3 @@ impl Park for Driver { self.park.shutdown() } } - -#[cfg(test)] -mod test { - use super::*; - use crate::process::unix::orphan::test::MockQueue; - - struct MockStream { - total_try_recv: usize, - values: Vec<Option<()>>, - } - - impl MockStream { - fn new(values: Vec<Option<()>>) -> Self { - Self { - total_try_recv: 0, - values, - } - } - } - - impl HasChanged for MockStream { - fn has_changed(&mut self) -> bool { - self.total_try_recv += 1; - self.values.remove(0).is_some() - } - } - - #[test] - fn no_reap_if_no_signal() { - let mut driver = CoreDriver { - sigchild: MockStream::new(vec![None]), - orphan_queue: MockQueue::<()>::new(), - }; - - driver.process(); - - assert_eq!(1, driver.sigchild.total_try_recv); - assert_eq!(0, driver.orphan_queue.total_reaps.get()); - } -} diff --git a/src/process/unix/mod.rs b/src/process/unix/mod.rs index 852a191..576fe6c 100644 --- a/src/process/unix/mod.rs +++ b/src/process/unix/mod.rs @@ -1,4 +1,4 @@ -//! Unix handling of child processes +//! Unix handling of child processes. //! //! Right now the only "fancy" thing about this is how we implement the //! `Future` implementation on `Child` to get the exit status. Unix offers @@ -24,7 +24,7 @@ pub(crate) mod driver; pub(crate) mod orphan; -use orphan::{OrphanQueue, OrphanQueueImpl, ReapOrphanQueue, Wait}; +use orphan::{OrphanQueue, OrphanQueueImpl, Wait}; mod reap; use reap::Reaper; @@ -32,6 +32,7 @@ use reap::Reaper; use crate::io::PollEvented; use crate::process::kill::Kill; use crate::process::SpawnedChild; +use crate::signal::unix::driver::Handle as SignalHandle; use crate::signal::unix::{signal, Signal, SignalKind}; use mio::event::Source; @@ -73,9 +74,9 @@ impl fmt::Debug for GlobalOrphanQueue { } } -impl ReapOrphanQueue for GlobalOrphanQueue { - fn reap_orphans(&self) { - ORPHAN_QUEUE.reap_orphans() +impl GlobalOrphanQueue { + fn reap_orphans(handle: &SignalHandle) { + ORPHAN_QUEUE.reap_orphans(handle) } } @@ -100,9 +101,9 @@ impl fmt::Debug for Child { pub(crate) fn spawn_child(cmd: &mut std::process::Command) -> io::Result<SpawnedChild> { let mut child = cmd.spawn()?; - let stdin = stdio(child.stdin.take())?; - let stdout = stdio(child.stdout.take())?; - let stderr = stdio(child.stderr.take())?; + let stdin = child.stdin.take().map(stdio).transpose()?; + let stdout = child.stdout.take().map(stdio).transpose()?; + let stderr = child.stderr.take().map(stdio).transpose()?; let signal = signal(SignalKind::child())?; @@ -212,9 +213,7 @@ impl Source for Pipe { } } -pub(crate) type ChildStdin = PollEvented<Pipe>; -pub(crate) type ChildStdout = PollEvented<Pipe>; -pub(crate) type ChildStderr = PollEvented<Pipe>; +pub(crate) type ChildStdio = PollEvented<Pipe>; fn set_nonblocking<T: AsRawFd>(fd: &mut T, nonblocking: bool) -> io::Result<()> { unsafe { @@ -239,18 +238,13 @@ fn set_nonblocking<T: AsRawFd>(fd: &mut T, nonblocking: bool) -> io::Result<()> Ok(()) } -fn stdio<T>(option: Option<T>) -> io::Result<Option<PollEvented<Pipe>>> +pub(super) fn stdio<T>(io: T) -> io::Result<PollEvented<Pipe>> where T: IntoRawFd, { - let io = match option { - Some(io) => io, - None => return Ok(None), - }; - // Set the fd to nonblocking before we pass it to the event loop let mut pipe = Pipe::from(io); set_nonblocking(&mut pipe, true)?; - Ok(Some(PollEvented::new(pipe)?)) + PollEvented::new(pipe) } diff --git a/src/process/unix/orphan.rs b/src/process/unix/orphan.rs index 8a1e127..1b0022c 100644 --- a/src/process/unix/orphan.rs +++ b/src/process/unix/orphan.rs @@ -1,6 +1,9 @@ +use crate::loom::sync::{Mutex, MutexGuard}; +use crate::signal::unix::driver::Handle as SignalHandle; +use crate::signal::unix::{signal_with_handle, SignalKind}; +use crate::sync::watch; use std::io; use std::process::ExitStatus; -use std::sync::Mutex; /// An interface for waiting on a process to exit. pub(crate) trait Wait { @@ -20,21 +23,8 @@ impl<T: Wait> Wait for &mut T { } } -/// An interface for reaping a set of orphaned processes. -pub(crate) trait ReapOrphanQueue { - /// Attempts to reap every process in the queue, ignoring any errors and - /// enqueueing any orphans which have not yet exited. - fn reap_orphans(&self); -} - -impl<T: ReapOrphanQueue> ReapOrphanQueue for &T { - fn reap_orphans(&self) { - (**self).reap_orphans() - } -} - /// An interface for queueing up an orphaned process so that it can be reaped. -pub(crate) trait OrphanQueue<T>: ReapOrphanQueue { +pub(crate) trait OrphanQueue<T> { /// Adds an orphan to the queue. fn push_orphan(&self, orphan: T); } @@ -48,50 +38,91 @@ impl<T, O: OrphanQueue<T>> OrphanQueue<T> for &O { /// An implementation of `OrphanQueue`. #[derive(Debug)] pub(crate) struct OrphanQueueImpl<T> { + sigchild: Mutex<Option<watch::Receiver<()>>>, queue: Mutex<Vec<T>>, } impl<T> OrphanQueueImpl<T> { pub(crate) fn new() -> Self { Self { + sigchild: Mutex::new(None), queue: Mutex::new(Vec::new()), } } #[cfg(test)] fn len(&self) -> usize { - self.queue.lock().unwrap().len() + self.queue.lock().len() } -} -impl<T: Wait> OrphanQueue<T> for OrphanQueueImpl<T> { - fn push_orphan(&self, orphan: T) { - self.queue.lock().unwrap().push(orphan) + pub(crate) fn push_orphan(&self, orphan: T) + where + T: Wait, + { + self.queue.lock().push(orphan) } -} -impl<T: Wait> ReapOrphanQueue for OrphanQueueImpl<T> { - fn reap_orphans(&self) { - let mut queue = self.queue.lock().unwrap(); - let queue = &mut *queue; - - for i in (0..queue.len()).rev() { - match queue[i].try_wait() { - Ok(None) => {} - Ok(Some(_)) | Err(_) => { - // The stdlib handles interruption errors (EINTR) when polling a child process. - // All other errors represent invalid inputs or pids that have already been - // reaped, so we can drop the orphan in case an error is raised. - queue.swap_remove(i); + /// Attempts to reap every process in the queue, ignoring any errors and + /// enqueueing any orphans which have not yet exited. + pub(crate) fn reap_orphans(&self, handle: &SignalHandle) + where + T: Wait, + { + // If someone else is holding the lock, they will be responsible for draining + // the queue as necessary, so we can safely bail if that happens + if let Some(mut sigchild_guard) = self.sigchild.try_lock() { + match &mut *sigchild_guard { + Some(sigchild) => { + if sigchild.try_has_changed().and_then(Result::ok).is_some() { + drain_orphan_queue(self.queue.lock()); + } + } + None => { + let queue = self.queue.lock(); + + // Be lazy and only initialize the SIGCHLD listener if there + // are any orphaned processes in the queue. + if !queue.is_empty() { + // An errors shouldn't really happen here, but if it does it + // means that the signal driver isn't running, in + // which case there isn't anything we can + // register/initialize here, so we can try again later + if let Ok(sigchild) = signal_with_handle(SignalKind::child(), handle) { + *sigchild_guard = Some(sigchild); + drain_orphan_queue(queue); + } + } } } } } } +fn drain_orphan_queue<T>(mut queue: MutexGuard<'_, Vec<T>>) +where + T: Wait, +{ + for i in (0..queue.len()).rev() { + match queue[i].try_wait() { + Ok(None) => {} + Ok(Some(_)) | Err(_) => { + // The stdlib handles interruption errors (EINTR) when polling a child process. + // All other errors represent invalid inputs or pids that have already been + // reaped, so we can drop the orphan in case an error is raised. + queue.swap_remove(i); + } + } + } + + drop(queue); +} + #[cfg(all(test, not(loom)))] pub(crate) mod test { use super::*; + use crate::io::driver::Driver as IoDriver; + use crate::signal::unix::driver::{Driver as SignalDriver, Handle as SignalHandle}; + use crate::sync::watch; use std::cell::{Cell, RefCell}; use std::io; use std::os::unix::process::ExitStatusExt; @@ -100,14 +131,12 @@ pub(crate) mod test { pub(crate) struct MockQueue<W> { pub(crate) all_enqueued: RefCell<Vec<W>>, - pub(crate) total_reaps: Cell<usize>, } impl<W> MockQueue<W> { pub(crate) fn new() -> Self { Self { all_enqueued: RefCell::new(Vec::new()), - total_reaps: Cell::new(0), } } } @@ -118,12 +147,6 @@ pub(crate) mod test { } } - impl<W> ReapOrphanQueue for MockQueue<W> { - fn reap_orphans(&self) { - self.total_reaps.set(self.total_reaps.get() + 1); - } - } - struct MockWait { total_waits: Rc<Cell<usize>>, num_wait_until_status: usize, @@ -191,27 +214,107 @@ pub(crate) mod test { assert_eq!(orphanage.len(), 4); - orphanage.reap_orphans(); + drain_orphan_queue(orphanage.queue.lock()); assert_eq!(orphanage.len(), 2); assert_eq!(first_waits.get(), 1); assert_eq!(second_waits.get(), 1); assert_eq!(third_waits.get(), 1); assert_eq!(fourth_waits.get(), 1); - orphanage.reap_orphans(); + drain_orphan_queue(orphanage.queue.lock()); assert_eq!(orphanage.len(), 1); assert_eq!(first_waits.get(), 1); assert_eq!(second_waits.get(), 2); assert_eq!(third_waits.get(), 2); assert_eq!(fourth_waits.get(), 1); - orphanage.reap_orphans(); + drain_orphan_queue(orphanage.queue.lock()); assert_eq!(orphanage.len(), 0); assert_eq!(first_waits.get(), 1); assert_eq!(second_waits.get(), 2); assert_eq!(third_waits.get(), 3); assert_eq!(fourth_waits.get(), 1); - orphanage.reap_orphans(); // Safe to reap when empty + // Safe to reap when empty + drain_orphan_queue(orphanage.queue.lock()); + } + + #[test] + fn no_reap_if_no_signal_received() { + let (tx, rx) = watch::channel(()); + + let handle = SignalHandle::default(); + + let orphanage = OrphanQueueImpl::new(); + *orphanage.sigchild.lock() = Some(rx); + + let orphan = MockWait::new(2); + let waits = orphan.total_waits.clone(); + orphanage.push_orphan(orphan); + + orphanage.reap_orphans(&handle); + assert_eq!(waits.get(), 0); + + orphanage.reap_orphans(&handle); + assert_eq!(waits.get(), 0); + + tx.send(()).unwrap(); + orphanage.reap_orphans(&handle); + assert_eq!(waits.get(), 1); + } + + #[test] + fn no_reap_if_signal_lock_held() { + let handle = SignalHandle::default(); + + let orphanage = OrphanQueueImpl::new(); + let signal_guard = orphanage.sigchild.lock(); + + let orphan = MockWait::new(2); + let waits = orphan.total_waits.clone(); + orphanage.push_orphan(orphan); + + orphanage.reap_orphans(&handle); + assert_eq!(waits.get(), 0); + + drop(signal_guard); + } + + #[test] + fn does_not_register_signal_if_queue_empty() { + let signal_driver = IoDriver::new().and_then(SignalDriver::new).unwrap(); + let handle = signal_driver.handle(); + + let orphanage = OrphanQueueImpl::new(); + assert!(orphanage.sigchild.lock().is_none()); // Sanity + + // No register when queue empty + orphanage.reap_orphans(&handle); + assert!(orphanage.sigchild.lock().is_none()); + + let orphan = MockWait::new(2); + let waits = orphan.total_waits.clone(); + orphanage.push_orphan(orphan); + + orphanage.reap_orphans(&handle); + assert!(orphanage.sigchild.lock().is_some()); + assert_eq!(waits.get(), 1); // Eager reap when registering listener + } + + #[test] + fn does_nothing_if_signal_could_not_be_registered() { + let handle = SignalHandle::default(); + + let orphanage = OrphanQueueImpl::new(); + assert!(orphanage.sigchild.lock().is_none()); + + let orphan = MockWait::new(2); + let waits = orphan.total_waits.clone(); + orphanage.push_orphan(orphan); + + // Signal handler has "gone away", nothing to register or reap + orphanage.reap_orphans(&handle); + assert!(orphanage.sigchild.lock().is_none()); + assert_eq!(waits.get(), 0); } } diff --git a/src/process/unix/reap.rs b/src/process/unix/reap.rs index 5dc95e5..f7f4d3c 100644 --- a/src/process/unix/reap.rs +++ b/src/process/unix/reap.rs @@ -224,7 +224,6 @@ mod test { assert!(grim.poll_unpin(&mut context).is_pending()); assert_eq!(1, grim.signal.total_polls); assert_eq!(1, grim.total_waits); - assert_eq!(0, grim.orphan_queue.total_reaps.get()); assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); // Not yet exited, couldn't register interest the first time @@ -232,7 +231,6 @@ mod test { assert!(grim.poll_unpin(&mut context).is_pending()); assert_eq!(3, grim.signal.total_polls); assert_eq!(3, grim.total_waits); - assert_eq!(0, grim.orphan_queue.total_reaps.get()); assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); // Exited @@ -245,7 +243,6 @@ mod test { } assert_eq!(4, grim.signal.total_polls); assert_eq!(4, grim.total_waits); - assert_eq!(0, grim.orphan_queue.total_reaps.get()); assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); } @@ -260,7 +257,6 @@ mod test { grim.kill().unwrap(); assert_eq!(1, grim.total_kills); - assert_eq!(0, grim.orphan_queue.total_reaps.get()); assert!(grim.orphan_queue.all_enqueued.borrow().is_empty()); } @@ -276,7 +272,6 @@ mod test { drop(grim); - assert_eq!(0, queue.total_reaps.get()); assert!(queue.all_enqueued.borrow().is_empty()); } @@ -294,7 +289,6 @@ mod test { let grim = Reaper::new(&mut mock, &queue, MockStream::new(vec![])); drop(grim); - assert_eq!(0, queue.total_reaps.get()); assert_eq!(1, queue.all_enqueued.borrow().len()); } diff --git a/src/process/windows.rs b/src/process/windows.rs index 7237525..136d5b0 100644 --- a/src/process/windows.rs +++ b/src/process/windows.rs @@ -24,7 +24,7 @@ use mio::windows::NamedPipe; use std::fmt; use std::future::Future; use std::io; -use std::os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle}; +use std::os::windows::prelude::{AsRawHandle, FromRawHandle, IntoRawHandle, RawHandle}; use std::pin::Pin; use std::process::Stdio; use std::process::{Child as StdChild, Command as StdCommand, ExitStatus}; @@ -67,9 +67,9 @@ unsafe impl Send for Waiting {} pub(crate) fn spawn_child(cmd: &mut StdCommand) -> io::Result<SpawnedChild> { let mut child = cmd.spawn()?; - let stdin = stdio(child.stdin.take()); - let stdout = stdio(child.stdout.take()); - let stderr = stdio(child.stderr.take()); + let stdin = child.stdin.take().map(stdio).transpose()?; + let stdout = child.stdout.take().map(stdio).transpose()?; + let stderr = child.stderr.take().map(stdio).transpose()?; Ok(SpawnedChild { child: Child { @@ -144,6 +144,12 @@ impl Future for Child { } } +impl AsRawHandle for Child { + fn as_raw_handle(&self) -> RawHandle { + self.child.as_raw_handle() + } +} + impl Drop for Waiting { fn drop(&mut self) { unsafe { @@ -161,20 +167,14 @@ unsafe extern "system" fn callback(ptr: PVOID, _timer_fired: BOOLEAN) { let _ = complete.take().unwrap().send(()); } -pub(crate) type ChildStdin = PollEvented<NamedPipe>; -pub(crate) type ChildStdout = PollEvented<NamedPipe>; -pub(crate) type ChildStderr = PollEvented<NamedPipe>; +pub(crate) type ChildStdio = PollEvented<NamedPipe>; -fn stdio<T>(option: Option<T>) -> Option<PollEvented<NamedPipe>> +pub(super) fn stdio<T>(io: T) -> io::Result<PollEvented<NamedPipe>> where T: IntoRawHandle, { - let io = match option { - Some(io) => io, - None => return None, - }; let pipe = unsafe { NamedPipe::from_raw_handle(io.into_raw_handle()) }; - PollEvented::new(pipe).ok() + PollEvented::new(pipe) } pub(crate) fn convert_to_stdio(io: PollEvented<NamedPipe>) -> io::Result<Stdio> { diff --git a/src/runtime/basic_scheduler.rs b/src/runtime/basic_scheduler.rs index ffe0bca..872d0d5 100644 --- a/src/runtime/basic_scheduler.rs +++ b/src/runtime/basic_scheduler.rs @@ -2,17 +2,18 @@ use crate::future::poll_fn; use crate::loom::sync::atomic::AtomicBool; use crate::loom::sync::Mutex; use crate::park::{Park, Unpark}; -use crate::runtime::task::{self, JoinHandle, Schedule, Task}; +use crate::runtime::context::EnterGuard; +use crate::runtime::stats::{RuntimeStats, WorkerStatsBatcher}; +use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::runtime::Callback; use crate::sync::notify::Notify; -use crate::util::linked_list::{Link, LinkedList}; use crate::util::{waker_ref, Wake, WakerRef}; use std::cell::RefCell; use std::collections::VecDeque; use std::fmt; use std::future::Future; -use std::ptr::NonNull; -use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::sync::atomic::Ordering::{AcqRel, Release}; use std::sync::Arc; use std::task::Poll::{Pending, Ready}; use std::time::Duration; @@ -29,6 +30,12 @@ pub(crate) struct BasicScheduler<P: Park> { /// Sendable task spawner spawner: Spawner, + + /// This is usually None, but right before dropping the BasicScheduler, it + /// is changed to `Some` with the context being the runtime's own context. + /// This ensures that any tasks dropped in the `BasicScheduler`s destructor + /// run in that runtime's context. + context_guard: Option<EnterGuard>, } /// The inner scheduler that owns the task queue and the main parker P. @@ -49,6 +56,14 @@ struct Inner<P: Park> { /// Thread park handle park: P, + + /// Callback for a worker parking itself + before_park: Option<Callback>, + /// Callback for a worker unparking itself + after_unpark: Option<Callback>, + + /// Stats batcher + stats: WorkerStatsBatcher, } #[derive(Clone)] @@ -57,9 +72,6 @@ pub(crate) struct Spawner { } struct Tasks { - /// Collection of all active tasks spawned onto this executor. - owned: LinkedList<Task<Arc<Shared>>, <Task<Arc<Shared>> as Link>::Target>, - /// Local run queue. /// /// Tasks notified from the current thread are pushed into this queue. @@ -69,29 +81,32 @@ struct Tasks { /// A remote scheduler entry. /// /// These are filled in by remote threads sending instructions to the scheduler. -enum Entry { +enum RemoteMsg { /// A remote thread wants to spawn a task. Schedule(task::Notified<Arc<Shared>>), - /// A remote thread wants a task to be released by the scheduler. We only - /// have access to its header. - Release(NonNull<task::Header>), } // Safety: Used correctly, the task header is "thread safe". Ultimately the task // is owned by the current thread executor, for which this instruction is being // sent. -unsafe impl Send for Entry {} +unsafe impl Send for RemoteMsg {} /// Scheduler state shared between threads. struct Shared { - /// Remote run queue - queue: Mutex<VecDeque<Entry>>, + /// Remote run queue. None if the `Runtime` has been dropped. + queue: Mutex<Option<VecDeque<RemoteMsg>>>, - /// Unpark the blocked thread + /// Collection of all active tasks spawned onto this executor. + owned: OwnedTasks<Arc<Shared>>, + + /// Unpark the blocked thread. unpark: Box<dyn Unpark>, - // indicates whether the blocked on thread was woken + /// Indicates whether the blocked on thread was woken. woken: AtomicBool, + + /// Keeps track of various runtime stats. + stats: RuntimeStats, } /// Thread-local context. @@ -119,31 +134,40 @@ const REMOTE_FIRST_INTERVAL: u8 = 31; scoped_thread_local!(static CURRENT: Context); impl<P: Park> BasicScheduler<P> { - pub(crate) fn new(park: P) -> BasicScheduler<P> { + pub(crate) fn new( + park: P, + before_park: Option<Callback>, + after_unpark: Option<Callback>, + ) -> BasicScheduler<P> { let unpark = Box::new(park.unpark()); let spawner = Spawner { shared: Arc::new(Shared { - queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)), + queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), + owned: OwnedTasks::new(), unpark: unpark as Box<dyn Unpark>, woken: AtomicBool::new(false), + stats: RuntimeStats::new(1), }), }; let inner = Mutex::new(Some(Inner { tasks: Some(Tasks { - owned: LinkedList::new(), queue: VecDeque::with_capacity(INITIAL_CAPACITY), }), spawner: spawner.clone(), tick: 0, park, + before_park, + after_unpark, + stats: WorkerStatsBatcher::new(0), })); BasicScheduler { inner, notify: Notify::new(), spawner, + context_guard: None, } } @@ -191,25 +215,28 @@ impl<P: Park> BasicScheduler<P> { Some(InnerGuard { inner: Some(inner), - basic_scheduler: &self, + basic_scheduler: self, }) } + + pub(super) fn set_context_guard(&mut self, guard: EnterGuard) { + self.context_guard = Some(guard); + } } impl<P: Park> Inner<P> { - /// Block on the future provided and drive the runtime's driver. + /// Blocks on the provided future and drives the runtime's driver. fn block_on<F: Future>(&mut self, future: F) -> F::Output { enter(self, |scheduler, context| { let _enter = crate::runtime::enter(false); let waker = scheduler.spawner.waker_ref(); let mut cx = std::task::Context::from_waker(&waker); - let mut polled = false; pin!(future); 'outer: loop { - if scheduler.spawner.was_woken() || !polled { - polled = true; + if scheduler.spawner.reset_woken() { + scheduler.stats.incr_poll_count(); if let Ready(v) = crate::coop::budget(|| future.as_mut().poll(&mut cx)) { return v; } @@ -227,7 +254,7 @@ impl<P: Park> Inner<P> { .borrow_mut() .queue .pop_front() - .map(Entry::Schedule) + .map(RemoteMsg::Schedule) }) } else { context @@ -235,15 +262,28 @@ impl<P: Park> Inner<P> { .borrow_mut() .queue .pop_front() - .map(Entry::Schedule) + .map(RemoteMsg::Schedule) .or_else(|| scheduler.spawner.pop()) }; let entry = match entry { Some(entry) => entry, None => { - // Park until the thread is signaled - scheduler.park.park().expect("failed to park"); + if let Some(f) = &scheduler.before_park { + f(); + } + // This check will fail if `before_park` spawns a task for us to run + // instead of parking the thread + if context.tasks.borrow_mut().queue.is_empty() { + // Park until the thread is signaled + scheduler.stats.about_to_park(); + scheduler.stats.submit(&scheduler.spawner.shared.stats); + scheduler.park.park().expect("failed to park"); + scheduler.stats.returned_from_park(); + } + if let Some(f) = &scheduler.after_unpark { + f(); + } // Try polling the `block_on` future next continue 'outer; @@ -251,31 +291,17 @@ impl<P: Park> Inner<P> { }; match entry { - Entry::Schedule(task) => crate::coop::budget(|| task.run()), - Entry::Release(ptr) => { - // Safety: the task header is only legally provided - // internally in the header, so we know that it is a - // valid (or in particular *allocated*) header that - // is part of the linked list. - unsafe { - let removed = context.tasks.borrow_mut().owned.remove(ptr); - - // TODO: This seems like it should hold, because - // there doesn't seem to be an avenue for anyone - // else to fiddle with the owned tasks - // collection *after* a remote thread has marked - // it as released, and at that point, the only - // location at which it can be removed is here - // or in the Drop implementation of the - // scheduler. - debug_assert!(removed.is_some()); - } + RemoteMsg::Schedule(task) => { + scheduler.stats.incr_poll_count(); + let task = context.shared.owned.assert_owner(task); + crate::coop::budget(|| task.run()) } } } // Yield to the park, this drives the timer and pulls any pending // I/O events. + scheduler.stats.submit(&scheduler.spawner.shared.stats); scheduler .park .park_timeout(Duration::from_millis(0)) @@ -285,8 +311,8 @@ impl<P: Park> Inner<P> { } } -/// Enter the scheduler context. This sets the queue and other necessary -/// scheduler state in the thread-local +/// Enters the scheduler context. This sets the queue and other necessary +/// scheduler state in the thread-local. fn enter<F, R, P>(scheduler: &mut Inner<P>, f: F) -> R where F: FnOnce(&mut Inner<P>, &Context) -> R, @@ -335,36 +361,33 @@ impl<P: Park> Drop for BasicScheduler<P> { }; enter(&mut inner, |scheduler, context| { - // Loop required here to ensure borrow is dropped between iterations - #[allow(clippy::while_let_loop)] - loop { - let task = match context.tasks.borrow_mut().owned.pop_back() { - Some(task) => task, - None => break, - }; - - task.shutdown(); - } + // Drain the OwnedTasks collection. This call also closes the + // collection, ensuring that no tasks are ever pushed after this + // call returns. + context.shared.owned.close_and_shutdown_all(); // Drain local queue + // We already shut down every task, so we just need to drop the task. for task in context.tasks.borrow_mut().queue.drain(..) { - task.shutdown(); + drop(task); } - // Drain remote queue - for entry in scheduler.spawner.shared.queue.lock().drain(..) { - match entry { - Entry::Schedule(task) => { - task.shutdown(); - } - Entry::Release(..) => { - // Do nothing, each entry in the linked list was *just* - // dropped by the scheduler above. + // Drain remote queue and set it to None + let remote_queue = scheduler.spawner.shared.queue.lock().take(); + + // Using `Option::take` to replace the shared queue with `None`. + // We already shut down every task, so we just need to drop the task. + if let Some(remote_queue) = remote_queue { + for entry in remote_queue { + match entry { + RemoteMsg::Schedule(task) => { + drop(task); + } } } } - assert!(context.tasks.borrow().owned.is_empty()); + assert!(context.shared.owned.is_empty()); }); } } @@ -378,29 +401,42 @@ impl<P: Park> fmt::Debug for BasicScheduler<P> { // ===== impl Spawner ===== impl Spawner { - /// Spawns a future onto the thread pool + /// Spawns a future onto the basic scheduler pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> where - F: Future + Send + 'static, + F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (task, handle) = task::joinable(future); - self.shared.schedule(task); + let (handle, notified) = self.shared.owned.bind(future, self.shared.clone()); + + if let Some(notified) = notified { + self.shared.schedule(notified); + } + handle } - fn pop(&self) -> Option<Entry> { - self.shared.queue.lock().pop_front() + pub(crate) fn stats(&self) -> &RuntimeStats { + &self.shared.stats + } + + fn pop(&self) -> Option<RemoteMsg> { + match self.shared.queue.lock().as_mut() { + Some(queue) => queue.pop_front(), + None => None, + } } fn waker_ref(&self) -> WakerRef<'_> { - // clear the woken bit - self.shared.woken.swap(false, AcqRel); + // Set woken to true when enter block_on, ensure outer future + // be polled for the first time when enter loop + self.shared.woken.store(true, Release); waker_ref(&self.shared) } - fn was_woken(&self) -> bool { - self.shared.woken.load(Acquire) + // reset woken to false and return original value + pub(crate) fn reset_woken(&self) -> bool { + self.shared.woken.swap(false, AcqRel) } } @@ -413,30 +449,8 @@ impl fmt::Debug for Spawner { // ===== impl Shared ===== impl Schedule for Arc<Shared> { - fn bind(task: Task<Self>) -> Arc<Shared> { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - cx.tasks.borrow_mut().owned.push_front(task); - cx.shared.clone() - }) - } - fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { - CURRENT.with(|maybe_cx| { - let ptr = NonNull::from(task.header()); - - if let Some(cx) = maybe_cx { - // safety: the task is inserted in the list in `bind`. - unsafe { cx.tasks.borrow_mut().owned.remove(ptr) } - } else { - self.queue.lock().push_back(Entry::Release(ptr)); - self.unpark.unpark(); - // Returning `None` here prevents the task plumbing from being - // freed. It is then up to the scheduler through the queue we - // just added to, or its Drop impl to free the task. - None - } - }) + self.owned.remove(task) } fn schedule(&self, task: task::Notified<Self>) { @@ -445,8 +459,14 @@ impl Schedule for Arc<Shared> { cx.tasks.borrow_mut().queue.push_back(task); } _ => { - self.queue.lock().push_back(Entry::Schedule(task)); - self.unpark.unpark(); + // If the queue is None, then the runtime has shut down. We + // don't need to do anything with the notification in that case. + let mut guard = self.queue.lock(); + if let Some(queue) = guard.as_mut() { + queue.push_back(RemoteMsg::Schedule(task)); + drop(guard); + self.unpark.unpark(); + } } }); } diff --git a/src/runtime/blocking/mod.rs b/src/runtime/blocking/mod.rs index fece3c2..670ec3a 100644 --- a/src/runtime/blocking/mod.rs +++ b/src/runtime/blocking/mod.rs @@ -8,7 +8,9 @@ pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner}; mod schedule; mod shutdown; -pub(crate) mod task; +mod task; +pub(crate) use schedule::NoopSchedule; +pub(crate) use task::BlockingTask; use crate::runtime::Builder; diff --git a/src/runtime/blocking/pool.rs b/src/runtime/blocking/pool.rs index 791e405..77ab495 100644 --- a/src/runtime/blocking/pool.rs +++ b/src/runtime/blocking/pool.rs @@ -4,12 +4,10 @@ use crate::loom::sync::{Arc, Condvar, Mutex}; use crate::loom::thread; use crate::runtime::blocking::schedule::NoopSchedule; use crate::runtime::blocking::shutdown; -use crate::runtime::blocking::task::BlockingTask; use crate::runtime::builder::ThreadNameFn; use crate::runtime::context; use crate::runtime::task::{self, JoinHandle}; use crate::runtime::{Builder, Callback, Handle}; -use crate::util::error::CONTEXT_MISSING_ERROR; use std::collections::{HashMap, VecDeque}; use std::fmt; @@ -26,28 +24,28 @@ pub(crate) struct Spawner { } struct Inner { - /// State shared between worker threads + /// State shared between worker threads. shared: Mutex<Shared>, /// Pool threads wait on this. condvar: Condvar, - /// Spawned threads use this name + /// Spawned threads use this name. thread_name: ThreadNameFn, - /// Spawned thread stack size + /// Spawned thread stack size. stack_size: Option<usize>, - /// Call after a thread starts + /// Call after a thread starts. after_start: Option<Callback>, - /// Call before a thread stops + /// Call before a thread stops. before_stop: Option<Callback>, - // Maximum number of threads + // Maximum number of threads. thread_cap: usize, - // Customizable wait timeout + // Customizable wait timeout. keep_alive: Duration, } @@ -61,43 +59,31 @@ struct Shared { /// Prior to shutdown, we clean up JoinHandles by having each timed-out /// thread join on the previous timed-out thread. This is not strictly /// necessary but helps avoid Valgrind false positives, see - /// https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666 + /// <https://github.com/tokio-rs/tokio/commit/646fbae76535e397ef79dbcaacb945d4c829f666> /// for more information. last_exiting_thread: Option<thread::JoinHandle<()>>, /// This holds the JoinHandles for all running threads; on shutdown, the thread /// calling shutdown handles joining on these. worker_threads: HashMap<usize, thread::JoinHandle<()>>, /// This is a counter used to iterate worker_threads in a consistent order (for loom's - /// benefit) + /// benefit). worker_thread_index: usize, } -type Task = task::Notified<NoopSchedule>; +type Task = task::UnownedTask<NoopSchedule>; const KEEP_ALIVE: Duration = Duration::from_secs(10); -/// Run the provided function on an executor dedicated to blocking operations. +/// Runs the provided function on an executor dedicated to blocking operations. pub(crate) fn spawn_blocking<F, R>(func: F) -> JoinHandle<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { - let rt = context::current().expect(CONTEXT_MISSING_ERROR); + let rt = context::current(); rt.spawn_blocking(func) } -#[allow(dead_code)] -pub(crate) fn try_spawn_blocking<F, R>(func: F) -> Result<(), ()> -where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, -{ - let rt = context::current().expect(CONTEXT_MISSING_ERROR); - - let (task, _handle) = task::joinable(BlockingTask::new(func)); - rt.blocking_spawner.spawn(task, &rt) -} - // ===== impl BlockingPool ===== impl BlockingPool { @@ -151,7 +137,7 @@ impl BlockingPool { self.spawner.inner.condvar.notify_all(); let last_exited_thread = std::mem::take(&mut shared.last_exiting_thread); - let workers = std::mem::replace(&mut shared.worker_threads, HashMap::new()); + let workers = std::mem::take(&mut shared.worker_threads); drop(shared); diff --git a/src/runtime/blocking/schedule.rs b/src/runtime/blocking/schedule.rs index 4e044ab..5425224 100644 --- a/src/runtime/blocking/schedule.rs +++ b/src/runtime/blocking/schedule.rs @@ -9,11 +9,6 @@ use crate::runtime::task::{self, Task}; pub(crate) struct NoopSchedule; impl task::Schedule for NoopSchedule { - fn bind(_task: Task<Self>) -> NoopSchedule { - // Do nothing w/ the task - NoopSchedule - } - fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { None } diff --git a/src/runtime/blocking/shutdown.rs b/src/runtime/blocking/shutdown.rs index 0cf2285..e6f4674 100644 --- a/src/runtime/blocking/shutdown.rs +++ b/src/runtime/blocking/shutdown.rs @@ -10,7 +10,7 @@ use std::time::Duration; #[derive(Debug, Clone)] pub(super) struct Sender { - tx: Arc<oneshot::Sender<()>>, + _tx: Arc<oneshot::Sender<()>>, } #[derive(Debug)] @@ -20,7 +20,7 @@ pub(super) struct Receiver { pub(super) fn channel() -> (Sender, Receiver) { let (tx, rx) = oneshot::channel(); - let tx = Sender { tx: Arc::new(tx) }; + let tx = Sender { _tx: Arc::new(tx) }; let rx = Receiver { rx }; (tx, rx) diff --git a/src/runtime/blocking/task.rs b/src/runtime/blocking/task.rs index ee2d8d6..0b7803a 100644 --- a/src/runtime/blocking/task.rs +++ b/src/runtime/blocking/task.rs @@ -2,13 +2,13 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -/// Converts a function to a future that completes on poll +/// Converts a function to a future that completes on poll. pub(crate) struct BlockingTask<T> { func: Option<T>, } impl<T> BlockingTask<T> { - /// Initializes a new blocking task from the given function + /// Initializes a new blocking task from the given function. pub(crate) fn new(func: T) -> BlockingTask<T> { BlockingTask { func: Some(func) } } diff --git a/src/runtime/builder.rs b/src/runtime/builder.rs index 0249266..91c365f 100644 --- a/src/runtime/builder.rs +++ b/src/runtime/builder.rs @@ -70,6 +70,12 @@ pub struct Builder { /// To run before each worker thread stops pub(super) before_stop: Option<Callback>, + /// To run before each worker thread is parked. + pub(super) before_park: Option<Callback>, + + /// To run after each thread is unparked. + pub(super) after_unpark: Option<Callback>, + /// Customizable keep alive timeout for BlockingPool pub(super) keep_alive: Option<Duration>, } @@ -135,6 +141,8 @@ impl Builder { // No worker thread callbacks after_start: None, before_stop: None, + before_park: None, + after_unpark: None, keep_alive: None, } @@ -374,6 +382,120 @@ impl Builder { self } + /// Executes function `f` just before a thread is parked (goes idle). + /// `f` is called within the Tokio context, so functions like [`tokio::spawn`](crate::spawn) + /// can be called, and may result in this thread being unparked immediately. + /// + /// This can be used to start work only when the executor is idle, or for bookkeeping + /// and monitoring purposes. + /// + /// Note: There can only be one park callback for a runtime; calling this function + /// more than once replaces the last callback defined, rather than adding to it. + /// + /// # Examples + /// + /// ## Multithreaded executor + /// ``` + /// # use std::sync::Arc; + /// # use std::sync::atomic::{AtomicBool, Ordering}; + /// # use tokio::runtime; + /// # use tokio::sync::Barrier; + /// # pub fn main() { + /// let once = AtomicBool::new(true); + /// let barrier = Arc::new(Barrier::new(2)); + /// + /// let runtime = runtime::Builder::new_multi_thread() + /// .worker_threads(1) + /// .on_thread_park({ + /// let barrier = barrier.clone(); + /// move || { + /// let barrier = barrier.clone(); + /// if once.swap(false, Ordering::Relaxed) { + /// tokio::spawn(async move { barrier.wait().await; }); + /// } + /// } + /// }) + /// .build() + /// .unwrap(); + /// + /// runtime.block_on(async { + /// barrier.wait().await; + /// }) + /// # } + /// ``` + /// ## Current thread executor + /// ``` + /// # use std::sync::Arc; + /// # use std::sync::atomic::{AtomicBool, Ordering}; + /// # use tokio::runtime; + /// # use tokio::sync::Barrier; + /// # pub fn main() { + /// let once = AtomicBool::new(true); + /// let barrier = Arc::new(Barrier::new(2)); + /// + /// let runtime = runtime::Builder::new_current_thread() + /// .on_thread_park({ + /// let barrier = barrier.clone(); + /// move || { + /// let barrier = barrier.clone(); + /// if once.swap(false, Ordering::Relaxed) { + /// tokio::spawn(async move { barrier.wait().await; }); + /// } + /// } + /// }) + /// .build() + /// .unwrap(); + /// + /// runtime.block_on(async { + /// barrier.wait().await; + /// }) + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_park<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.before_park = Some(std::sync::Arc::new(f)); + self + } + + /// Executes function `f` just after a thread unparks (starts executing tasks). + /// + /// This is intended for bookkeeping and monitoring use cases; note that work + /// in this callback will increase latencies when the application has allowed one or + /// more runtime threads to go idle. + /// + /// Note: There can only be one unpark callback for a runtime; calling this function + /// more than once replaces the last callback defined, rather than adding to it. + /// + /// # Examples + /// + /// ``` + /// # use tokio::runtime; + /// + /// # pub fn main() { + /// let runtime = runtime::Builder::new_multi_thread() + /// .on_thread_unpark(|| { + /// println!("thread unparking"); + /// }) + /// .build(); + /// + /// runtime.unwrap().block_on(async { + /// tokio::task::yield_now().await; + /// println!("Hello from Tokio!"); + /// }) + /// # } + /// ``` + #[cfg(not(loom))] + pub fn on_thread_unpark<F>(&mut self, f: F) -> &mut Self + where + F: Fn() + Send + Sync + 'static, + { + self.after_unpark = Some(std::sync::Arc::new(f)); + self + } + /// Creates the configured `Runtime`. /// /// The returned `Runtime` instance is ready to spawn tasks. @@ -413,7 +535,7 @@ impl Builder { /// Sets a custom timeout for a thread in the blocking pool. /// /// By default, the timeout for a thread is set to 10 seconds. This can - /// be overriden using .thread_keep_alive(). + /// be overridden using .thread_keep_alive(). /// /// # Example /// @@ -441,7 +563,8 @@ impl Builder { // there are no futures ready to do something, it'll let the timer or // the reactor to generate some new stimuli for the futures to continue // in their life. - let scheduler = BasicScheduler::new(driver); + let scheduler = + BasicScheduler::new(driver, self.before_park.clone(), self.after_unpark.clone()); let spawner = Spawner::Basic(scheduler.spawner().clone()); // Blocking pool @@ -546,7 +669,7 @@ cfg_rt_multi_thread! { let (driver, resources) = driver::Driver::new(self.get_cfg())?; - let (scheduler, launch) = ThreadPool::new(core_threads, Parker::new(driver)); + let (scheduler, launch) = ThreadPool::new(core_threads, Parker::new(driver), self.before_park.clone(), self.after_unpark.clone()); let spawner = Spawner::ThreadPool(scheduler.spawner().clone()); // Create the blocking pool @@ -587,7 +710,9 @@ impl fmt::Debug for Builder { ) .field("thread_stack_size", &self.thread_stack_size) .field("after_start", &self.after_start.as_ref().map(|_| "...")) - .field("before_stop", &self.after_start.as_ref().map(|_| "...")) + .field("before_stop", &self.before_stop.as_ref().map(|_| "...")) + .field("before_park", &self.before_park.as_ref().map(|_| "...")) + .field("after_unpark", &self.after_unpark.as_ref().map(|_| "...")) .finish() } } diff --git a/src/runtime/context.rs b/src/runtime/context.rs index a727ed4..1f44a53 100644 --- a/src/runtime/context.rs +++ b/src/runtime/context.rs @@ -1,5 +1,5 @@ //! Thread local runtime context -use crate::runtime::Handle; +use crate::runtime::{Handle, TryCurrentError}; use std::cell::RefCell; @@ -7,58 +7,96 @@ thread_local! { static CONTEXT: RefCell<Option<Handle>> = RefCell::new(None) } -pub(crate) fn current() -> Option<Handle> { - CONTEXT.with(|ctx| ctx.borrow().clone()) +pub(crate) fn try_current() -> Result<Handle, crate::runtime::TryCurrentError> { + match CONTEXT.try_with(|ctx| ctx.borrow().clone()) { + Ok(Some(handle)) => Ok(handle), + Ok(None) => Err(TryCurrentError::new_no_context()), + Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()), + } +} + +pub(crate) fn current() -> Handle { + match try_current() { + Ok(handle) => handle, + Err(e) => panic!("{}", e), + } } cfg_io_driver! { pub(crate) fn io_handle() -> crate::runtime::driver::IoHandle { - CONTEXT.with(|ctx| { + match CONTEXT.try_with(|ctx| { let ctx = ctx.borrow(); ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).io_handle.clone() - }) + }) { + Ok(io_handle) => io_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } } } cfg_signal_internal! { #[cfg(unix)] pub(crate) fn signal_handle() -> crate::runtime::driver::SignalHandle { - CONTEXT.with(|ctx| { + match CONTEXT.try_with(|ctx| { let ctx = ctx.borrow(); ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).signal_handle.clone() - }) + }) { + Ok(signal_handle) => signal_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } } } cfg_time! { pub(crate) fn time_handle() -> crate::runtime::driver::TimeHandle { - CONTEXT.with(|ctx| { + match CONTEXT.try_with(|ctx| { let ctx = ctx.borrow(); ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).time_handle.clone() - }) + }) { + Ok(time_handle) => time_handle, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } } cfg_test_util! { pub(crate) fn clock() -> Option<crate::runtime::driver::Clock> { - CONTEXT.with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.clock.clone())) + match CONTEXT.try_with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.clock.clone())) { + Ok(clock) => clock, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } } } } cfg_rt! { pub(crate) fn spawn_handle() -> Option<crate::runtime::Spawner> { - CONTEXT.with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.spawner.clone())) + match CONTEXT.try_with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.spawner.clone())) { + Ok(spawner) => spawner, + Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } } } -/// Set this [`Handle`] as the current active [`Handle`]. +/// Sets this [`Handle`] as the current active [`Handle`]. /// /// [`Handle`]: Handle pub(crate) fn enter(new: Handle) -> EnterGuard { - CONTEXT.with(|ctx| { - let old = ctx.borrow_mut().replace(new); - EnterGuard(old) - }) + match try_enter(new) { + Some(guard) => guard, + None => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), + } +} + +/// Sets this [`Handle`] as the current active [`Handle`]. +/// +/// [`Handle`]: Handle +pub(crate) fn try_enter(new: Handle) -> Option<EnterGuard> { + CONTEXT + .try_with(|ctx| { + let old = ctx.borrow_mut().replace(new); + EnterGuard(old) + }) + .ok() } #[derive(Debug)] diff --git a/src/runtime/driver.rs b/src/runtime/driver.rs index a0e8e23..7e45977 100644 --- a/src/runtime/driver.rs +++ b/src/runtime/driver.rs @@ -23,7 +23,7 @@ cfg_io_driver! { let io_handle = io_driver.handle(); let (signal_driver, signal_handle) = create_signal_driver(io_driver)?; - let process_driver = create_process_driver(signal_driver)?; + let process_driver = create_process_driver(signal_driver); (Either::A(process_driver), Some(io_handle), signal_handle) } else { @@ -80,7 +80,7 @@ cfg_not_signal_internal! { cfg_process_driver! { type ProcessDriver = crate::process::unix::driver::Driver; - fn create_process_driver(signal_driver: SignalDriver) -> io::Result<ProcessDriver> { + fn create_process_driver(signal_driver: SignalDriver) -> ProcessDriver { crate::process::unix::driver::Driver::new(signal_driver) } } @@ -89,8 +89,8 @@ cfg_not_process_driver! { cfg_io_driver! { type ProcessDriver = SignalDriver; - fn create_process_driver(signal_driver: SignalDriver) -> io::Result<ProcessDriver> { - Ok(signal_driver) + fn create_process_driver(signal_driver: SignalDriver) -> ProcessDriver { + signal_driver } } } diff --git a/src/runtime/enter.rs b/src/runtime/enter.rs index 4dd8dd0..3f14cb5 100644 --- a/src/runtime/enter.rs +++ b/src/runtime/enter.rs @@ -64,7 +64,7 @@ cfg_rt! { // # Warning // // This is hidden for a reason. Do not use without fully understanding -// executors. Misuing can easily cause your program to deadlock. +// executors. Misusing can easily cause your program to deadlock. cfg_rt_multi_thread! { pub(crate) fn exit<F: FnOnce() -> R, R>(f: F) -> R { // Reset in case the closure panics @@ -92,7 +92,7 @@ cfg_rt_multi_thread! { } cfg_rt! { - /// Disallow blocking in the current runtime context until the guard is dropped. + /// Disallows blocking in the current runtime context until the guard is dropped. pub(crate) fn disallow_blocking() -> DisallowBlockingGuard { let reset = ENTERED.with(|c| { if let EnterContext::Entered { diff --git a/src/runtime/handle.rs b/src/runtime/handle.rs index 4f1b4c5..cd1cb76 100644 --- a/src/runtime/handle.rs +++ b/src/runtime/handle.rs @@ -1,9 +1,10 @@ -use crate::runtime::blocking::task::BlockingTask; +use crate::runtime::blocking::{BlockingTask, NoopSchedule}; use crate::runtime::task::{self, JoinHandle}; use crate::runtime::{blocking, context, driver, Spawner}; -use crate::util::error::CONTEXT_MISSING_ERROR; +use crate::util::error::{CONTEXT_MISSING_ERROR, THREAD_LOCAL_DESTROYED_ERROR}; use std::future::Future; +use std::marker::PhantomData; use std::{error, fmt}; /// Handle to the runtime. @@ -17,15 +18,25 @@ pub struct Handle { pub(super) spawner: Spawner, /// Handles to the I/O drivers + #[cfg_attr( + not(any(feature = "net", feature = "process", all(unix, feature = "signal"))), + allow(dead_code) + )] pub(super) io_handle: driver::IoHandle, /// Handles to the signal drivers + #[cfg_attr( + not(any(feature = "signal", all(unix, feature = "process"))), + allow(dead_code) + )] pub(super) signal_handle: driver::SignalHandle, /// Handles to the time drivers + #[cfg_attr(not(feature = "time"), allow(dead_code))] pub(super) time_handle: driver::TimeHandle, /// Source of `Instant::now()` + #[cfg_attr(not(all(feature = "time", feature = "test-util")), allow(dead_code))] pub(super) clock: driver::Clock, /// Blocking pool spawner @@ -41,12 +52,12 @@ pub struct Handle { #[derive(Debug)] #[must_use = "Creating and dropping a guard does nothing"] pub struct EnterGuard<'a> { - handle: &'a Handle, - guard: context::EnterGuard, + _guard: context::EnterGuard, + _handle_lifetime: PhantomData<&'a Handle>, } impl Handle { - /// Enter the runtime context. This allows you to construct types that must + /// Enters the runtime context. This allows you to construct types that must /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. /// It will also allow you to call methods such as [`tokio::spawn`]. /// @@ -55,12 +66,12 @@ impl Handle { /// [`tokio::spawn`]: fn@crate::spawn pub fn enter(&self) -> EnterGuard<'_> { EnterGuard { - handle: self, - guard: context::enter(self.clone()), + _guard: context::enter(self.clone()), + _handle_lifetime: PhantomData, } } - /// Returns a `Handle` view over the currently running `Runtime` + /// Returns a `Handle` view over the currently running `Runtime`. /// /// # Panic /// @@ -99,7 +110,7 @@ impl Handle { /// # } /// ``` pub fn current() -> Self { - context::current().expect(CONTEXT_MISSING_ERROR) + context::current() } /// Returns a Handle view over the currently running Runtime @@ -108,10 +119,18 @@ impl Handle { /// /// Contrary to `current`, this never panics pub fn try_current() -> Result<Self, TryCurrentError> { - context::current().ok_or(TryCurrentError(())) + context::try_current() } - /// Spawn a future onto the Tokio runtime. + cfg_stats! { + /// Returns a view that lets you get information about how the runtime + /// is performing. + pub fn stats(&self) -> &crate::runtime::stats::RuntimeStats { + self.spawner.stats() + } + } + + /// Spawns a future onto the Tokio runtime. /// /// This spawns the given future onto the runtime's executor, usually a /// thread pool. The thread pool is then responsible for polling the future @@ -145,11 +164,11 @@ impl Handle { F::Output: Send + 'static, { #[cfg(all(tokio_unstable, feature = "tracing"))] - let future = crate::util::trace::task(future, "task"); + let future = crate::util::trace::task(future, "task", None); self.spawner.spawn(future) } - /// Run the provided function on an executor dedicated to blocking + /// Runs the provided function on an executor dedicated to blocking. /// operations. /// /// # Examples @@ -174,36 +193,55 @@ impl Handle { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { + if cfg!(debug_assertions) && std::mem::size_of::<F>() > 2048 { + self.spawn_blocking_inner(Box::new(func), None) + } else { + self.spawn_blocking_inner(func, None) + } + } + + #[cfg_attr(tokio_track_caller, track_caller)] + pub(crate) fn spawn_blocking_inner<F, R>(&self, func: F, name: Option<&str>) -> JoinHandle<R> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let fut = BlockingTask::new(func); + #[cfg(all(tokio_unstable, feature = "tracing"))] - let func = { + let fut = { + use tracing::Instrument; #[cfg(tokio_track_caller)] let location = std::panic::Location::caller(); #[cfg(tokio_track_caller)] let span = tracing::trace_span!( - target: "tokio::task", - "task", + target: "tokio::task::blocking", + "runtime.spawn", kind = %"blocking", - function = %std::any::type_name::<F>(), + task.name = %name.unwrap_or_default(), + "fn" = %std::any::type_name::<F>(), spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), ); #[cfg(not(tokio_track_caller))] let span = tracing::trace_span!( - target: "tokio::task", - "task", + target: "tokio::task::blocking", + "runtime.spawn", kind = %"blocking", - function = %std::any::type_name::<F>(), + task.name = %name.unwrap_or_default(), + "fn" = %std::any::type_name::<F>(), ); - move || { - let _g = span.enter(); - func() - } + fut.instrument(span) }; - let (task, handle) = task::joinable(BlockingTask::new(func)); - let _ = self.blocking_spawner.spawn(task, &self); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let _ = name; + + let (task, handle) = task::unowned(fut, NoopSchedule); + let _ = self.blocking_spawner.spawn(task, self); handle } - /// Run a future to completion on this `Handle`'s associated `Runtime`. + /// Runs a future to completion on this `Handle`'s associated `Runtime`. /// /// This runs the given future on the current thread, blocking until it is /// complete, and yielding its resolved result. Any tasks or timers which @@ -273,7 +311,11 @@ impl Handle { /// [`tokio::fs`]: crate::fs /// [`tokio::net`]: crate::net /// [`tokio::time`]: crate::time + #[cfg_attr(tokio_track_caller, track_caller)] pub fn block_on<F: Future>(&self, future: F) -> F::Output { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "block_on", None); + // Enter the **runtime** context. This configures spawning, the current I/O driver, ... let _rt_enter = self.enter(); @@ -292,17 +334,60 @@ impl Handle { } /// Error returned by `try_current` when no Runtime has been started -pub struct TryCurrentError(()); +#[derive(Debug)] +pub struct TryCurrentError { + kind: TryCurrentErrorKind, +} + +impl TryCurrentError { + pub(crate) fn new_no_context() -> Self { + Self { + kind: TryCurrentErrorKind::NoContext, + } + } -impl fmt::Debug for TryCurrentError { + pub(crate) fn new_thread_local_destroyed() -> Self { + Self { + kind: TryCurrentErrorKind::ThreadLocalDestroyed, + } + } + + /// Returns true if the call failed because there is currently no runtime in + /// the Tokio context. + pub fn is_missing_context(&self) -> bool { + matches!(self.kind, TryCurrentErrorKind::NoContext) + } + + /// Returns true if the call failed because the Tokio context thread-local + /// had been destroyed. This can usually only happen if in the destructor of + /// other thread-locals. + pub fn is_thread_local_destroyed(&self) -> bool { + matches!(self.kind, TryCurrentErrorKind::ThreadLocalDestroyed) + } +} + +enum TryCurrentErrorKind { + NoContext, + ThreadLocalDestroyed, +} + +impl fmt::Debug for TryCurrentErrorKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TryCurrentError").finish() + use TryCurrentErrorKind::*; + match self { + NoContext => f.write_str("NoContext"), + ThreadLocalDestroyed => f.write_str("ThreadLocalDestroyed"), + } } } impl fmt::Display for TryCurrentError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(CONTEXT_MISSING_ERROR) + use TryCurrentErrorKind::*; + match self.kind { + NoContext => f.write_str(CONTEXT_MISSING_ERROR), + ThreadLocalDestroyed => f.write_str(THREAD_LOCAL_DESTROYED_ERROR), + } } } diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 52532ec..96bb47c 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -181,6 +181,13 @@ pub(crate) mod enter; pub(crate) mod task; +cfg_stats! { + pub mod stats; +} +cfg_not_stats! { + pub(crate) mod stats; +} + cfg_rt! { mod basic_scheduler; use basic_scheduler::BasicScheduler; @@ -198,7 +205,7 @@ cfg_rt! { use self::enter::enter; mod handle; - pub use handle::{EnterGuard, Handle}; + pub use handle::{EnterGuard, Handle, TryCurrentError}; mod spawner; use self::spawner::Spawner; @@ -287,7 +294,7 @@ cfg_rt! { type Callback = std::sync::Arc<dyn Fn() + Send + Sync>; impl Runtime { - /// Create a new runtime instance with default configuration values. + /// Creates a new runtime instance with default configuration values. /// /// This results in the multi threaded scheduler, I/O driver, and time driver being /// initialized. @@ -322,7 +329,7 @@ cfg_rt! { Builder::new_multi_thread().enable_all().build() } - /// Return a handle to the runtime's spawner. + /// Returns a handle to the runtime's spawner. /// /// The returned handle can be used to spawn tasks that run on this runtime, and can /// be cloned to allow moving the `Handle` to other threads. @@ -343,7 +350,7 @@ cfg_rt! { &self.handle } - /// Spawn a future onto the Tokio runtime. + /// Spawns a future onto the Tokio runtime. /// /// This spawns the given future onto the runtime's executor, usually a /// thread pool. The thread pool is then responsible for polling the future @@ -377,7 +384,7 @@ cfg_rt! { self.handle.spawn(future) } - /// Run the provided function on an executor dedicated to blocking operations. + /// Runs the provided function on an executor dedicated to blocking operations. /// /// # Examples /// @@ -402,7 +409,7 @@ cfg_rt! { self.handle.spawn_blocking(func) } - /// Run a future to completion on the Tokio runtime. This is the + /// Runs a future to completion on the Tokio runtime. This is the /// runtime's entry point. /// /// This runs the given future on the current thread, blocking until it is @@ -443,7 +450,11 @@ cfg_rt! { /// ``` /// /// [handle]: fn@Handle::block_on + #[cfg_attr(tokio_track_caller, track_caller)] pub fn block_on<F: Future>(&self, future: F) -> F::Output { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let future = crate::util::trace::task(future, "block_on", None); + let _enter = self.enter(); match &self.kind { @@ -453,7 +464,7 @@ cfg_rt! { } } - /// Enter the runtime context. + /// Enters the runtime context. /// /// This allows you to construct types that must have an executor /// available on creation such as [`Sleep`] or [`TcpStream`]. It will @@ -489,7 +500,7 @@ cfg_rt! { self.handle.enter() } - /// Shutdown the runtime, waiting for at most `duration` for all spawned + /// Shuts down the runtime, waiting for at most `duration` for all spawned /// task to shutdown. /// /// Usually, dropping a `Runtime` handle is sufficient as tasks are able to @@ -526,11 +537,11 @@ cfg_rt! { /// ``` pub fn shutdown_timeout(mut self, duration: Duration) { // Wakeup and shutdown all the worker threads - self.handle.shutdown(); + self.handle.clone().shutdown(); self.blocking_pool.shutdown(Some(duration)); } - /// Shutdown the runtime, without waiting for any spawned tasks to shutdown. + /// Shuts down the runtime, without waiting for any spawned tasks to shutdown. /// /// This can be useful if you want to drop a runtime from within another runtime. /// Normally, dropping a runtime will block indefinitely for spawned blocking tasks @@ -560,4 +571,30 @@ cfg_rt! { self.shutdown_timeout(Duration::from_nanos(0)) } } + + #[allow(clippy::single_match)] // there are comments in the error branch, so we don't want if-let + impl Drop for Runtime { + fn drop(&mut self) { + match &mut self.kind { + Kind::CurrentThread(basic) => { + // This ensures that tasks spawned on the basic runtime are dropped inside the + // runtime's context. + match self::context::try_enter(self.handle.clone()) { + Some(guard) => basic.set_context_guard(guard), + None => { + // The context thread-local has alread been destroyed. + // + // We don't set the guard in this case. Calls to tokio::spawn in task + // destructors would fail regardless if this happens. + }, + } + }, + #[cfg(feature = "rt-multi-thread")] + Kind::ThreadPool(_) => { + // The threaded scheduler drops its tasks on its worker threads, which is + // already in the runtime's context. + }, + } + } + } } diff --git a/src/runtime/queue.rs b/src/runtime/queue.rs index 6ea23c9..a88dffc 100644 --- a/src/runtime/queue.rs +++ b/src/runtime/queue.rs @@ -1,13 +1,13 @@ //! Run-queue structures to support a work-stealing scheduler use crate::loom::cell::UnsafeCell; -use crate::loom::sync::atomic::{AtomicU16, AtomicU32, AtomicUsize}; -use crate::loom::sync::{Arc, Mutex}; -use crate::runtime::task; +use crate::loom::sync::atomic::{AtomicU16, AtomicU32}; +use crate::loom::sync::Arc; +use crate::runtime::stats::WorkerStatsBatcher; +use crate::runtime::task::{self, Inject}; -use std::marker::PhantomData; use std::mem::MaybeUninit; -use std::ptr::{self, NonNull}; +use std::ptr; use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; /// Producer handle. May only be used from a single thread. @@ -18,19 +18,6 @@ pub(super) struct Local<T: 'static> { /// Consumer handle. May be used from many threads. pub(super) struct Steal<T: 'static>(Arc<Inner<T>>); -/// Growable, MPMC queue used to inject new tasks into the scheduler and as an -/// overflow queue when the local, fixed-size, array queue overflows. -pub(super) struct Inject<T: 'static> { - /// Pointers to the head and tail of the queue - pointers: Mutex<Pointers>, - - /// Number of pending tasks in the queue. This helps prevent unnecessary - /// locking in the hot path. - len: AtomicUsize, - - _p: PhantomData<T>, -} - pub(super) struct Inner<T: 'static> { /// Concurrently updated by many threads. /// @@ -49,24 +36,11 @@ pub(super) struct Inner<T: 'static> { tail: AtomicU16, /// Elements - buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>]>, -} - -struct Pointers { - /// True if the queue is closed - is_closed: bool, - - /// Linked-list head - head: Option<NonNull<task::Header>>, - - /// Linked-list tail - tail: Option<NonNull<task::Header>>, + buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>, } unsafe impl<T> Send for Inner<T> {} unsafe impl<T> Sync for Inner<T> {} -unsafe impl<T> Send for Inject<T> {} -unsafe impl<T> Sync for Inject<T> {} #[cfg(not(loom))] const LOCAL_QUEUE_CAPACITY: usize = 256; @@ -79,6 +53,17 @@ const LOCAL_QUEUE_CAPACITY: usize = 4; const MASK: usize = LOCAL_QUEUE_CAPACITY - 1; +// Constructing the fixed size array directly is very awkward. The only way to +// do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as +// the contents are not Copy. The trick with defining a const doesn't work for +// generic types. +fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> { + assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY); + + // safety: We check that the length is correct. + unsafe { Box::from_raw(Box::into_raw(buffer).cast()) } +} + /// Create a new local run-queue pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) { let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY); @@ -90,7 +75,7 @@ pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) { let inner = Arc::new(Inner { head: AtomicU32::new(0), tail: AtomicU16::new(0), - buffer: buffer.into(), + buffer: make_fixed_size(buffer.into_boxed_slice()), }); let local = Local { @@ -108,6 +93,14 @@ impl<T> Local<T> { !self.inner.is_empty() } + /// Returns false if there are any entries in the queue + /// + /// Separate to is_stealable so that refactors of is_stealable to "protect" + /// some tasks from stealing won't affect this + pub(super) fn has_tasks(&self) -> bool { + !self.inner.is_empty() + } + /// Pushes a task to the back of the local queue, skipping the LIFO slot. pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) { let tail = loop { @@ -121,8 +114,8 @@ impl<T> Local<T> { // There is capacity for the task break tail; } else if steal != real { - // Concurrently stealing, this will free up capacity, so - // only push the new task onto the inject queue + // Concurrently stealing, this will free up capacity, so only + // push the task onto the inject queue inject.push(task); return; } else { @@ -171,9 +164,12 @@ impl<T> Local<T> { tail: u16, inject: &Inject<T>, ) -> Result<(), task::Notified<T>> { - const BATCH_LEN: usize = LOCAL_QUEUE_CAPACITY / 2 + 1; + /// How many elements are we taking from the local queue. + /// + /// This is one less than the number of tasks pushed to the inject + /// queue as we are also inserting the `task` argument. + const NUM_TASKS_TAKEN: u16 = (LOCAL_QUEUE_CAPACITY / 2) as u16; - let n = (LOCAL_QUEUE_CAPACITY / 2) as u16; assert_eq!( tail.wrapping_sub(head) as usize, LOCAL_QUEUE_CAPACITY, @@ -199,7 +195,10 @@ impl<T> Local<T> { .head .compare_exchange( prev, - pack(head.wrapping_add(n), head.wrapping_add(n)), + pack( + head.wrapping_add(NUM_TASKS_TAKEN), + head.wrapping_add(NUM_TASKS_TAKEN), + ), Release, Relaxed, ) @@ -211,41 +210,41 @@ impl<T> Local<T> { return Err(task); } - // link the tasks - for i in 0..n { - let j = i + 1; - - let i_idx = i.wrapping_add(head) as usize & MASK; - let j_idx = j.wrapping_add(head) as usize & MASK; - - // Get the next pointer - let next = if j == n { - // The last task in the local queue being moved - task.header().into() - } else { - // safety: The above CAS prevents a stealer from accessing these - // tasks and we are the only producer. - self.inner.buffer[j_idx].with(|ptr| unsafe { - let value = (*ptr).as_ptr(); - (*value).header().into() - }) - }; - - // safety: the above CAS prevents a stealer from accessing these - // tasks and we are the only producer. - self.inner.buffer[i_idx].with_mut(|ptr| unsafe { - let ptr = (*ptr).as_ptr(); - (*ptr).header().set_next(Some(next)) - }); + /// An iterator that takes elements out of the run queue. + struct BatchTaskIter<'a, T: 'static> { + buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY], + head: u32, + i: u32, + } + impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> { + type Item = task::Notified<T>; + + #[inline] + fn next(&mut self) -> Option<task::Notified<T>> { + if self.i == u32::from(NUM_TASKS_TAKEN) { + None + } else { + let i_idx = self.i.wrapping_add(self.head) as usize & MASK; + let slot = &self.buffer[i_idx]; + + // safety: Our CAS from before has assumed exclusive ownership + // of the task pointers in this range. + let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); + + self.i += 1; + Some(task) + } + } } - // safety: the above CAS prevents a stealer from accessing these tasks - // and we are the only producer. - let head = self.inner.buffer[head as usize & MASK] - .with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) }); - - // Push the tasks onto the inject queue - inject.push_batch(head, task, BATCH_LEN); + // safety: The CAS above ensures that no consumer will look at these + // values again, and we are the only producer. + let batch_iter = BatchTaskIter { + buffer: &*self.inner.buffer, + head: head as u32, + i: 0, + }; + inject.push_batch(batch_iter.chain(std::iter::once(task))); Ok(()) } @@ -298,7 +297,11 @@ impl<T> Steal<T> { } /// Steals half the tasks from self and place them into `dst`. - pub(super) fn steal_into(&self, dst: &mut Local<T>) -> Option<task::Notified<T>> { + pub(super) fn steal_into( + &self, + dst: &mut Local<T>, + stats: &mut WorkerStatsBatcher, + ) -> Option<task::Notified<T>> { // Safety: the caller is the only thread that mutates `dst.tail` and // holds a mutable reference. let dst_tail = unsafe { dst.inner.tail.unsync_load() }; @@ -317,6 +320,7 @@ impl<T> Steal<T> { // Steal the tasks into `dst`'s buffer. This does not yet expose the // tasks in `dst`. let mut n = self.steal_into2(dst, dst_tail); + stats.incr_steal_count(n); if n == 0 { // No tasks were stolen @@ -465,159 +469,10 @@ impl<T> Inner<T> { } } -impl<T: 'static> Inject<T> { - pub(super) fn new() -> Inject<T> { - Inject { - pointers: Mutex::new(Pointers { - is_closed: false, - head: None, - tail: None, - }), - len: AtomicUsize::new(0), - _p: PhantomData, - } - } - - pub(super) fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Close the injection queue, returns `true` if the queue is open when the - /// transition is made. - pub(super) fn close(&self) -> bool { - let mut p = self.pointers.lock(); - - if p.is_closed { - return false; - } - - p.is_closed = true; - true - } - - pub(super) fn is_closed(&self) -> bool { - self.pointers.lock().is_closed - } - - pub(super) fn len(&self) -> usize { - self.len.load(Acquire) - } - - /// Pushes a value into the queue. - pub(super) fn push(&self, task: task::Notified<T>) { - // Acquire queue lock - let mut p = self.pointers.lock(); - - if p.is_closed { - // Drop the mutex to avoid a potential deadlock when - // re-entering. - drop(p); - drop(task); - return; - } - - // safety: only mutated with the lock held - let len = unsafe { self.len.unsync_load() }; - let task = task.into_raw(); - - // The next pointer should already be null - debug_assert!(get_next(task).is_none()); - - if let Some(tail) = p.tail { - set_next(tail, Some(task)); - } else { - p.head = Some(task); - } - - p.tail = Some(task); - - self.len.store(len + 1, Release); - } - - pub(super) fn push_batch( - &self, - batch_head: task::Notified<T>, - batch_tail: task::Notified<T>, - num: usize, - ) { - let batch_head = batch_head.into_raw(); - let batch_tail = batch_tail.into_raw(); - - debug_assert!(get_next(batch_tail).is_none()); - - let mut p = self.pointers.lock(); - - if let Some(tail) = p.tail { - set_next(tail, Some(batch_head)); - } else { - p.head = Some(batch_head); - } - - p.tail = Some(batch_tail); - - // Increment the count. - // - // safety: All updates to the len atomic are guarded by the mutex. As - // such, a non-atomic load followed by a store is safe. - let len = unsafe { self.len.unsync_load() }; - - self.len.store(len + num, Release); - } - - pub(super) fn pop(&self) -> Option<task::Notified<T>> { - // Fast path, if len == 0, then there are no values - if self.is_empty() { - return None; - } - - let mut p = self.pointers.lock(); - - // It is possible to hit null here if another thread popped the last - // task between us checking `len` and acquiring the lock. - let task = p.head?; - - p.head = get_next(task); - - if p.head.is_none() { - p.tail = None; - } - - set_next(task, None); - - // Decrement the count. - // - // safety: All updates to the len atomic are guarded by the mutex. As - // such, a non-atomic load followed by a store is safe. - self.len - .store(unsafe { self.len.unsync_load() } - 1, Release); - - // safety: a `Notified` is pushed into the queue and now it is popped! - Some(unsafe { task::Notified::from_raw(task) }) - } -} - -impl<T: 'static> Drop for Inject<T> { - fn drop(&mut self) { - if !std::thread::panicking() { - assert!(self.pop().is_none(), "queue not empty"); - } - } -} - -fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> { - unsafe { header.as_ref().queue_next.with(|ptr| *ptr) } -} - -fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) { - unsafe { - header.as_ref().set_next(val); - } -} - /// Split the head value into the real head and the index a stealer is working /// on. fn unpack(n: u32) -> (u16, u16) { - let real = n & u16::max_value() as u32; + let real = n & u16::MAX as u32; let steal = n >> 16; (steal as u16, real as u16) @@ -630,5 +485,5 @@ fn pack(steal: u16, real: u16) -> u32 { #[test] fn test_local_queue_capacity() { - assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::max_value() as usize); + assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize); } diff --git a/src/runtime/shell.rs b/src/runtime/shell.rs deleted file mode 100644 index 486d4fa..0000000 --- a/src/runtime/shell.rs +++ /dev/null @@ -1,132 +0,0 @@ -#![allow(clippy::redundant_clone)] - -use crate::future::poll_fn; -use crate::park::{Park, Unpark}; -use crate::runtime::driver::Driver; -use crate::sync::Notify; -use crate::util::{waker_ref, Wake}; - -use std::sync::{Arc, Mutex}; -use std::task::Context; -use std::task::Poll::{Pending, Ready}; -use std::{future::Future, sync::PoisonError}; - -#[derive(Debug)] -pub(super) struct Shell { - driver: Mutex<Option<Driver>>, - - notify: Notify, - - /// TODO: don't store this - unpark: Arc<Handle>, -} - -#[derive(Debug)] -struct Handle(<Driver as Park>::Unpark); - -impl Shell { - pub(super) fn new(driver: Driver) -> Shell { - let unpark = Arc::new(Handle(driver.unpark())); - - Shell { - driver: Mutex::new(Some(driver)), - notify: Notify::new(), - unpark, - } - } - - pub(super) fn block_on<F>(&self, f: F) -> F::Output - where - F: Future, - { - let mut enter = crate::runtime::enter(true); - - pin!(f); - - loop { - if let Some(driver) = &mut self.take_driver() { - return driver.block_on(f); - } else { - let notified = self.notify.notified(); - pin!(notified); - - if let Some(out) = enter - .block_on(poll_fn(|cx| { - if notified.as_mut().poll(cx).is_ready() { - return Ready(None); - } - - if let Ready(out) = f.as_mut().poll(cx) { - return Ready(Some(out)); - } - - Pending - })) - .expect("Failed to `Enter::block_on`") - { - return out; - } - } - } - } - - fn take_driver(&self) -> Option<DriverGuard<'_>> { - let mut lock = self.driver.lock().unwrap(); - let driver = lock.take()?; - - Some(DriverGuard { - inner: Some(driver), - shell: &self, - }) - } -} - -impl Wake for Handle { - /// Wake by value - fn wake(self: Arc<Self>) { - Wake::wake_by_ref(&self); - } - - /// Wake by reference - fn wake_by_ref(arc_self: &Arc<Self>) { - arc_self.0.unpark(); - } -} - -struct DriverGuard<'a> { - inner: Option<Driver>, - shell: &'a Shell, -} - -impl DriverGuard<'_> { - fn block_on<F: Future>(&mut self, f: F) -> F::Output { - let driver = self.inner.as_mut().unwrap(); - - pin!(f); - - let waker = waker_ref(&self.shell.unpark); - let mut cx = Context::from_waker(&waker); - - loop { - if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) { - return v; - } - - driver.park().unwrap(); - } - } -} - -impl Drop for DriverGuard<'_> { - fn drop(&mut self) { - if let Some(inner) = self.inner.take() { - self.shell - .driver - .lock() - .unwrap_or_else(PoisonError::into_inner) - .replace(inner); - - self.shell.notify.notify_one(); - } - } -} diff --git a/src/runtime/spawner.rs b/src/runtime/spawner.rs index a37c667..9a3d465 100644 --- a/src/runtime/spawner.rs +++ b/src/runtime/spawner.rs @@ -1,9 +1,7 @@ -cfg_rt! { - use crate::runtime::basic_scheduler; - use crate::task::JoinHandle; - - use std::future::Future; -} +use crate::future::Future; +use crate::runtime::basic_scheduler; +use crate::runtime::stats::RuntimeStats; +use crate::task::JoinHandle; cfg_rt_multi_thread! { use crate::runtime::thread_pool; @@ -11,7 +9,6 @@ cfg_rt_multi_thread! { #[derive(Debug, Clone)] pub(crate) enum Spawner { - #[cfg(feature = "rt")] Basic(basic_scheduler::Spawner), #[cfg(feature = "rt-multi-thread")] ThreadPool(thread_pool::Spawner), @@ -26,21 +23,25 @@ impl Spawner { } } } -} -cfg_rt! { - impl Spawner { - pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - match self { - #[cfg(feature = "rt")] - Spawner::Basic(spawner) => spawner.spawn(future), - #[cfg(feature = "rt-multi-thread")] - Spawner::ThreadPool(spawner) => spawner.spawn(future), - } + pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + match self { + Spawner::Basic(spawner) => spawner.spawn(future), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.spawn(future), + } + } + + #[cfg_attr(not(all(tokio_unstable, feature = "stats")), allow(dead_code))] + pub(crate) fn stats(&self) -> &RuntimeStats { + match self { + Spawner::Basic(spawner) => spawner.stats(), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.stats(), } } } diff --git a/src/runtime/stats/mock.rs b/src/runtime/stats/mock.rs new file mode 100644 index 0000000..3bda8bf --- /dev/null +++ b/src/runtime/stats/mock.rs @@ -0,0 +1,27 @@ +//! This file contains mocks of the types in src/runtime/stats/stats.rs + +pub(crate) struct RuntimeStats {} + +impl RuntimeStats { + pub(crate) fn new(_worker_threads: usize) -> Self { + Self {} + } +} + +pub(crate) struct WorkerStatsBatcher {} + +impl WorkerStatsBatcher { + pub(crate) fn new(_my_index: usize) -> Self { + Self {} + } + + pub(crate) fn submit(&mut self, _to: &RuntimeStats) {} + + pub(crate) fn about_to_park(&mut self) {} + pub(crate) fn returned_from_park(&mut self) {} + + #[cfg(feature = "rt-multi-thread")] + pub(crate) fn incr_steal_count(&mut self, _by: u16) {} + + pub(crate) fn incr_poll_count(&mut self) {} +} diff --git a/src/runtime/stats/mod.rs b/src/runtime/stats/mod.rs new file mode 100644 index 0000000..5e08e8e --- /dev/null +++ b/src/runtime/stats/mod.rs @@ -0,0 +1,17 @@ +//! This module contains information need to view information about how the +//! runtime is performing. +#![allow(clippy::module_inception)] + +cfg_stats! { + mod stats; + + pub use self::stats::{RuntimeStats, WorkerStats}; + pub(crate) use self::stats::WorkerStatsBatcher; +} + +cfg_not_stats! { + #[path = "mock.rs"] + mod stats; + + pub(crate) use self::stats::{RuntimeStats, WorkerStatsBatcher}; +} diff --git a/src/runtime/stats/stats.rs b/src/runtime/stats/stats.rs new file mode 100644 index 0000000..b2bcacc --- /dev/null +++ b/src/runtime/stats/stats.rs @@ -0,0 +1,122 @@ +//! This file contains the types necessary to collect various types of stats. +use crate::loom::sync::atomic::{AtomicU64, Ordering::Relaxed}; + +use std::convert::TryFrom; +use std::time::{Duration, Instant}; + +/// This type contains methods to retrieve stats from a Tokio runtime. +#[derive(Debug)] +pub struct RuntimeStats { + workers: Box<[WorkerStats]>, +} + +/// This type contains methods to retrieve stats from a worker thread on a Tokio runtime. +#[derive(Debug)] +#[repr(align(128))] +pub struct WorkerStats { + park_count: AtomicU64, + steal_count: AtomicU64, + poll_count: AtomicU64, + busy_duration_total: AtomicU64, +} + +impl RuntimeStats { + pub(crate) fn new(worker_threads: usize) -> Self { + let mut workers = Vec::with_capacity(worker_threads); + for _ in 0..worker_threads { + workers.push(WorkerStats { + park_count: AtomicU64::new(0), + steal_count: AtomicU64::new(0), + poll_count: AtomicU64::new(0), + busy_duration_total: AtomicU64::new(0), + }); + } + + Self { + workers: workers.into_boxed_slice(), + } + } + + /// Returns a slice containing the worker stats for each worker thread. + pub fn workers(&self) -> impl Iterator<Item = &WorkerStats> { + self.workers.iter() + } +} + +impl WorkerStats { + /// Returns the total number of times this worker thread has parked. + pub fn park_count(&self) -> u64 { + self.park_count.load(Relaxed) + } + + /// Returns the number of tasks this worker has stolen from other worker + /// threads. + pub fn steal_count(&self) -> u64 { + self.steal_count.load(Relaxed) + } + + /// Returns the number of times this worker has polled a task. + pub fn poll_count(&self) -> u64 { + self.poll_count.load(Relaxed) + } + + /// Returns the total amount of time this worker has been busy for. + pub fn total_busy_duration(&self) -> Duration { + Duration::from_nanos(self.busy_duration_total.load(Relaxed)) + } +} + +pub(crate) struct WorkerStatsBatcher { + my_index: usize, + park_count: u64, + steal_count: u64, + poll_count: u64, + /// The total busy duration in nanoseconds. + busy_duration_total: u64, + last_resume_time: Instant, +} + +impl WorkerStatsBatcher { + pub(crate) fn new(my_index: usize) -> Self { + Self { + my_index, + park_count: 0, + steal_count: 0, + poll_count: 0, + busy_duration_total: 0, + last_resume_time: Instant::now(), + } + } + pub(crate) fn submit(&mut self, to: &RuntimeStats) { + let worker = &to.workers[self.my_index]; + + worker.park_count.store(self.park_count, Relaxed); + worker.steal_count.store(self.steal_count, Relaxed); + worker.poll_count.store(self.poll_count, Relaxed); + + worker + .busy_duration_total + .store(self.busy_duration_total, Relaxed); + } + + pub(crate) fn about_to_park(&mut self) { + self.park_count += 1; + + let busy_duration = self.last_resume_time.elapsed(); + let busy_duration = u64::try_from(busy_duration.as_nanos()).unwrap_or(u64::MAX); + self.busy_duration_total += busy_duration; + } + + pub(crate) fn returned_from_park(&mut self) { + self.last_resume_time = Instant::now(); + } + + #[cfg(feature = "rt-multi-thread")] + pub(crate) fn incr_steal_count(&mut self, by: u16) { + self.steal_count += u64::from(by); + } + + pub(crate) fn incr_poll_count(&mut self) { + self.poll_count += 1; + } +} diff --git a/src/runtime/task/core.rs b/src/runtime/task/core.rs index 9f7ff55..776e834 100644 --- a/src/runtime/task/core.rs +++ b/src/runtime/task/core.rs @@ -9,13 +9,13 @@ //! Make sure to consult the relevant safety section of each function before //! use. +use crate::future::Future; use crate::loom::cell::UnsafeCell; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; -use crate::runtime::task::{Notified, Schedule, Task}; +use crate::runtime::task::Schedule; use crate::util::linked_list; -use std::future::Future; use std::pin::Pin; use std::ptr::NonNull; use std::task::{Context, Poll, Waker}; @@ -36,10 +36,6 @@ pub(super) struct Cell<T: Future, S> { pub(super) trailer: Trailer, } -pub(super) struct Scheduler<S> { - scheduler: UnsafeCell<Option<S>>, -} - pub(super) struct CoreStage<T: Future> { stage: UnsafeCell<Stage<T>>, } @@ -48,29 +44,43 @@ pub(super) struct CoreStage<T: Future> { /// /// Holds the future or output, depending on the stage of execution. pub(super) struct Core<T: Future, S> { - /// Scheduler used to drive this future - pub(super) scheduler: Scheduler<S>, + /// Scheduler used to drive this future. + pub(super) scheduler: S, - /// Either the future or the output + /// Either the future or the output. pub(super) stage: CoreStage<T>, } /// Crate public as this is also needed by the pool. #[repr(C)] pub(crate) struct Header { - /// Task state + /// Task state. pub(super) state: State, - pub(crate) owned: UnsafeCell<linked_list::Pointers<Header>>, - - /// Pointer to next task, used with the injection queue - pub(crate) queue_next: UnsafeCell<Option<NonNull<Header>>>, + pub(super) owned: UnsafeCell<linked_list::Pointers<Header>>, - /// Pointer to the next task in the transfer stack - pub(super) stack_next: UnsafeCell<Option<NonNull<Header>>>, + /// Pointer to next task, used with the injection queue. + pub(super) queue_next: UnsafeCell<Option<NonNull<Header>>>, /// Table of function pointers for executing actions on the task. pub(super) vtable: &'static Vtable, + + /// This integer contains the id of the OwnedTasks or LocalOwnedTasks that + /// this task is stored in. If the task is not in any list, should be the + /// id of the list that it was previously in, or zero if it has never been + /// in any list. + /// + /// Once a task has been bound to a list, it can never be bound to another + /// list, even if removed from the first list. + /// + /// The id is not unset when removed from a list because we want to be able + /// to read the id without synchronization, even if it is concurrently being + /// removed from the list. + pub(super) owner_id: UnsafeCell<u64>, + + /// The tracing ID for this instrumented task. + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) id: Option<tracing::Id>, } unsafe impl Send for Header {} @@ -92,19 +102,21 @@ pub(super) enum Stage<T: Future> { impl<T: Future, S: Schedule> Cell<T, S> { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, state: State) -> Box<Cell<T, S>> { + pub(super) fn new(future: T, scheduler: S, state: State) -> Box<Cell<T, S>> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let id = future.id(); Box::new(Cell { header: Header { state, owned: UnsafeCell::new(linked_list::Pointers::new()), queue_next: UnsafeCell::new(None), - stack_next: UnsafeCell::new(None), vtable: raw::vtable::<T, S>(), + owner_id: UnsafeCell::new(0), + #[cfg(all(tokio_unstable, feature = "tracing"))] + id, }, core: Core { - scheduler: Scheduler { - scheduler: UnsafeCell::new(None), - }, + scheduler, stage: CoreStage { stage: UnsafeCell::new(Stage::Running(future)), }, @@ -116,103 +128,17 @@ impl<T: Future, S: Schedule> Cell<T, S> { } } -impl<S: Schedule> Scheduler<S> { - pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Option<S>) -> R) -> R { - self.scheduler.with_mut(f) - } - - /// Bind a scheduler to the task. - /// - /// This only happens on the first poll and must be preceeded by a call to - /// `is_bound` to determine if binding is appropriate or not. - /// - /// # Safety - /// - /// Binding must not be done concurrently since it will mutate the task - /// core through a shared reference. - pub(super) fn bind_scheduler(&self, task: Task<S>) { - // This function may be called concurrently, but the __first__ time it - // is called, the caller has unique access to this field. All subsequent - // concurrent calls will be via the `Waker`, which will "happens after" - // the first poll. - // - // In other words, it is always safe to read the field and it is safe to - // write to the field when it is `None`. - debug_assert!(!self.is_bound()); - - // Bind the task to the scheduler - let scheduler = S::bind(task); - - // Safety: As `scheduler` is not set, this is the first poll - self.scheduler.with_mut(|ptr| unsafe { - *ptr = Some(scheduler); - }); - } - - /// Returns true if the task is bound to a scheduler. - pub(super) fn is_bound(&self) -> bool { - // Safety: never called concurrently w/ a mutation. - self.scheduler.with(|ptr| unsafe { (*ptr).is_some() }) - } - - /// Schedule the future for execution - pub(super) fn schedule(&self, task: Notified<S>) { - self.scheduler.with(|ptr| { - // Safety: Can only be called after initial `poll`, which is the - // only time the field is mutated. - match unsafe { &*ptr } { - Some(scheduler) => scheduler.schedule(task), - None => panic!("no scheduler set"), - } - }); - } - - /// Schedule the future for execution in the near future, yielding the - /// thread to other tasks. - pub(super) fn yield_now(&self, task: Notified<S>) { - self.scheduler.with(|ptr| { - // Safety: Can only be called after initial `poll`, which is the - // only time the field is mutated. - match unsafe { &*ptr } { - Some(scheduler) => scheduler.yield_now(task), - None => panic!("no scheduler set"), - } - }); - } - - /// Release the task - /// - /// If the `Scheduler` implementation is able to, it returns the `Task` - /// handle immediately. The caller of this function will batch a ref-dec - /// with a state change. - pub(super) fn release(&self, task: Task<S>) -> Option<Task<S>> { - use std::mem::ManuallyDrop; - - let task = ManuallyDrop::new(task); - - self.scheduler.with(|ptr| { - // Safety: Can only be called after initial `poll`, which is the - // only time the field is mutated. - match unsafe { &*ptr } { - Some(scheduler) => scheduler.release(&*task), - // Task was never polled - None => None, - } - }) - } -} - impl<T: Future> CoreStage<T> { pub(super) fn with_mut<R>(&self, f: impl FnOnce(*mut Stage<T>) -> R) -> R { self.stage.with_mut(f) } - /// Poll the future + /// Polls the future. /// /// # Safety /// /// The caller must ensure it is safe to mutate the `state` field. This - /// requires ensuring mutal exclusion between any concurrent thread that + /// requires ensuring mutual exclusion between any concurrent thread that /// might modify the future or output field. /// /// The mutual exclusion is implemented by `Harness` and the `Lifecycle` @@ -243,7 +169,7 @@ impl<T: Future> CoreStage<T> { res } - /// Drop the future + /// Drops the future. /// /// # Safety /// @@ -255,7 +181,7 @@ impl<T: Future> CoreStage<T> { } } - /// Store the task output + /// Stores the task output. /// /// # Safety /// @@ -267,7 +193,7 @@ impl<T: Future> CoreStage<T> { } } - /// Take the task output + /// Takes the task output. /// /// # Safety /// @@ -276,10 +202,10 @@ impl<T: Future> CoreStage<T> { use std::mem; self.stage.with_mut(|ptr| { - // Safety:: the caller ensures mutal exclusion to the field. + // Safety:: the caller ensures mutual exclusion to the field. match mem::replace(unsafe { &mut *ptr }, Stage::Consumed) { Stage::Finished(output) => output, - _ => panic!("unexpected task state"), + _ => panic!("JoinHandle polled after completion"), } }) } @@ -291,32 +217,40 @@ impl<T: Future> CoreStage<T> { cfg_rt_multi_thread! { impl Header { - pub(crate) fn shutdown(&self) { - use crate::runtime::task::RawTask; - - let task = unsafe { RawTask::from_raw(self.into()) }; - task.shutdown(); - } - - pub(crate) unsafe fn set_next(&self, next: Option<NonNull<Header>>) { + pub(super) unsafe fn set_next(&self, next: Option<NonNull<Header>>) { self.queue_next.with_mut(|ptr| *ptr = next); } } } +impl Header { + // safety: The caller must guarantee exclusive access to this field, and + // must ensure that the id is either 0 or the id of the OwnedTasks + // containing this task. + pub(super) unsafe fn set_owner_id(&self, owner: u64) { + self.owner_id.with_mut(|ptr| *ptr = owner); + } + + pub(super) fn get_owner_id(&self) -> u64 { + // safety: If there are concurrent writes, then that write has violated + // the safety requirements on `set_owner_id`. + unsafe { self.owner_id.with(|ptr| *ptr) } + } +} + impl Trailer { - pub(crate) unsafe fn set_waker(&self, waker: Option<Waker>) { + pub(super) unsafe fn set_waker(&self, waker: Option<Waker>) { self.waker.with_mut(|ptr| { *ptr = waker; }); } - pub(crate) unsafe fn will_wake(&self, waker: &Waker) -> bool { + pub(super) unsafe fn will_wake(&self, waker: &Waker) -> bool { self.waker .with(|ptr| (*ptr).as_ref().unwrap().will_wake(waker)) } - pub(crate) fn wake_join(&self) { + pub(super) fn wake_join(&self) { self.waker.with(|ptr| match unsafe { &*ptr } { Some(waker) => waker.wake_by_ref(), None => panic!("waker missing"), diff --git a/src/runtime/task/error.rs b/src/runtime/task/error.rs index 177fe65..1a8129b 100644 --- a/src/runtime/task/error.rs +++ b/src/runtime/task/error.rs @@ -1,7 +1,8 @@ use std::any::Any; use std::fmt; use std::io; -use std::sync::Mutex; + +use crate::util::SyncWrapper; cfg_rt! { /// Task failed to execute to completion. @@ -12,7 +13,7 @@ cfg_rt! { enum Repr { Cancelled, - Panic(Mutex<Box<dyn Any + Send + 'static>>), + Panic(SyncWrapper<Box<dyn Any + Send + 'static>>), } impl JoinError { @@ -24,16 +25,16 @@ impl JoinError { pub(crate) fn panic(err: Box<dyn Any + Send + 'static>) -> JoinError { JoinError { - repr: Repr::Panic(Mutex::new(err)), + repr: Repr::Panic(SyncWrapper::new(err)), } } - /// Returns true if the error was caused by the task being cancelled + /// Returns true if the error was caused by the task being cancelled. pub fn is_cancelled(&self) -> bool { matches!(&self.repr, Repr::Cancelled) } - /// Returns true if the error was caused by the task panicking + /// Returns true if the error was caused by the task panicking. /// /// # Examples /// @@ -106,7 +107,7 @@ impl JoinError { /// ``` pub fn try_into_panic(self) -> Result<Box<dyn Any + Send + 'static>, JoinError> { match self.repr { - Repr::Panic(p) => Ok(p.into_inner().expect("Extracting panic from mutex")), + Repr::Panic(p) => Ok(p.into_inner()), _ => Err(self), } } diff --git a/src/runtime/task/harness.rs b/src/runtime/task/harness.rs index 7d596e3..0996e52 100644 --- a/src/runtime/task/harness.rs +++ b/src/runtime/task/harness.rs @@ -1,15 +1,16 @@ -use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Scheduler, Trailer}; +use crate::future::Future; +use crate::runtime::task::core::{Cell, Core, CoreStage, Header, Trailer}; use crate::runtime::task::state::Snapshot; use crate::runtime::task::waker::waker_ref; use crate::runtime::task::{JoinError, Notified, Schedule, Task}; -use std::future::Future; use std::mem; +use std::mem::ManuallyDrop; use std::panic; use std::ptr::NonNull; use std::task::{Context, Poll, Waker}; -/// Typed raw task handle +/// Typed raw task handle. pub(super) struct Harness<T: Future, S: 'static> { cell: NonNull<Cell<T, S>>, } @@ -36,13 +37,6 @@ where fn core(&self) -> &Core<T, S> { unsafe { &self.cell.as_ref().core } } - - fn scheduler_view(&self) -> SchedulerView<'_, S> { - SchedulerView { - header: self.header(), - scheduler: &self.core().scheduler, - } - } } impl<T, S> Harness<T, S> @@ -50,43 +44,103 @@ where T: Future, S: Schedule, { - /// Polls the inner future. + /// Polls the inner future. A ref-count is consumed. /// /// All necessary state checks and transitions are performed. - /// /// Panics raised while polling the future are handled. pub(super) fn poll(self) { + // We pass our ref-count to `poll_inner`. match self.poll_inner() { PollFuture::Notified => { - // Signal yield - self.core().scheduler.yield_now(Notified(self.to_task())); - // The ref-count was incremented as part of - // `transition_to_idle`. + // The `poll_inner` call has given us two ref-counts back. + // We give one of them to a new task and call `yield_now`. + self.core() + .scheduler + .yield_now(Notified(self.get_new_task())); + + // The remaining ref-count is now dropped. We kept the extra + // ref-count until now to ensure that even if the `yield_now` + // call drops the provided task, the task isn't deallocated + // before after `yield_now` returns. self.drop_reference(); } - PollFuture::DropReference => { - self.drop_reference(); + PollFuture::Complete => { + self.complete(); } - PollFuture::Complete(out, is_join_interested) => { - self.complete(out, is_join_interested); + PollFuture::Dealloc => { + self.dealloc(); } - PollFuture::None => (), + PollFuture::Done => (), } } - fn poll_inner(&self) -> PollFuture<T::Output> { - let snapshot = match self.scheduler_view().transition_to_running() { - TransitionToRunning::Ok(snapshot) => snapshot, - TransitionToRunning::DropReference => return PollFuture::DropReference, - }; + /// Polls the task and cancel it if necessary. This takes ownership of a + /// ref-count. + /// + /// If the return value is Notified, the caller is given ownership of two + /// ref-counts. + /// + /// If the return value is Complete, the caller is given ownership of a + /// single ref-count, which should be passed on to `complete`. + /// + /// If the return value is Dealloc, then this call consumed the last + /// ref-count and the caller should call `dealloc`. + /// + /// Otherwise the ref-count is consumed and the caller should not access + /// `self` again. + fn poll_inner(&self) -> PollFuture { + use super::state::{TransitionToIdle, TransitionToRunning}; + + match self.header().state.transition_to_running() { + TransitionToRunning::Success => { + let waker_ref = waker_ref::<T, S>(self.header()); + let cx = Context::from_waker(&*waker_ref); + let res = poll_future(&self.core().stage, cx); + + if res == Poll::Ready(()) { + // The future completed. Move on to complete the task. + return PollFuture::Complete; + } + + match self.header().state.transition_to_idle() { + TransitionToIdle::Ok => PollFuture::Done, + TransitionToIdle::OkNotified => PollFuture::Notified, + TransitionToIdle::OkDealloc => PollFuture::Dealloc, + TransitionToIdle::Cancelled => { + // The transition to idle failed because the task was + // cancelled during the poll. + + cancel_task(&self.core().stage); + PollFuture::Complete + } + } + } + TransitionToRunning::Cancelled => { + cancel_task(&self.core().stage); + PollFuture::Complete + } + TransitionToRunning::Failed => PollFuture::Done, + TransitionToRunning::Dealloc => PollFuture::Dealloc, + } + } - // The transition to `Running` done above ensures that a lock on the - // future has been obtained. This also ensures the `*mut T` pointer - // contains the future (as opposed to the output) and is initialized. + /// Forcibly shuts down the task. + /// + /// Attempt to transition to `Running` in order to forcibly shutdown the + /// task. If the task is currently running or in a state of completion, then + /// there is nothing further to do. When the task completes running, it will + /// notice the `CANCELLED` bit and finalize the task. + pub(super) fn shutdown(self) { + if !self.header().state.transition_to_shutdown() { + // The task is concurrently running. No further work needed. + self.drop_reference(); + return; + } - let waker_ref = waker_ref::<T, S>(self.header()); - let cx = Context::from_waker(&*waker_ref); - poll_future(self.header(), &self.core().stage, snapshot, cx) + // By transitioning the lifecycle to `Running`, we have permission to + // drop the future. + cancel_task(&self.core().stage); + self.complete(); } pub(super) fn dealloc(self) { @@ -95,7 +149,6 @@ where // Check causality self.core().stage.with_mut(drop); - self.core().scheduler.with_mut(drop); unsafe { drop(Box::from_raw(self.cell.as_ptr())); @@ -112,6 +165,8 @@ where } pub(super) fn drop_join_handle_slow(self) { + let mut maybe_panic = None; + // Try to unset `JOIN_INTEREST`. This must be done as a first step in // case the task concurrently completed. if self.header().state.unset_join_interested().is_err() { @@ -120,23 +175,95 @@ where // the scheduler or `JoinHandle`. i.e. if the output remains in the // task structure until the task is deallocated, it may be dropped // by a Waker on any arbitrary thread. - self.core().stage.drop_future_or_output(); + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + self.core().stage.drop_future_or_output(); + })); + + if let Err(panic) = panic { + maybe_panic = Some(panic); + } } // Drop the `JoinHandle` reference, possibly deallocating the task self.drop_reference(); + + if let Some(panic) = maybe_panic { + panic::resume_unwind(panic); + } + } + + /// Remotely aborts the task. + /// + /// The caller should hold a ref-count, but we do not consume it. + /// + /// This is similar to `shutdown` except that it asks the runtime to perform + /// the shutdown. This is necessary to avoid the shutdown happening in the + /// wrong thread for non-Send tasks. + pub(super) fn remote_abort(self) { + if self.header().state.transition_to_notified_and_cancel() { + // The transition has created a new ref-count, which we turn into + // a Notified and pass to the task. + // + // Since the caller holds a ref-count, the task cannot be destroyed + // before the call to `schedule` returns even if the call drops the + // `Notified` internally. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } } // ===== waker behavior ===== + /// This call consumes a ref-count and notifies the task. This will create a + /// new Notified and submit it if necessary. + /// + /// The caller does not need to hold a ref-count besides the one that was + /// passed to this call. pub(super) fn wake_by_val(self) { - self.wake_by_ref(); - self.drop_reference(); + use super::state::TransitionToNotifiedByVal; + + match self.header().state.transition_to_notified_by_val() { + TransitionToNotifiedByVal::Submit => { + // The caller has given us a ref-count, and the transition has + // created a new ref-count, so we now hold two. We turn the new + // ref-count Notified and pass it to the call to `schedule`. + // + // The old ref-count is retained for now to ensure that the task + // is not dropped during the call to `schedule` if the call + // drops the task it was given. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + + // Now that we have completed the call to schedule, we can + // release our ref-count. + self.drop_reference(); + } + TransitionToNotifiedByVal::Dealloc => { + self.dealloc(); + } + TransitionToNotifiedByVal::DoNothing => {} + } } + /// This call notifies the task. It will not consume any ref-counts, but the + /// caller should hold a ref-count. This will create a new Notified and + /// submit it if necessary. pub(super) fn wake_by_ref(&self) { - if self.header().state.transition_to_notified() { - self.core().scheduler.schedule(Notified(self.to_task())); + use super::state::TransitionToNotifiedByRef; + + match self.header().state.transition_to_notified_by_ref() { + TransitionToNotifiedByRef::Submit => { + // The transition above incremented the ref-count for a new task + // and the caller also holds a ref-count. The caller's ref-count + // ensures that the task is not destroyed even if the new task + // is dropped before `schedule` returns. + self.core() + .scheduler + .schedule(Notified(self.get_new_task())); + } + TransitionToNotifiedByRef::DoNothing => {} } } @@ -146,153 +273,70 @@ where } } - /// Forcibly shutdown the task - /// - /// Attempt to transition to `Running` in order to forcibly shutdown the - /// task. If the task is currently running or in a state of completion, then - /// there is nothing further to do. When the task completes running, it will - /// notice the `CANCELLED` bit and finalize the task. - pub(super) fn shutdown(self) { - if !self.header().state.transition_to_shutdown() { - // The task is concurrently running. No further work needed. - return; - } - - // By transitioning the lifcycle to `Running`, we have permission to - // drop the future. - let err = cancel_task(&self.core().stage); - self.complete(Err(err), true) + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(super) fn id(&self) -> Option<&tracing::Id> { + self.header().id.as_ref() } // ====== internal ====== - fn complete(self, output: super::Result<T::Output>, is_join_interested: bool) { - if is_join_interested { - // Store the output. The future has already been dropped - // - // Safety: Mutual exclusion is obtained by having transitioned the task - // state -> Running - let stage = &self.core().stage; - stage.store_output(output); - - // Transition to `Complete`, notifying the `JoinHandle` if necessary. - transition_to_complete(self.header(), stage, &self.trailer()); - } + /// Completes the task. This method assumes that the state is RUNNING. + fn complete(self) { + // The future has completed and its output has been written to the task + // stage. We transition from running to complete. + + let snapshot = self.header().state.transition_to_complete(); + + // We catch panics here in case dropping the future or waking the + // JoinHandle panics. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + if !snapshot.is_join_interested() { + // The `JoinHandle` is not interested in the output of + // this task. It is our responsibility to drop the + // output. + self.core().stage.drop_future_or_output(); + } else if snapshot.has_join_waker() { + // Notify the join handle. The previous transition obtains the + // lock on the waker cell. + self.trailer().wake_join(); + } + })); // The task has completed execution and will no longer be scheduled. - // - // Attempts to batch a ref-dec with the state transition below. - - if self - .scheduler_view() - .transition_to_terminal(is_join_interested) - { - self.dealloc() - } - } + let num_release = self.release(); - fn to_task(&self) -> Task<S> { - self.scheduler_view().to_task() + if self.header().state.transition_to_terminal(num_release) { + self.dealloc(); + } } -} -enum TransitionToRunning { - Ok(Snapshot), - DropReference, -} + /// Releases the task from the scheduler. Returns the number of ref-counts + /// that should be decremented. + fn release(&self) -> usize { + // We don't actually increment the ref-count here, but the new task is + // never destroyed, so that's ok. + let me = ManuallyDrop::new(self.get_new_task()); -struct SchedulerView<'a, S> { - header: &'a Header, - scheduler: &'a Scheduler<S>, -} - -impl<'a, S> SchedulerView<'a, S> -where - S: Schedule, -{ - fn to_task(&self) -> Task<S> { - // SAFETY The header is from the same struct containing the scheduler `S` so the cast is safe - unsafe { Task::from_raw(self.header.into()) } - } - - /// Returns true if the task should be deallocated. - fn transition_to_terminal(&self, is_join_interested: bool) -> bool { - let ref_dec = if self.scheduler.is_bound() { - if let Some(task) = self.scheduler.release(self.to_task()) { - mem::forget(task); - true - } else { - false - } + if let Some(task) = self.core().scheduler.release(&me) { + mem::forget(task); + 2 } else { - false - }; - - // This might deallocate - let snapshot = self - .header - .state - .transition_to_terminal(!is_join_interested, ref_dec); - - snapshot.ref_count() == 0 - } - - fn transition_to_running(&self) -> TransitionToRunning { - // If this is the first time the task is polled, the task will be bound - // to the scheduler, in which case the task ref count must be - // incremented. - let is_not_bound = !self.scheduler.is_bound(); - - // Transition the task to the running state. - // - // A failure to transition here indicates the task has been cancelled - // while in the run queue pending execution. - let snapshot = match self.header.state.transition_to_running(is_not_bound) { - Ok(snapshot) => snapshot, - Err(_) => { - // The task was shutdown while in the run queue. At this point, - // we just hold a ref counted reference. Since we do not have access to it here - // return `DropReference` so the caller drops it. - return TransitionToRunning::DropReference; - } - }; - - if is_not_bound { - // Ensure the task is bound to a scheduler instance. Since this is - // the first time polling the task, a scheduler instance is pulled - // from the local context and assigned to the task. - // - // The scheduler maintains ownership of the task and responds to - // `wake` calls. - // - // The task reference count has been incremented. - // - // Safety: Since we have unique access to the task so that we can - // safely call `bind_scheduler`. - self.scheduler.bind_scheduler(self.to_task()); + 1 } - TransitionToRunning::Ok(snapshot) } -} - -/// Transitions the task's lifecycle to `Complete`. Notifies the -/// `JoinHandle` if it still has interest in the completion. -fn transition_to_complete<T>(header: &Header, stage: &CoreStage<T>, trailer: &Trailer) -where - T: Future, -{ - // Transition the task's lifecycle to `Complete` and get a snapshot of - // the task's sate. - let snapshot = header.state.transition_to_complete(); - if !snapshot.is_join_interested() { - // The `JoinHandle` is not interested in the output of this task. It - // is our responsibility to drop the output. - stage.drop_future_or_output(); - } else if snapshot.has_join_waker() { - // Notify the join handle. The previous transition obtains the - // lock on the waker cell. - trailer.wake_join(); + /// Creates a new task that holds its own ref-count. + /// + /// # Safety + /// + /// Any use of `self` after this call must ensure that a ref-count to the + /// task holds the task alive until after the use of `self`. Passing the + /// returned Task to any method on `self` is unsound if dropping the Task + /// could drop `self` before the call on `self` returned. + fn get_new_task(&self) -> Task<S> { + // safety: The header is at the beginning of the cell, so this cast is + // safe. + unsafe { Task::from_raw(self.cell.cast()) } } } @@ -374,73 +418,62 @@ fn set_join_waker( res } -enum PollFuture<T> { - Complete(Result<T, JoinError>, bool), - DropReference, +enum PollFuture { + Complete, Notified, - None, + Done, + Dealloc, } -fn cancel_task<T: Future>(stage: &CoreStage<T>) -> JoinError { +/// Cancels the task and store the appropriate error in the stage field. +fn cancel_task<T: Future>(stage: &CoreStage<T>) { // Drop the future from a panic guard. let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { stage.drop_future_or_output(); })); - if let Err(err) = res { - // Dropping the future panicked, complete the join - // handle with the panic to avoid dropping the panic - // on the ground. - JoinError::panic(err) - } else { - JoinError::cancelled() + match res { + Ok(()) => { + stage.store_output(Err(JoinError::cancelled())); + } + Err(panic) => { + stage.store_output(Err(JoinError::panic(panic))); + } } } -fn poll_future<T: Future>( - header: &Header, - core: &CoreStage<T>, - snapshot: Snapshot, - cx: Context<'_>, -) -> PollFuture<T::Output> { - if snapshot.is_cancelled() { - PollFuture::Complete(Err(JoinError::cancelled()), snapshot.is_join_interested()) - } else { - let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { - struct Guard<'a, T: Future> { - core: &'a CoreStage<T>, - } - - impl<T: Future> Drop for Guard<'_, T> { - fn drop(&mut self) { - self.core.drop_future_or_output(); - } +/// Polls the future. If the future completes, the output is written to the +/// stage field. +fn poll_future<T: Future>(core: &CoreStage<T>, cx: Context<'_>) -> Poll<()> { + // Poll the future. + let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { + struct Guard<'a, T: Future> { + core: &'a CoreStage<T>, + } + impl<'a, T: Future> Drop for Guard<'a, T> { + fn drop(&mut self) { + // If the future panics on poll, we drop it inside the panic + // guard. + self.core.drop_future_or_output(); } + } + let guard = Guard { core }; + let res = guard.core.poll(cx); + mem::forget(guard); + res + })); - let guard = Guard { core }; - - let res = guard.core.poll(cx); + // Prepare output for being placed in the core stage. + let output = match output { + Ok(Poll::Pending) => return Poll::Pending, + Ok(Poll::Ready(output)) => Ok(output), + Err(panic) => Err(JoinError::panic(panic)), + }; - // prevent the guard from dropping the future - mem::forget(guard); + // Catch and ignore panics if the future panics on drop. + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + core.store_output(output); + })); - res - })); - match res { - Ok(Poll::Pending) => match header.state.transition_to_idle() { - Ok(snapshot) => { - if snapshot.is_notified() { - PollFuture::Notified - } else { - PollFuture::None - } - } - Err(_) => PollFuture::Complete(Err(cancel_task(core)), true), - }, - Ok(Poll::Ready(ok)) => PollFuture::Complete(Ok(ok), snapshot.is_join_interested()), - Err(err) => { - PollFuture::Complete(Err(JoinError::panic(err)), snapshot.is_join_interested()) - } - } - } + Poll::Ready(()) } diff --git a/src/runtime/task/inject.rs b/src/runtime/task/inject.rs new file mode 100644 index 0000000..1585e13 --- /dev/null +++ b/src/runtime/task/inject.rs @@ -0,0 +1,220 @@ +//! Inject queue used to send wakeups to a work-stealing scheduler + +use crate::loom::sync::atomic::AtomicUsize; +use crate::loom::sync::Mutex; +use crate::runtime::task; + +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::atomic::Ordering::{Acquire, Release}; + +/// Growable, MPMC queue used to inject new tasks into the scheduler and as an +/// overflow queue when the local, fixed-size, array queue overflows. +pub(crate) struct Inject<T: 'static> { + /// Pointers to the head and tail of the queue. + pointers: Mutex<Pointers>, + + /// Number of pending tasks in the queue. This helps prevent unnecessary + /// locking in the hot path. + len: AtomicUsize, + + _p: PhantomData<T>, +} + +struct Pointers { + /// True if the queue is closed. + is_closed: bool, + + /// Linked-list head. + head: Option<NonNull<task::Header>>, + + /// Linked-list tail. + tail: Option<NonNull<task::Header>>, +} + +unsafe impl<T> Send for Inject<T> {} +unsafe impl<T> Sync for Inject<T> {} + +impl<T: 'static> Inject<T> { + pub(crate) fn new() -> Inject<T> { + Inject { + pointers: Mutex::new(Pointers { + is_closed: false, + head: None, + tail: None, + }), + len: AtomicUsize::new(0), + _p: PhantomData, + } + } + + pub(crate) fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Closes the injection queue, returns `true` if the queue is open when the + /// transition is made. + pub(crate) fn close(&self) -> bool { + let mut p = self.pointers.lock(); + + if p.is_closed { + return false; + } + + p.is_closed = true; + true + } + + pub(crate) fn is_closed(&self) -> bool { + self.pointers.lock().is_closed + } + + pub(crate) fn len(&self) -> usize { + self.len.load(Acquire) + } + + /// Pushes a value into the queue. + /// + /// This does nothing if the queue is closed. + pub(crate) fn push(&self, task: task::Notified<T>) { + // Acquire queue lock + let mut p = self.pointers.lock(); + + if p.is_closed { + return; + } + + // safety: only mutated with the lock held + let len = unsafe { self.len.unsync_load() }; + let task = task.into_raw(); + + // The next pointer should already be null + debug_assert!(get_next(task).is_none()); + + if let Some(tail) = p.tail { + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(tail, Some(task)); + } else { + p.head = Some(task); + } + + p.tail = Some(task); + + self.len.store(len + 1, Release); + } + + /// Pushes several values into the queue. + #[inline] + pub(crate) fn push_batch<I>(&self, mut iter: I) + where + I: Iterator<Item = task::Notified<T>>, + { + let first = match iter.next() { + Some(first) => first.into_raw(), + None => return, + }; + + // Link up all the tasks. + let mut prev = first; + let mut counter = 1; + + // We are going to be called with an `std::iter::Chain`, and that + // iterator overrides `for_each` to something that is easier for the + // compiler to optimize than a loop. + iter.for_each(|next| { + let next = next.into_raw(); + + // safety: Holding the Notified for a task guarantees exclusive + // access to the `queue_next` field. + set_next(prev, Some(next)); + prev = next; + counter += 1; + }); + + // Now that the tasks are linked together, insert them into the + // linked list. + self.push_batch_inner(first, prev, counter); + } + + /// Inserts several tasks that have been linked together into the queue. + /// + /// The provided head and tail may be be the same task. In this case, a + /// single task is inserted. + #[inline] + fn push_batch_inner( + &self, + batch_head: NonNull<task::Header>, + batch_tail: NonNull<task::Header>, + num: usize, + ) { + debug_assert!(get_next(batch_tail).is_none()); + + let mut p = self.pointers.lock(); + + if let Some(tail) = p.tail { + set_next(tail, Some(batch_head)); + } else { + p.head = Some(batch_head); + } + + p.tail = Some(batch_tail); + + // Increment the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + let len = unsafe { self.len.unsync_load() }; + + self.len.store(len + num, Release); + } + + pub(crate) fn pop(&self) -> Option<task::Notified<T>> { + // Fast path, if len == 0, then there are no values + if self.is_empty() { + return None; + } + + let mut p = self.pointers.lock(); + + // It is possible to hit null here if another thread popped the last + // task between us checking `len` and acquiring the lock. + let task = p.head?; + + p.head = get_next(task); + + if p.head.is_none() { + p.tail = None; + } + + set_next(task, None); + + // Decrement the count. + // + // safety: All updates to the len atomic are guarded by the mutex. As + // such, a non-atomic load followed by a store is safe. + self.len + .store(unsafe { self.len.unsync_load() } - 1, Release); + + // safety: a `Notified` is pushed into the queue and now it is popped! + Some(unsafe { task::Notified::from_raw(task) }) + } +} + +impl<T: 'static> Drop for Inject<T> { + fn drop(&mut self) { + if !std::thread::panicking() { + assert!(self.pop().is_none(), "queue not empty"); + } + } +} + +fn get_next(header: NonNull<task::Header>) -> Option<NonNull<task::Header>> { + unsafe { header.as_ref().queue_next.with(|ptr| *ptr) } +} + +fn set_next(header: NonNull<task::Header>, val: Option<NonNull<task::Header>>) { + unsafe { + header.as_ref().set_next(val); + } +} diff --git a/src/runtime/task/join.rs b/src/runtime/task/join.rs index dedfb38..0abbff2 100644 --- a/src/runtime/task/join.rs +++ b/src/runtime/task/join.rs @@ -162,7 +162,7 @@ impl<T> JoinHandle<T> { /// /// Awaiting a cancelled task might complete as usual if the task was /// already completed at the time it was cancelled, but most likely it - /// will complete with a `Err(JoinError::Cancelled)`. + /// will fail with a [cancelled] `JoinError`. /// /// ```rust /// use tokio::time; @@ -190,9 +190,10 @@ impl<T> JoinHandle<T> { /// } /// } /// ``` + /// [cancelled]: method@super::error::JoinError::is_cancelled pub fn abort(&self) { if let Some(raw) = self.raw { - raw.shutdown(); + raw.remote_abort(); } } } diff --git a/src/runtime/task/list.rs b/src/runtime/task/list.rs new file mode 100644 index 0000000..7758f8d --- /dev/null +++ b/src/runtime/task/list.rs @@ -0,0 +1,297 @@ +//! This module has containers for storing the tasks spawned on a scheduler. The +//! `OwnedTasks` container is thread-safe but can only store tasks that +//! implement Send. The `LocalOwnedTasks` container is not thread safe, but can +//! store non-Send tasks. +//! +//! The collections can be closed to prevent adding new tasks during shutdown of +//! the scheduler with the collection. + +use crate::future::Future; +use crate::loom::cell::UnsafeCell; +use crate::loom::sync::Mutex; +use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; +use crate::util::linked_list::{Link, LinkedList}; + +use std::marker::PhantomData; + +// The id from the module below is used to verify whether a given task is stored +// in this OwnedTasks, or some other task. The counter starts at one so we can +// use zero for tasks not owned by any list. +// +// The safety checks in this file can technically be violated if the counter is +// overflown, but the checks are not supposed to ever fail unless there is a +// bug in Tokio, so we accept that certain bugs would not be caught if the two +// mixed up runtimes happen to have the same id. + +cfg_has_atomic_u64! { + use std::sync::atomic::{AtomicU64, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return id; + } + } + } +} + +cfg_not_has_atomic_u64! { + use std::sync::atomic::{AtomicU32, Ordering}; + + static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); + + fn get_next_id() -> u64 { + loop { + let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); + if id != 0 { + return u64::from(id); + } + } + } +} + +pub(crate) struct OwnedTasks<S: 'static> { + inner: Mutex<OwnedTasksInner<S>>, + id: u64, +} +pub(crate) struct LocalOwnedTasks<S: 'static> { + inner: UnsafeCell<OwnedTasksInner<S>>, + id: u64, + _not_send_or_sync: PhantomData<*const ()>, +} +struct OwnedTasksInner<S: 'static> { + list: LinkedList<Task<S>, <Task<S> as Link>::Target>, + closed: bool, +} + +impl<S: 'static> OwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: Mutex::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + } + } + + /// Binds the provided task to this OwnedTasks instance. This fails if the + /// OwnedTasks has been closed. + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + let mut lock = self.inner.lock(); + if lock.closed { + drop(lock); + drop(notified); + task.shutdown(); + (join, None) + } else { + lock.list.push_front(task); + (join, Some(notified)) + } + } + + /// Asserts that the given task is owned by this OwnedTasks and convert it to + /// a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: All tasks bound to this OwnedTasks are Send, so it is safe + // to poll it on this thread no matter what thread we are on. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + /// Shuts down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + // The first iteration of the loop was unrolled so it can set the + // closed bool. + let first_task = { + let mut lock = self.inner.lock(); + lock.closed = true; + lock.list.pop_back() + }; + match first_task { + Some(task) => task.shutdown(), + None => return, + } + + loop { + let task = match self.inner.lock().list.pop_back() { + Some(task) => task, + None => return, + }; + + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + // safety: We just checked that the provided task is not in some other + // linked list. + unsafe { self.inner.lock().list.remove(task.header().into()) } + } + + pub(crate) fn is_empty(&self) -> bool { + self.inner.lock().list.is_empty() + } +} + +impl<S: 'static> LocalOwnedTasks<S> { + pub(crate) fn new() -> Self { + Self { + inner: UnsafeCell::new(OwnedTasksInner { + list: LinkedList::new(), + closed: false, + }), + id: get_next_id(), + _not_send_or_sync: PhantomData, + } + } + + pub(crate) fn bind<T>( + &self, + task: T, + scheduler: S, + ) -> (JoinHandle<T::Output>, Option<Notified<S>>) + where + S: Schedule, + T: Future + 'static, + T::Output: 'static, + { + let (task, notified, join) = super::new_task(task, scheduler); + + unsafe { + // safety: We just created the task, so we have exclusive access + // to the field. + task.header().set_owner_id(self.id); + } + + if self.is_closed() { + drop(notified); + task.shutdown(); + (join, None) + } else { + self.with_inner(|inner| { + inner.list.push_front(task); + }); + (join, Some(notified)) + } + } + + /// Shuts down all tasks in the collection. This call also closes the + /// collection, preventing new items from being added. + pub(crate) fn close_and_shutdown_all(&self) + where + S: Schedule, + { + self.with_inner(|inner| inner.closed = true); + + while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { + task.shutdown(); + } + } + + pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { + let task_id = task.header().get_owner_id(); + if task_id == 0 { + // The task is unowned. + return None; + } + + assert_eq!(task_id, self.id); + + self.with_inner(|inner| + // safety: We just checked that the provided task is not in some + // other linked list. + unsafe { inner.list.remove(task.header().into()) }) + } + + /// Asserts that the given task is owned by this LocalOwnedTasks and convert + /// it to a LocalNotified, giving the thread permission to poll this task. + #[inline] + pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { + assert_eq!(task.header().get_owner_id(), self.id); + + // safety: The task was bound to this LocalOwnedTasks, and the + // LocalOwnedTasks is not Send or Sync, so we are on the right thread + // for polling this task. + LocalNotified { + task: task.0, + _not_send: PhantomData, + } + } + + #[inline] + fn with_inner<F, T>(&self, f: F) -> T + where + F: FnOnce(&mut OwnedTasksInner<S>) -> T, + { + // safety: This type is not Sync, so concurrent calls of this method + // can't happen. Furthermore, all uses of this method in this file make + // sure that they don't call `with_inner` recursively. + self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) + } + + pub(crate) fn is_closed(&self) -> bool { + self.with_inner(|inner| inner.closed) + } + + pub(crate) fn is_empty(&self) -> bool { + self.with_inner(|inner| inner.list.is_empty()) + } +} + +#[cfg(all(test))] +mod tests { + use super::*; + + // This test may run in parallel with other tests, so we only test that ids + // come in increasing order. + #[test] + fn test_id_not_broken() { + let mut last_id = get_next_id(); + assert_ne!(last_id, 0); + + for _ in 0..1000 { + let next_id = get_next_id(); + assert_ne!(next_id, 0); + assert!(last_id < next_id); + last_id = next_id; + } + } +} diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs index 7b49e95..1f18209 100644 --- a/src/runtime/task/mod.rs +++ b/src/runtime/task/mod.rs @@ -1,6 +1,143 @@ +//! The task module. +//! +//! The task module contains the code that manages spawned tasks and provides a +//! safe API for the rest of the runtime to use. Each task in a runtime is +//! stored in an OwnedTasks or LocalOwnedTasks object. +//! +//! # Task reference types +//! +//! A task is usually referenced by multiple handles, and there are several +//! types of handles. +//! +//! * OwnedTask - tasks stored in an OwnedTasks or LocalOwnedTasks are of this +//! reference type. +//! +//! * JoinHandle - each task has a JoinHandle that allows access to the output +//! of the task. +//! +//! * Waker - every waker for a task has this reference type. There can be any +//! number of waker references. +//! +//! * Notified - tracks whether the task is notified. +//! +//! * Unowned - this task reference type is used for tasks not stored in any +//! runtime. Mainly used for blocking tasks, but also in tests. +//! +//! The task uses a reference count to keep track of how many active references +//! exist. The Unowned reference type takes up two ref-counts. All other +//! reference types take pu a single ref-count. +//! +//! Besides the waker type, each task has at most one of each reference type. +//! +//! # State +//! +//! The task stores its state in an atomic usize with various bitfields for the +//! necessary information. The state has the following bitfields: +//! +//! * RUNNING - Tracks whether the task is currently being polled or cancelled. +//! This bit functions as a lock around the task. +//! +//! * COMPLETE - Is one once the future has fully completed and has been +//! dropped. Never unset once set. Never set together with RUNNING. +//! +//! * NOTIFIED - Tracks whether a Notified object currently exists. +//! +//! * CANCELLED - Is set to one for tasks that should be cancelled as soon as +//! possible. May take any value for completed tasks. +//! +//! * JOIN_INTEREST - Is set to one if there exists a JoinHandle. +//! +//! * JOIN_WAKER - Is set to one if the JoinHandle has set a waker. +//! +//! The rest of the bits are used for the ref-count. +//! +//! # Fields in the task +//! +//! The task has various fields. This section describes how and when it is safe +//! to access a field. +//! +//! * The state field is accessed with atomic instructions. +//! +//! * The OwnedTask reference has exclusive access to the `owned` field. +//! +//! * The Notified reference has exclusive access to the `queue_next` field. +//! +//! * The `owner_id` field can be set as part of construction of the task, but +//! is otherwise immutable and anyone can access the field immutably without +//! synchronization. +//! +//! * If COMPLETE is one, then the JoinHandle has exclusive access to the +//! stage field. If COMPLETE is zero, then the RUNNING bitfield functions as +//! a lock for the stage field, and it can be accessed only by the thread +//! that set RUNNING to one. +//! +//! * If JOIN_WAKER is zero, then the JoinHandle has exclusive access to the +//! join handle waker. If JOIN_WAKER and COMPLETE are both one, then the +//! thread that set COMPLETE to one has exclusive access to the join handle +//! waker. +//! +//! All other fields are immutable and can be accessed immutably without +//! synchronization by anyone. +//! +//! # Safety +//! +//! This section goes through various situations and explains why the API is +//! safe in that situation. +//! +//! ## Polling or dropping the future +//! +//! Any mutable access to the future happens after obtaining a lock by modifying +//! the RUNNING field, so exclusive access is ensured. +//! +//! When the task completes, exclusive access to the output is transferred to +//! the JoinHandle. If the JoinHandle is already dropped when the transition to +//! complete happens, the thread performing that transition retains exclusive +//! access to the output and should immediately drop it. +//! +//! ## Non-Send futures +//! +//! If a future is not Send, then it is bound to a LocalOwnedTasks. The future +//! will only ever be polled or dropped given a LocalNotified or inside a call +//! to LocalOwnedTasks::shutdown_all. In either case, it is guaranteed that the +//! future is on the right thread. +//! +//! If the task is never removed from the LocalOwnedTasks, then it is leaked, so +//! there is no risk that the task is dropped on some other thread when the last +//! ref-count drops. +//! +//! ## Non-Send output +//! +//! When a task completes, the output is placed in the stage of the task. Then, +//! a transition that sets COMPLETE to true is performed, and the value of +//! JOIN_INTEREST when this transition happens is read. +//! +//! If JOIN_INTEREST is zero when the transition to COMPLETE happens, then the +//! output is immediately dropped. +//! +//! If JOIN_INTEREST is one when the transition to COMPLETE happens, then the +//! JoinHandle is responsible for cleaning up the output. If the output is not +//! Send, then this happens: +//! +//! 1. The output is created on the thread that the future was polled on. Since +//! only non-Send futures can have non-Send output, the future was polled on +//! the thread that the future was spawned from. +//! 2. Since JoinHandle<Output> is not Send if Output is not Send, the +//! JoinHandle is also on the thread that the future was spawned from. +//! 3. Thus, the JoinHandle will not move the output across threads when it +//! takes or drops the output. +//! +//! ## Recursive poll/shutdown +//! +//! Calling poll from inside a shutdown call or vice-versa is not prevented by +//! the API exposed by the task module, so this has to be safe. In either case, +//! the lock in the RUNNING bitfield makes the inner call return immediately. If +//! the inner call is a `shutdown` call, then the CANCELLED bit is set, and the +//! poll call will notice it when the poll finishes, and the task is cancelled +//! at that point. + mod core; use self::core::Cell; -pub(crate) use self::core::Header; +use self::core::Header; mod error; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 @@ -9,10 +146,18 @@ pub use self::error::JoinError; mod harness; use self::harness::Harness; +cfg_rt_multi_thread! { + mod inject; + pub(super) use self::inject::Inject; +} + mod join; #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::join::JoinHandle; +mod list; +pub(crate) use self::list::{LocalOwnedTasks, OwnedTasks}; + mod raw; use self::raw::RawTask; @@ -21,19 +166,14 @@ use self::state::State; mod waker; -cfg_rt_multi_thread! { - mod stack; - pub(crate) use self::stack::TransferStack; -} - +use crate::future::Future; use crate::util::linked_list; -use std::future::Future; use std::marker::PhantomData; use std::ptr::NonNull; use std::{fmt, mem}; -/// An owned handle to the task, tracked by ref count +/// An owned handle to the task, tracked by ref count. #[repr(transparent)] pub(crate) struct Task<S: 'static> { raw: RawTask, @@ -43,30 +183,43 @@ pub(crate) struct Task<S: 'static> { unsafe impl<S> Send for Task<S> {} unsafe impl<S> Sync for Task<S> {} -/// A task was notified +/// A task was notified. #[repr(transparent)] pub(crate) struct Notified<S: 'static>(Task<S>); +// safety: This type cannot be used to touch the task without first verifying +// that the value is on a thread where it is safe to poll the task. unsafe impl<S: Schedule> Send for Notified<S> {} unsafe impl<S: Schedule> Sync for Notified<S> {} -/// Task result sent back +/// A non-Send variant of Notified with the invariant that it is on a thread +/// where it is safe to poll it. +#[repr(transparent)] +pub(crate) struct LocalNotified<S: 'static> { + task: Task<S>, + _not_send: PhantomData<*const ()>, +} + +/// A task that is not owned by any OwnedTasks. Used for blocking tasks. +/// This type holds two ref-counts. +pub(crate) struct UnownedTask<S: 'static> { + raw: RawTask, + _p: PhantomData<S>, +} + +// safety: This type can only be created given a Send task. +unsafe impl<S> Send for UnownedTask<S> {} +unsafe impl<S> Sync for UnownedTask<S> {} + +/// Task result sent back. pub(crate) type Result<T> = std::result::Result<T, JoinError>; pub(crate) trait Schedule: Sync + Sized + 'static { - /// Bind a task to the executor. - /// - /// Guaranteed to be called from the thread that called `poll` on the task. - /// The returned `Schedule` instance is associated with the task and is used - /// as `&self` in the other methods on this trait. - fn bind(task: Task<Self>) -> Self; - /// The task has completed work and is ready to be released. The scheduler - /// is free to drop it whenever. + /// should release it immediately and return it. The task module will batch + /// the ref-dec with setting other options. /// - /// If the scheduler can immediately release the task, it should return - /// it as part of the function. This enables the task module to batch - /// the ref-dec with other options. + /// If the scheduler has already released the task, then None is returned. fn release(&self, task: &Task<Self>) -> Option<Task<Self>>; /// Schedule the task @@ -80,71 +233,86 @@ pub(crate) trait Schedule: Sync + Sized + 'static { } cfg_rt! { - /// Create a new task with an associated join handle - pub(crate) fn joinable<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>) + /// This is the constructor for a new task. Three references to the task are + /// created. The first task reference is usually put into an OwnedTasks + /// immediately. The Notified is sent to the scheduler as an ordinary + /// notification. + fn new_task<T, S>( + task: T, + scheduler: S + ) -> (Task<S>, Notified<S>, JoinHandle<T::Output>) where - T: Future + Send + 'static, S: Schedule, + T: Future + 'static, + T::Output: 'static, { - let raw = RawTask::new::<_, S>(task); - + let raw = RawTask::new::<T, S>(task, scheduler); let task = Task { raw, _p: PhantomData, }; - + let notified = Notified(Task { + raw, + _p: PhantomData, + }); let join = JoinHandle::new(raw); - (Notified(task), join) + (task, notified, join) } -} -cfg_rt! { - /// Create a new `!Send` task with an associated join handle - pub(crate) unsafe fn joinable_local<T, S>(task: T) -> (Notified<S>, JoinHandle<T::Output>) + /// Creates a new task with an associated join handle. This method is used + /// only when the task is not going to be stored in an `OwnedTasks` list. + /// + /// Currently only blocking tasks use this method. + pub(crate) fn unowned<T, S>(task: T, scheduler: S) -> (UnownedTask<S>, JoinHandle<T::Output>) where - T: Future + 'static, S: Schedule, + T: Send + Future + 'static, + T::Output: Send + 'static, { - let raw = RawTask::new::<_, S>(task); + let (task, notified, join) = new_task(task, scheduler); - let task = Task { - raw, + // This transfers the ref-count of task and notified into an UnownedTask. + // This is valid because an UnownedTask holds two ref-counts. + let unowned = UnownedTask { + raw: task.raw, _p: PhantomData, }; + std::mem::forget(task); + std::mem::forget(notified); - let join = JoinHandle::new(raw); - - (Notified(task), join) + (unowned, join) } } impl<S: 'static> Task<S> { - pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Task<S> { Task { raw: RawTask::from_raw(ptr), _p: PhantomData, } } - pub(crate) fn header(&self) -> &Header { + fn header(&self) -> &Header { self.raw.header() } } +impl<S: 'static> Notified<S> { + fn header(&self) -> &Header { + self.0.header() + } +} + cfg_rt_multi_thread! { impl<S: 'static> Notified<S> { - pub(crate) unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> { + unsafe fn from_raw(ptr: NonNull<Header>) -> Notified<S> { Notified(Task::from_raw(ptr)) } - - pub(crate) fn header(&self) -> &Header { - self.0.header() - } } impl<S: 'static> Task<S> { - pub(crate) fn into_raw(self) -> NonNull<Header> { + fn into_raw(self) -> NonNull<Header> { let ret = self.header().into(); mem::forget(self); ret @@ -152,29 +320,69 @@ cfg_rt_multi_thread! { } impl<S: 'static> Notified<S> { - pub(crate) fn into_raw(self) -> NonNull<Header> { + fn into_raw(self) -> NonNull<Header> { self.0.into_raw() } } } impl<S: Schedule> Task<S> { - /// Pre-emptively cancel the task as part of the shutdown process. - pub(crate) fn shutdown(&self) { - self.raw.shutdown(); + /// Pre-emptively cancels the task as part of the shutdown process. + pub(crate) fn shutdown(self) { + let raw = self.raw; + mem::forget(self); + raw.shutdown(); } } -impl<S: Schedule> Notified<S> { - /// Run the task +impl<S: Schedule> LocalNotified<S> { + /// Runs the task. pub(crate) fn run(self) { - self.0.raw.poll(); + let raw = self.task.raw; mem::forget(self); + raw.poll(); + } +} + +impl<S: Schedule> UnownedTask<S> { + // Used in test of the inject queue. + #[cfg(test)] + pub(super) fn into_notified(self) -> Notified<S> { + Notified(self.into_task()) + } + + fn into_task(self) -> Task<S> { + // Convert into a task. + let task = Task { + raw: self.raw, + _p: PhantomData, + }; + mem::forget(self); + + // Drop a ref-count since an UnownedTask holds two. + task.header().state.ref_dec(); + + task + } + + pub(crate) fn run(self) { + let raw = self.raw; + mem::forget(self); + + // Transfer one ref-count to a Task object. + let task = Task::<S> { + raw, + _p: PhantomData, + }; + + // Use the other ref-count to poll the task. + raw.poll(); + // Decrement our extra ref-count + drop(task); } - /// Pre-emptively cancel the task as part of the shutdown process. pub(crate) fn shutdown(self) { - self.0.shutdown(); + self.into_task().shutdown() } } @@ -188,6 +396,16 @@ impl<S: 'static> Drop for Task<S> { } } +impl<S: 'static> Drop for UnownedTask<S> { + fn drop(&mut self) { + // Decrement the ref count + if self.raw.header().state.ref_dec_twice() { + // Deallocate if this is the final ref count + self.raw.dealloc(); + } + } +} + impl<S> fmt::Debug for Task<S> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "Task({:p})", self.header()) @@ -202,7 +420,7 @@ impl<S> fmt::Debug for Notified<S> { /// # Safety /// -/// Tasks are pinned +/// Tasks are pinned. unsafe impl<S> linked_list::Link for Task<S> { type Handle = Task<S>; type Target = Header; diff --git a/src/runtime/task/raw.rs b/src/runtime/task/raw.rs index cae56d0..fbc9574 100644 --- a/src/runtime/task/raw.rs +++ b/src/runtime/task/raw.rs @@ -1,6 +1,6 @@ +use crate::future::Future; use crate::runtime::task::{Cell, Harness, Header, Schedule, State}; -use std::future::Future; use std::ptr::NonNull; use std::task::{Poll, Waker}; @@ -10,19 +10,22 @@ pub(super) struct RawTask { } pub(super) struct Vtable { - /// Poll the future + /// Polls the future. pub(super) poll: unsafe fn(NonNull<Header>), - /// Deallocate the memory + /// Deallocates the memory. pub(super) dealloc: unsafe fn(NonNull<Header>), - /// Read the task output, if complete + /// Reads the task output, if complete. pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker), - /// The join handle has been dropped + /// The join handle has been dropped. pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>), - /// Scheduler is being shutdown + /// The task is remotely aborted. + pub(super) remote_abort: unsafe fn(NonNull<Header>), + + /// Scheduler is being shutdown. pub(super) shutdown: unsafe fn(NonNull<Header>), } @@ -33,17 +36,18 @@ pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable { dealloc: dealloc::<T, S>, try_read_output: try_read_output::<T, S>, drop_join_handle_slow: drop_join_handle_slow::<T, S>, + remote_abort: remote_abort::<T, S>, shutdown: shutdown::<T, S>, } } impl RawTask { - pub(super) fn new<T, S>(task: T) -> RawTask + pub(super) fn new<T, S>(task: T, scheduler: S) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, State::new())); + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; RawTask { ptr } @@ -89,6 +93,11 @@ impl RawTask { let vtable = self.header().vtable; unsafe { (vtable.shutdown)(self.ptr) } } + + pub(super) fn remote_abort(self) { + let vtable = self.header().vtable; + unsafe { (vtable.remote_abort)(self.ptr) } + } } impl Clone for RawTask { @@ -125,6 +134,11 @@ unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) { harness.drop_join_handle_slow() } +unsafe fn remote_abort<T: Future, S: Schedule>(ptr: NonNull<Header>) { + let harness = Harness::<T, S>::from_raw(ptr); + harness.remote_abort() +} + unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) { let harness = Harness::<T, S>::from_raw(ptr); harness.shutdown() diff --git a/src/runtime/task/stack.rs b/src/runtime/task/stack.rs deleted file mode 100644 index 9dd8d3f..0000000 --- a/src/runtime/task/stack.rs +++ /dev/null @@ -1,83 +0,0 @@ -use crate::loom::sync::atomic::AtomicPtr; -use crate::runtime::task::{Header, Task}; - -use std::marker::PhantomData; -use std::ptr::{self, NonNull}; -use std::sync::atomic::Ordering::{Acquire, Relaxed, Release}; - -/// Concurrent stack of tasks, used to pass ownership of a task from one worker -/// to another. -pub(crate) struct TransferStack<T: 'static> { - head: AtomicPtr<Header>, - _p: PhantomData<T>, -} - -impl<T: 'static> TransferStack<T> { - pub(crate) fn new() -> TransferStack<T> { - TransferStack { - head: AtomicPtr::new(ptr::null_mut()), - _p: PhantomData, - } - } - - pub(crate) fn push(&self, task: Task<T>) { - let task = task.into_raw(); - - // We don't care about any memory associated w/ setting the `head` - // field, just the current value. - // - // The compare-exchange creates a release sequence. - let mut curr = self.head.load(Relaxed); - - loop { - unsafe { - task.as_ref() - .stack_next - .with_mut(|ptr| *ptr = NonNull::new(curr)) - }; - - let res = self - .head - .compare_exchange(curr, task.as_ptr() as *mut _, Release, Relaxed); - - match res { - Ok(_) => return, - Err(actual) => { - curr = actual; - } - } - } - } - - pub(crate) fn drain(&self) -> impl Iterator<Item = Task<T>> { - struct Iter<T: 'static>(Option<NonNull<Header>>, PhantomData<T>); - - impl<T: 'static> Iterator for Iter<T> { - type Item = Task<T>; - - fn next(&mut self) -> Option<Task<T>> { - let task = self.0?; - - // Move the cursor forward - self.0 = unsafe { task.as_ref().stack_next.with(|ptr| *ptr) }; - - // Return the task - unsafe { Some(Task::from_raw(task)) } - } - } - - impl<T: 'static> Drop for Iter<T> { - fn drop(&mut self) { - use std::process; - - if self.0.is_some() { - // we have bugs - process::abort(); - } - } - } - - let ptr = self.head.swap(ptr::null_mut(), Acquire); - Iter(NonNull::new(ptr), PhantomData) - } -} diff --git a/src/runtime/task/state.rs b/src/runtime/task/state.rs index 21e9043..c2d5b28 100644 --- a/src/runtime/task/state.rs +++ b/src/runtime/task/state.rs @@ -8,7 +8,7 @@ pub(super) struct State { val: AtomicUsize, } -/// Current state value +/// Current state value. #[derive(Copy, Clone)] pub(super) struct Snapshot(usize); @@ -19,54 +19,87 @@ const RUNNING: usize = 0b0001; /// The task is complete. /// -/// Once this bit is set, it is never unset +/// Once this bit is set, it is never unset. const COMPLETE: usize = 0b0010; -/// Extracts the task's lifecycle value from the state +/// Extracts the task's lifecycle value from the state. const LIFECYCLE_MASK: usize = 0b11; /// Flag tracking if the task has been pushed into a run queue. const NOTIFIED: usize = 0b100; -/// The join handle is still around +/// The join handle is still around. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 const JOIN_INTEREST: usize = 0b1_000; -/// A join handle waker has been set +/// A join handle waker has been set. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 const JOIN_WAKER: usize = 0b10_000; /// The task has been forcibly cancelled. +#[allow(clippy::unusual_byte_groupings)] // https://github.com/rust-lang/rust-clippy/issues/6556 const CANCELLED: usize = 0b100_000; -/// All bits +/// All bits. const STATE_MASK: usize = LIFECYCLE_MASK | NOTIFIED | JOIN_INTEREST | JOIN_WAKER | CANCELLED; /// Bits used by the ref count portion of the state. const REF_COUNT_MASK: usize = !STATE_MASK; -/// Number of positions to shift the ref count +/// Number of positions to shift the ref count. const REF_COUNT_SHIFT: usize = REF_COUNT_MASK.count_zeros() as usize; -/// One ref count +/// One ref count. const REF_ONE: usize = 1 << REF_COUNT_SHIFT; -/// State a task is initialized with +/// State a task is initialized with. /// -/// A task is initialized with two references: one for the scheduler and one for -/// the `JoinHandle`. As the task starts with a `JoinHandle`, `JOIN_INTERST` is -/// set. A new task is immediately pushed into the run queue for execution and -/// starts with the `NOTIFIED` flag set. -const INITIAL_STATE: usize = (REF_ONE * 2) | JOIN_INTEREST | NOTIFIED; +/// A task is initialized with three references: +/// +/// * A reference that will be stored in an OwnedTasks or LocalOwnedTasks. +/// * A reference that will be sent to the scheduler as an ordinary notification. +/// * A reference for the JoinHandle. +/// +/// As the task starts with a `JoinHandle`, `JOIN_INTEREST` is set. +/// As the task starts with a `Notified`, `NOTIFIED` is set. +const INITIAL_STATE: usize = (REF_ONE * 3) | JOIN_INTEREST | NOTIFIED; + +#[must_use] +pub(super) enum TransitionToRunning { + Success, + Cancelled, + Failed, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToIdle { + Ok, + OkNotified, + OkDealloc, + Cancelled, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByVal { + DoNothing, + Submit, + Dealloc, +} + +#[must_use] +pub(super) enum TransitionToNotifiedByRef { + DoNothing, + Submit, +} /// All transitions are performed via RMW operations. This establishes an /// unambiguous modification order. impl State { - /// Return a task's initial state + /// Returns a task's initial state. pub(super) fn new() -> State { - // A task is initialized with three references: one for the scheduler, - // one for the `JoinHandle`, one for the task handle made available in - // release. As the task starts with a `JoinHandle`, `JOIN_INTERST` is - // set. A new task is immediately pushed into the run queue for - // execution and starts with the `NOTIFIED` flag set. + // The raw task returned by this method has a ref-count of three. See + // the comment on INITIAL_STATE for more. State { val: AtomicUsize::new(INITIAL_STATE), } @@ -77,57 +110,72 @@ impl State { Snapshot(self.val.load(Acquire)) } - /// Attempt to transition the lifecycle to `Running`. - /// - /// If `ref_inc` is set, the reference count is also incremented. - /// - /// The `NOTIFIED` bit is always unset. - pub(super) fn transition_to_running(&self, ref_inc: bool) -> UpdateResult { - self.fetch_update(|curr| { - assert!(curr.is_notified()); - - let mut next = curr; + /// Attempts to transition the lifecycle to `Running`. This sets the + /// notified bit to false so notifications during the poll can be detected. + pub(super) fn transition_to_running(&self) -> TransitionToRunning { + self.fetch_update_action(|mut next| { + let action; + assert!(next.is_notified()); if !next.is_idle() { - return None; - } - - if ref_inc { - next.ref_inc(); + // This happens if the task is either currently running or if it + // has already completed, e.g. if it was cancelled during + // shutdown. Consume the ref-count and return. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToRunning::Dealloc; + } else { + action = TransitionToRunning::Failed; + } + } else { + // We are able to lock the RUNNING bit. + next.set_running(); + next.unset_notified(); + + if next.is_cancelled() { + action = TransitionToRunning::Cancelled; + } else { + action = TransitionToRunning::Success; + } } - - next.set_running(); - next.unset_notified(); - Some(next) + (action, Some(next)) }) } /// Transitions the task from `Running` -> `Idle`. /// - /// Returns `Ok` if the transition to `Idle` is successful, `Err` otherwise. - /// In both cases, a snapshot of the state from **after** the transition is - /// returned. - /// + /// Returns `true` if the transition to `Idle` is successful, `false` otherwise. /// The transition to `Idle` fails if the task has been flagged to be /// cancelled. - pub(super) fn transition_to_idle(&self) -> UpdateResult { - self.fetch_update(|curr| { + pub(super) fn transition_to_idle(&self) -> TransitionToIdle { + self.fetch_update_action(|curr| { assert!(curr.is_running()); if curr.is_cancelled() { - return None; + return (TransitionToIdle::Cancelled, None); } let mut next = curr; + let action; next.unset_running(); - if next.is_notified() { - // The caller needs to schedule the task. To do this, it needs a - // waker. The waker requires a ref count. + if !next.is_notified() { + // Polling the future consumes the ref-count of the Notified. + next.ref_dec(); + if next.ref_count() == 0 { + action = TransitionToIdle::OkDealloc; + } else { + action = TransitionToIdle::Ok; + } + } else { + // The caller will schedule a new notification, so we create a + // new ref-count for the notification. Our own ref-count is kept + // for now, and the caller will drop it shortly. next.ref_inc(); + action = TransitionToIdle::OkNotified; } - Some(next) + (action, Some(next)) }) } @@ -142,42 +190,123 @@ impl State { Snapshot(prev.0 ^ DELTA) } - /// Transition from `Complete` -> `Terminal`, decrementing the reference - /// count by 1. + /// Transitions from `Complete` -> `Terminal`, decrementing the reference + /// count the specified number of times. /// - /// When `ref_dec` is set, an additional ref count decrement is performed. - /// This is used to batch atomic ops when possible. - pub(super) fn transition_to_terminal(&self, complete: bool, ref_dec: bool) -> Snapshot { - self.fetch_update(|mut snapshot| { - if complete { - snapshot.set_complete(); - } else { - assert!(snapshot.is_complete()); - } + /// Returns true if the task should be deallocated. + pub(super) fn transition_to_terminal(&self, count: usize) -> bool { + let prev = Snapshot(self.val.fetch_sub(count * REF_ONE, AcqRel)); + assert!( + prev.ref_count() >= count, + "current: {}, sub: {}", + prev.ref_count(), + count + ); + prev.ref_count() == count + } + + /// Transitions the state to `NOTIFIED`. + /// + /// If no task needs to be submitted, a ref-count is consumed. + /// + /// If a task needs to be submitted, the ref-count is incremented for the + /// new Notified. + pub(super) fn transition_to_notified_by_val(&self) -> TransitionToNotifiedByVal { + self.fetch_update_action(|mut snapshot| { + let action; + + if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit anything as the thread currently running the + // future is responsible for that. + snapshot.set_notified(); + snapshot.ref_dec(); - // Decrement the primary handle - snapshot.ref_dec(); + // The thread that set the running bit also holds a ref-count. + assert!(snapshot.ref_count() > 0); - if ref_dec { - // Decrement a second time + action = TransitionToNotifiedByVal::DoNothing; + } else if snapshot.is_complete() || snapshot.is_notified() { + // We do not need to submit any notifications, but we have to + // decrement the ref-count. snapshot.ref_dec(); + + if snapshot.ref_count() == 0 { + action = TransitionToNotifiedByVal::Dealloc; + } else { + action = TransitionToNotifiedByVal::DoNothing; + } + } else { + // We create a new notified that we can submit. The caller + // retains ownership of the ref-count they passed in. + snapshot.set_notified(); + snapshot.ref_inc(); + action = TransitionToNotifiedByVal::Submit; } - Some(snapshot) + (action, Some(snapshot)) }) - .unwrap() } /// Transitions the state to `NOTIFIED`. + pub(super) fn transition_to_notified_by_ref(&self) -> TransitionToNotifiedByRef { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_complete() || snapshot.is_notified() { + // There is nothing to do in this case. + (TransitionToNotifiedByRef::DoNothing, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as notified, but we should + // not submit as the thread currently running the future is + // responsible for that. + snapshot.set_notified(); + (TransitionToNotifiedByRef::DoNothing, Some(snapshot)) + } else { + // The task is idle and not notified. We should submit a + // notification. + snapshot.set_notified(); + snapshot.ref_inc(); + (TransitionToNotifiedByRef::Submit, Some(snapshot)) + } + }) + } + + /// Sets the cancelled bit and transitions the state to `NOTIFIED` if idle. /// /// Returns `true` if the task needs to be submitted to the pool for - /// execution - pub(super) fn transition_to_notified(&self) -> bool { - let prev = Snapshot(self.val.fetch_or(NOTIFIED, AcqRel)); - prev.will_need_queueing() + /// execution. + pub(super) fn transition_to_notified_and_cancel(&self) -> bool { + self.fetch_update_action(|mut snapshot| { + if snapshot.is_cancelled() || snapshot.is_complete() { + // Aborts to completed or cancelled tasks are no-ops. + (false, None) + } else if snapshot.is_running() { + // If the task is running, we mark it as cancelled. The thread + // running the task will notice the cancelled bit when it + // stops polling and it will kill the task. + // + // The set_notified() call is not strictly necessary but it will + // in some cases let a wake_by_ref call return without having + // to perform a compare_exchange. + snapshot.set_notified(); + snapshot.set_cancelled(); + (false, Some(snapshot)) + } else { + // The task is idle. We set the cancelled and notified bits and + // submit a notification if the notified bit was not already + // set. + snapshot.set_cancelled(); + if !snapshot.is_notified() { + snapshot.set_notified(); + snapshot.ref_inc(); + (true, Some(snapshot)) + } else { + (false, Some(snapshot)) + } + } + }) } - /// Set the `CANCELLED` bit and attempt to transition to `Running`. + /// Sets the `CANCELLED` bit and attempts to transition to `Running`. /// /// Returns `true` if the transition to `Running` succeeded. pub(super) fn transition_to_shutdown(&self) -> bool { @@ -188,17 +317,11 @@ impl State { if snapshot.is_idle() { snapshot.set_running(); - - if snapshot.is_notified() { - // If the task is idle and notified, this indicates the task is - // in the run queue and is considered owned by the scheduler. - // The shutdown operation claims ownership of the task, which - // means we need to assign an additional ref-count to the task - // in the queue. - snapshot.ref_inc(); - } } + // If the task was not idle, the thread currently running the task + // will notice the cancelled bit and cancel it once the poll + // completes. snapshot.set_cancelled(); Some(snapshot) }); @@ -207,7 +330,7 @@ impl State { } /// Optimistically tries to swap the state assuming the join handle is - /// __immediately__ dropped on spawn + /// __immediately__ dropped on spawn. pub(super) fn drop_join_handle_fast(&self) -> Result<(), ()> { use std::sync::atomic::Ordering::Relaxed; @@ -229,7 +352,7 @@ impl State { .map_err(|_| ()) } - /// Try to unset the JOIN_INTEREST flag. + /// Tries to unset the JOIN_INTEREST flag. /// /// Returns `Ok` if the operation happens before the task transitions to a /// completed state, `Err` otherwise. @@ -248,7 +371,7 @@ impl State { }) } - /// Set the `JOIN_WAKER` bit. + /// Sets the `JOIN_WAKER` bit. /// /// Returns `Ok` if the bit is set, `Err` otherwise. This operation fails if /// the task has completed. @@ -306,7 +429,7 @@ impl State { let prev = self.val.fetch_add(REF_ONE, Relaxed); // If the reference count overflowed, abort. - if prev > isize::max_value() as usize { + if prev > isize::MAX as usize { process::abort(); } } @@ -314,9 +437,39 @@ impl State { /// Returns `true` if the task should be released. pub(super) fn ref_dec(&self) -> bool { let prev = Snapshot(self.val.fetch_sub(REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 1); prev.ref_count() == 1 } + /// Returns `true` if the task should be released. + pub(super) fn ref_dec_twice(&self) -> bool { + let prev = Snapshot(self.val.fetch_sub(2 * REF_ONE, AcqRel)); + assert!(prev.ref_count() >= 2); + prev.ref_count() == 2 + } + + fn fetch_update_action<F, T>(&self, mut f: F) -> T + where + F: FnMut(Snapshot) -> (T, Option<Snapshot>), + { + let mut curr = self.load(); + + loop { + let (output, next) = f(curr); + let next = match next { + Some(next) => next, + None => return output, + }; + + let res = self.val.compare_exchange(curr.0, next.0, AcqRel, Acquire); + + match res { + Ok(_) => return output, + Err(actual) => curr = Snapshot(actual), + } + } + } + fn fetch_update<F>(&self, mut f: F) -> Result<Snapshot, Snapshot> where F: FnMut(Snapshot) -> Option<Snapshot>, @@ -356,6 +509,10 @@ impl Snapshot { self.0 &= !NOTIFIED } + fn set_notified(&mut self) { + self.0 |= NOTIFIED + } + pub(super) fn is_running(self) -> bool { self.0 & RUNNING == RUNNING } @@ -376,10 +533,6 @@ impl Snapshot { self.0 |= CANCELLED; } - fn set_complete(&mut self) { - self.0 |= COMPLETE; - } - /// Returns `true` if the task's future has completed execution. pub(super) fn is_complete(self) -> bool { self.0 & COMPLETE == COMPLETE @@ -410,7 +563,7 @@ impl Snapshot { } fn ref_inc(&mut self) { - assert!(self.0 <= isize::max_value() as usize); + assert!(self.0 <= isize::MAX as usize); self.0 += REF_ONE; } @@ -418,10 +571,6 @@ impl Snapshot { assert!(self.ref_count() > 0); self.0 -= REF_ONE } - - fn will_need_queueing(self) -> bool { - !self.is_notified() && self.is_idle() - } } impl fmt::Debug for State { diff --git a/src/runtime/task/waker.rs b/src/runtime/task/waker.rs index 5c2d478..b7313b4 100644 --- a/src/runtime/task/waker.rs +++ b/src/runtime/task/waker.rs @@ -1,7 +1,7 @@ +use crate::future::Future; use crate::runtime::task::harness::Harness; use crate::runtime::task::{Header, Schedule}; -use std::future::Future; use std::marker::PhantomData; use std::mem::ManuallyDrop; use std::ops; @@ -44,12 +44,38 @@ impl<S> ops::Deref for WakerRef<'_, S> { } } +cfg_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + if let Some(id) = $harness.id() { + tracing::trace!( + target: "tokio::task::waker", + op = $op, + task.id = id.into_u64(), + ); + } + } + } +} + +cfg_not_trace! { + macro_rules! trace { + ($harness:expr, $op:expr) => { + // noop + let _ = &$harness; + } + } +} + unsafe fn clone_waker<T, S>(ptr: *const ()) -> RawWaker where T: Future, S: Schedule, { let header = ptr as *const Header; + let ptr = NonNull::new_unchecked(ptr as *mut Header); + let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.clone"); (*header).state.ref_inc(); raw_waker::<T, S>(header) } @@ -61,6 +87,7 @@ where { let ptr = NonNull::new_unchecked(ptr as *mut Header); let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.drop"); harness.drop_reference(); } @@ -71,6 +98,7 @@ where { let ptr = NonNull::new_unchecked(ptr as *mut Header); let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake"); harness.wake_by_val(); } @@ -82,6 +110,7 @@ where { let ptr = NonNull::new_unchecked(ptr as *mut Header); let harness = Harness::<T, S>::from_raw(ptr); + trace!(harness, "waker.wake_by_ref"); harness.wake_by_ref(); } diff --git a/src/runtime/tests/loom_basic_scheduler.rs b/src/runtime/tests/loom_basic_scheduler.rs index e6221d3..d2894b9 100644 --- a/src/runtime/tests/loom_basic_scheduler.rs +++ b/src/runtime/tests/loom_basic_scheduler.rs @@ -63,6 +63,45 @@ fn block_on_num_polls() { }); } +#[test] +fn assert_no_unnecessary_polls() { + loom::model(|| { + // // After we poll outer future, woken should reset to false + let rt = Builder::new_current_thread().build().unwrap(); + let (tx, rx) = oneshot::channel(); + let pending_cnt = Arc::new(AtomicUsize::new(0)); + + rt.spawn(async move { + for _ in 0..24 { + task::yield_now().await; + } + tx.send(()).unwrap(); + }); + + let pending_cnt_clone = pending_cnt.clone(); + rt.block_on(async move { + // use task::yield_now() to ensure woken set to true + // ResetFuture will be polled at most once + // Here comes two cases + // 1. recv no message from channel, ResetFuture will be polled + // but get Pending and we record ResetFuture.pending_cnt ++. + // Then when message arrive, ResetFuture returns Ready. So we + // expect ResetFuture.pending_cnt = 1 + // 2. recv message from channel, ResetFuture returns Ready immediately. + // We expect ResetFuture.pending_cnt = 0 + task::yield_now().await; + ResetFuture { + rx, + pending_cnt: pending_cnt_clone, + } + .await; + }); + + let pending_cnt = pending_cnt.load(Acquire); + assert!(pending_cnt <= 1); + }); +} + struct BlockedFuture { rx: Receiver<()>, num_polls: Arc<AtomicUsize>, @@ -80,3 +119,22 @@ impl Future for BlockedFuture { } } } + +struct ResetFuture { + rx: Receiver<()>, + pending_cnt: Arc<AtomicUsize>, +} + +impl Future for ResetFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match Pin::new(&mut self.rx).poll(cx) { + Poll::Pending => { + self.pending_cnt.fetch_add(1, Release); + Poll::Pending + } + _ => Poll::Ready(()), + } + } +} diff --git a/src/runtime/tests/loom_local.rs b/src/runtime/tests/loom_local.rs new file mode 100644 index 0000000..d9a07a4 --- /dev/null +++ b/src/runtime/tests/loom_local.rs @@ -0,0 +1,47 @@ +use crate::runtime::tests::loom_oneshot as oneshot; +use crate::runtime::Builder; +use crate::task::LocalSet; + +use std::task::Poll; + +/// Waking a runtime will attempt to push a task into a queue of notifications +/// in the runtime, however the tasks in such a queue usually have a reference +/// to the runtime itself. This means that if they are not properly removed at +/// runtime shutdown, this will cause a memory leak. +/// +/// This test verifies that waking something during shutdown of a LocalSet does +/// not result in tasks lingering in the queue once shutdown is complete. This +/// is verified using loom's leak finder. +#[test] +fn wake_during_shutdown() { + loom::model(|| { + let rt = Builder::new_current_thread().build().unwrap(); + let ls = LocalSet::new(); + + let (send, recv) = oneshot::channel(); + + ls.spawn_local(async move { + let mut send = Some(send); + + let () = futures::future::poll_fn(|cx| { + if let Some(send) = send.take() { + send.send(cx.waker().clone()); + } + + Poll::Pending + }) + .await; + }); + + let handle = loom::thread::spawn(move || { + let waker = recv.recv(); + waker.wake(); + }); + + ls.block_on(&rt, crate::task::yield_now()); + + drop(ls); + handle.join().unwrap(); + drop(rt); + }); +} diff --git a/src/runtime/tests/loom_oneshot.rs b/src/runtime/tests/loom_oneshot.rs index c126fe4..87eb638 100644 --- a/src/runtime/tests/loom_oneshot.rs +++ b/src/runtime/tests/loom_oneshot.rs @@ -1,7 +1,6 @@ +use crate::loom::sync::{Arc, Mutex}; use loom::sync::Notify; -use std::sync::{Arc, Mutex}; - pub(crate) fn channel<T>() -> (Sender<T>, Receiver<T>) { let inner = Arc::new(Inner { notify: Notify::new(), @@ -31,7 +30,7 @@ struct Inner<T> { impl<T> Sender<T> { pub(crate) fn send(self, value: T) { - *self.inner.value.lock().unwrap() = Some(value); + *self.inner.value.lock() = Some(value); self.inner.notify.notify(); } } @@ -39,7 +38,7 @@ impl<T> Sender<T> { impl<T> Receiver<T> { pub(crate) fn recv(self) -> T { loop { - if let Some(v) = self.inner.value.lock().unwrap().take() { + if let Some(v) = self.inner.value.lock().take() { return v; } diff --git a/src/runtime/tests/loom_pool.rs b/src/runtime/tests/loom_pool.rs index 06ad641..b3ecd43 100644 --- a/src/runtime/tests/loom_pool.rs +++ b/src/runtime/tests/loom_pool.rs @@ -11,7 +11,7 @@ use crate::{spawn, task}; use tokio_test::assert_ok; use loom::sync::atomic::{AtomicBool, AtomicUsize}; -use loom::sync::{Arc, Mutex}; +use loom::sync::Arc; use pin_project_lite::pin_project; use std::future::Future; @@ -19,6 +19,57 @@ use std::pin::Pin; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; use std::task::{Context, Poll}; +mod atomic_take { + use loom::sync::atomic::AtomicBool; + use std::mem::MaybeUninit; + use std::sync::atomic::Ordering::SeqCst; + + pub(super) struct AtomicTake<T> { + inner: MaybeUninit<T>, + taken: AtomicBool, + } + + impl<T> AtomicTake<T> { + pub(super) fn new(value: T) -> Self { + Self { + inner: MaybeUninit::new(value), + taken: AtomicBool::new(false), + } + } + + pub(super) fn take(&self) -> Option<T> { + // safety: Only one thread will see the boolean change from false + // to true, so that thread is able to take the value. + match self.taken.fetch_or(true, SeqCst) { + false => unsafe { Some(std::ptr::read(self.inner.as_ptr())) }, + true => None, + } + } + } + + impl<T> Drop for AtomicTake<T> { + fn drop(&mut self) { + drop(self.take()); + } + } +} + +#[derive(Clone)] +struct AtomicOneshot<T> { + value: std::sync::Arc<atomic_take::AtomicTake<oneshot::Sender<T>>>, +} +impl<T> AtomicOneshot<T> { + fn new(sender: oneshot::Sender<T>) -> Self { + Self { + value: std::sync::Arc::new(atomic_take::AtomicTake::new(sender)), + } + } + + fn assert_send(&self, value: T) { + self.value.take().unwrap().send(value); + } +} + /// Tests are divided into groups to make the runs faster on CI. mod group_a { use super::*; @@ -52,7 +103,7 @@ mod group_a { let c1 = Arc::new(AtomicUsize::new(0)); let (tx, rx) = oneshot::channel(); - let tx1 = Arc::new(Mutex::new(Some(tx))); + let tx1 = AtomicOneshot::new(tx); // Spawn a task let c2 = c1.clone(); @@ -60,7 +111,7 @@ mod group_a { pool.spawn(track(async move { spawn(track(async move { if 1 == c1.fetch_add(1, Relaxed) { - tx1.lock().unwrap().take().unwrap().send(()); + tx1.assert_send(()); } })); })); @@ -69,7 +120,7 @@ mod group_a { pool.spawn(track(async move { spawn(track(async move { if 1 == c2.fetch_add(1, Relaxed) { - tx2.lock().unwrap().take().unwrap().send(()); + tx2.assert_send(()); } })); })); @@ -119,7 +170,7 @@ mod group_b { let (block_tx, block_rx) = oneshot::channel(); let (done_tx, done_rx) = oneshot::channel(); - let done_tx = Arc::new(Mutex::new(Some(done_tx))); + let done_tx = AtomicOneshot::new(done_tx); pool.spawn(track(async move { crate::task::block_in_place(move || { @@ -136,7 +187,7 @@ mod group_b { pool.spawn(track(async move { if NUM == cnt.fetch_add(1, Relaxed) + 1 { - done_tx.lock().unwrap().take().unwrap().send(()); + done_tx.assert_send(()); } })); } @@ -159,23 +210,6 @@ mod group_b { } #[test] - fn pool_shutdown() { - loom::model(|| { - let pool = mk_pool(2); - - pool.spawn(track(async move { - gated2(true).await; - })); - - pool.spawn(track(async move { - gated2(false).await; - })); - - drop(pool); - }); - } - - #[test] fn join_output() { loom::model(|| { let rt = mk_pool(1); @@ -223,10 +257,6 @@ mod group_b { }); }); } -} - -mod group_c { - use super::*; #[test] fn shutdown_with_notification() { @@ -255,6 +285,27 @@ mod group_c { } } +mod group_c { + use super::*; + + #[test] + fn pool_shutdown() { + loom::model(|| { + let pool = mk_pool(2); + + pool.spawn(track(async move { + gated2(true).await; + })); + + pool.spawn(track(async move { + gated2(false).await; + })); + + drop(pool); + }); + } +} + mod group_d { use super::*; @@ -266,17 +317,17 @@ mod group_d { let c1 = Arc::new(AtomicUsize::new(0)); let (done_tx, done_rx) = oneshot::channel(); - let done_tx1 = Arc::new(Mutex::new(Some(done_tx))); + let done_tx1 = AtomicOneshot::new(done_tx); + let done_tx2 = done_tx1.clone(); // Spawn a task let c2 = c1.clone(); - let done_tx2 = done_tx1.clone(); pool.spawn(track(async move { gated().await; gated().await; if 1 == c1.fetch_add(1, Relaxed) { - done_tx1.lock().unwrap().take().unwrap().send(()); + done_tx1.assert_send(()); } })); @@ -286,7 +337,7 @@ mod group_d { gated().await; if 1 == c2.fetch_add(1, Relaxed) { - done_tx2.lock().unwrap().take().unwrap().send(()); + done_tx2.assert_send(()); } })); diff --git a/src/runtime/tests/loom_queue.rs b/src/runtime/tests/loom_queue.rs index de02610..2cbb0a1 100644 --- a/src/runtime/tests/loom_queue.rs +++ b/src/runtime/tests/loom_queue.rs @@ -1,5 +1,7 @@ +use crate::runtime::blocking::NoopSchedule; use crate::runtime::queue; -use crate::runtime::task::{self, Schedule, Task}; +use crate::runtime::stats::WorkerStatsBatcher; +use crate::runtime::task::Inject; use loom::thread; @@ -7,14 +9,15 @@ use loom::thread; fn basic() { loom::model(|| { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); let th = thread::spawn(move || { + let mut stats = WorkerStatsBatcher::new(0); let (_, mut local) = queue::local(); let mut n = 0; for _ in 0..3 { - if steal.steal_into(&mut local).is_some() { + if steal.steal_into(&mut local, &mut stats).is_some() { n += 1; } @@ -30,7 +33,7 @@ fn basic() { for _ in 0..2 { for _ in 0..2 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -39,7 +42,7 @@ fn basic() { } // Push another task - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); while local.pop().is_some() { @@ -61,13 +64,14 @@ fn basic() { fn steal_overflow() { loom::model(|| { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); let th = thread::spawn(move || { + let mut stats = WorkerStatsBatcher::new(0); let (_, mut local) = queue::local(); let mut n = 0; - if steal.steal_into(&mut local).is_some() { + if steal.steal_into(&mut local, &mut stats).is_some() { n += 1; } @@ -81,7 +85,7 @@ fn steal_overflow() { let mut n = 0; // push a task, pop a task - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); if local.pop().is_some() { @@ -89,7 +93,7 @@ fn steal_overflow() { } for _ in 0..6 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -111,10 +115,11 @@ fn steal_overflow() { fn multi_stealer() { const NUM_TASKS: usize = 5; - fn steal_tasks(steal: queue::Steal<Runtime>) -> usize { + fn steal_tasks(steal: queue::Steal<NoopSchedule>) -> usize { + let mut stats = WorkerStatsBatcher::new(0); let (_, mut local) = queue::local(); - if steal.steal_into(&mut local).is_none() { + if steal.steal_into(&mut local, &mut stats).is_none() { return 0; } @@ -129,11 +134,11 @@ fn multi_stealer() { loom::model(|| { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); // Push work for _ in 0..NUM_TASKS { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -164,23 +169,25 @@ fn multi_stealer() { #[test] fn chained_steal() { loom::model(|| { + let mut stats = WorkerStatsBatcher::new(0); let (s1, mut l1) = queue::local(); let (s2, mut l2) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); // Load up some tasks for _ in 0..4 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); l1.push_back(task, &inject); - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); l2.push_back(task, &inject); } // Spawn a task to steal from **our** queue let th = thread::spawn(move || { + let mut stats = WorkerStatsBatcher::new(0); let (_, mut local) = queue::local(); - s1.steal_into(&mut local); + s1.steal_into(&mut local, &mut stats); while local.pop().is_some() {} }); @@ -188,7 +195,7 @@ fn chained_steal() { // Drain our tasks, then attempt to steal while l1.pop().is_some() {} - s2.steal_into(&mut l1); + s2.steal_into(&mut l1, &mut stats); th.join().unwrap(); @@ -197,20 +204,3 @@ fn chained_steal() { while inject.pop().is_some() {} }); } - -struct Runtime; - -impl Schedule for Runtime { - fn bind(task: Task<Self>) -> Runtime { - std::mem::forget(task); - Runtime - } - - fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { - None - } - - fn schedule(&self, _task: task::Notified<Self>) { - unreachable!(); - } -} diff --git a/src/runtime/tests/loom_shutdown_join.rs b/src/runtime/tests/loom_shutdown_join.rs new file mode 100644 index 0000000..6fbc4bf --- /dev/null +++ b/src/runtime/tests/loom_shutdown_join.rs @@ -0,0 +1,28 @@ +use crate::runtime::{Builder, Handle}; + +#[test] +fn join_handle_cancel_on_shutdown() { + let mut builder = loom::model::Builder::new(); + builder.preemption_bound = Some(2); + builder.check(|| { + use futures::future::FutureExt; + + let rt = Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(); + + let handle = rt.block_on(async move { Handle::current() }); + + let jh1 = handle.spawn(futures::future::pending::<()>()); + + drop(rt); + + let jh2 = handle.spawn(futures::future::pending::<()>()); + + let err1 = jh1.now_or_never().unwrap().unwrap_err(); + let err2 = jh2.now_or_never().unwrap().unwrap_err(); + assert!(err1.is_cancelled()); + assert!(err2.is_cancelled()); + }); +} diff --git a/src/runtime/tests/mod.rs b/src/runtime/tests/mod.rs index ebb48de..be36d6f 100644 --- a/src/runtime/tests/mod.rs +++ b/src/runtime/tests/mod.rs @@ -1,14 +1,49 @@ +use self::unowned_wrapper::unowned; + +mod unowned_wrapper { + use crate::runtime::blocking::NoopSchedule; + use crate::runtime::task::{JoinHandle, Notified}; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + use tracing::Instrument; + let span = tracing::trace_span!("test_span"); + let task = task.instrument(span); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + pub(crate) fn unowned<T>(task: T) -> (Notified<NoopSchedule>, JoinHandle<T::Output>) + where + T: std::future::Future + Send + 'static, + T::Output: Send + 'static, + { + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + (task.into_notified(), handle) + } +} + cfg_loom! { mod loom_basic_scheduler; + mod loom_local; mod loom_blocking; mod loom_oneshot; mod loom_pool; mod loom_queue; + mod loom_shutdown_join; } cfg_not_loom! { mod queue; + #[cfg(not(miri))] + mod task_combinations; + #[cfg(miri)] mod task; } diff --git a/src/runtime/tests/queue.rs b/src/runtime/tests/queue.rs index d228d5d..47f1b01 100644 --- a/src/runtime/tests/queue.rs +++ b/src/runtime/tests/queue.rs @@ -1,5 +1,6 @@ use crate::runtime::queue; -use crate::runtime::task::{self, Schedule, Task}; +use crate::runtime::stats::WorkerStatsBatcher; +use crate::runtime::task::{self, Inject, Schedule, Task}; use std::thread; use std::time::Duration; @@ -7,10 +8,10 @@ use std::time::Duration; #[test] fn fits_256() { let (_, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); for _ in 0..256 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -22,10 +23,10 @@ fn fits_256() { #[test] fn overflow() { let (_, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); for _ in 0..257 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -44,16 +45,18 @@ fn overflow() { #[test] fn steal_batch() { + let mut stats = WorkerStatsBatcher::new(0); + let (steal1, mut local1) = queue::local(); let (_, mut local2) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); for _ in 0..4 { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local1.push_back(task, &inject); } - assert!(steal1.steal_into(&mut local2).is_some()); + assert!(steal1.steal_into(&mut local2, &mut stats).is_some()); for _ in 0..1 { assert!(local2.pop().is_some()); @@ -78,14 +81,15 @@ fn stress1() { for _ in 0..NUM_ITER { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); let th = thread::spawn(move || { + let mut stats = WorkerStatsBatcher::new(0); let (_, mut local) = queue::local(); let mut n = 0; for _ in 0..NUM_STEAL { - if steal.steal_into(&mut local).is_some() { + if steal.steal_into(&mut local, &mut stats).is_some() { n += 1; } @@ -103,7 +107,7 @@ fn stress1() { for _ in 0..NUM_LOCAL { for _ in 0..NUM_PUSH { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); } @@ -134,14 +138,15 @@ fn stress2() { for _ in 0..NUM_ITER { let (steal, mut local) = queue::local(); - let inject = queue::Inject::new(); + let inject = Inject::new(); let th = thread::spawn(move || { + let mut stats = WorkerStatsBatcher::new(0); let (_, mut local) = queue::local(); let mut n = 0; for _ in 0..NUM_STEAL { - if steal.steal_into(&mut local).is_some() { + if steal.steal_into(&mut local, &mut stats).is_some() { n += 1; } @@ -158,7 +163,7 @@ fn stress2() { let mut num_pop = 0; for i in 0..NUM_TASKS { - let (task, _) = task::joinable::<_, Runtime>(async {}); + let (task, _) = super::unowned(async {}); local.push_back(task, &inject); if i % 128 == 0 && local.pop().is_some() { @@ -187,11 +192,6 @@ fn stress2() { struct Runtime; impl Schedule for Runtime { - fn bind(task: Task<Self>) -> Runtime { - std::mem::forget(task); - Runtime - } - fn release(&self, _task: &Task<Self>) -> Option<Task<Self>> { None } diff --git a/src/runtime/tests/task.rs b/src/runtime/tests/task.rs index a34526f..04e1b56 100644 --- a/src/runtime/tests/task.rs +++ b/src/runtime/tests/task.rs @@ -1,44 +1,191 @@ -use crate::runtime::task::{self, Schedule, Task}; -use crate::util::linked_list::{Link, LinkedList}; +use crate::runtime::blocking::NoopSchedule; +use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task}; use crate::util::TryLock; use std::collections::VecDeque; +use std::future::Future; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +struct AssertDropHandle { + is_dropped: Arc<AtomicBool>, +} +impl AssertDropHandle { + #[track_caller] + fn assert_dropped(&self) { + assert!(self.is_dropped.load(Ordering::SeqCst)); + } + + #[track_caller] + fn assert_not_dropped(&self) { + assert!(!self.is_dropped.load(Ordering::SeqCst)); + } +} + +struct AssertDrop { + is_dropped: Arc<AtomicBool>, +} +impl AssertDrop { + fn new() -> (Self, AssertDropHandle) { + let shared = Arc::new(AtomicBool::new(false)); + ( + AssertDrop { + is_dropped: shared.clone(), + }, + AssertDropHandle { + is_dropped: shared.clone(), + }, + ) + } +} +impl Drop for AssertDrop { + fn drop(&mut self) { + self.is_dropped.store(true, Ordering::SeqCst); + } +} + +// A Notified does not shut down on drop, but it is dropped once the ref-count +// hits zero. +#[test] +fn create_drop1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(notified); + handle.assert_not_dropped(); + drop(join); + handle.assert_dropped(); +} + +#[test] +fn create_drop2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + drop(notified); + handle.assert_dropped(); +} + +// Shutting down through Notified works #[test] -fn create_drop() { - let _ = task::joinable::<_, Runtime>(async { unreachable!() }); +fn create_shutdown1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + drop(join); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); +} + +#[test] +fn create_shutdown2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + ); + handle.assert_not_dropped(); + notified.shutdown(); + handle.assert_dropped(); + drop(join); +} + +#[test] +fn unowned_poll() { + let (task, _) = unowned(async {}, NoopSchedule); + task.run(); } #[test] fn schedule() { with(|rt| { - let (task, _) = task::joinable(async { + rt.spawn(async { crate::task::yield_now().await; }); - rt.schedule(task); - assert_eq!(2, rt.tick()); + rt.shutdown(); }) } #[test] fn shutdown() { with(|rt| { - let (task, _) = task::joinable(async { + rt.spawn(async { loop { crate::task::yield_now().await; } }); - rt.schedule(task); rt.tick_max(1); rt.shutdown(); }) } +#[test] +fn shutdown_immediately() { + with(|rt| { + rt.spawn(async { + loop { + crate::task::yield_now().await; + } + }); + + rt.shutdown(); + }) +} + +#[test] +fn spawn_during_shutdown() { + static DID_SPAWN: AtomicBool = AtomicBool::new(false); + + struct SpawnOnDrop(Runtime); + impl Drop for SpawnOnDrop { + fn drop(&mut self) { + DID_SPAWN.store(true, Ordering::SeqCst); + self.0.spawn(async {}); + } + } + + with(|rt| { + let rt2 = rt.clone(); + rt.spawn(async move { + let _spawn_on_drop = SpawnOnDrop(rt2); + + loop { + crate::task::yield_now().await; + } + }); + + rt.tick_max(1); + rt.shutdown(); + }); + + assert!(DID_SPAWN.load(Ordering::SeqCst)); +} + fn with(f: impl FnOnce(Runtime)) { struct Reset; @@ -51,10 +198,9 @@ fn with(f: impl FnOnce(Runtime)) { let _reset = Reset; let rt = Runtime(Arc::new(Inner { - released: task::TransferStack::new(), + owned: OwnedTasks::new(), core: TryLock::new(Core { queue: VecDeque::new(), - tasks: LinkedList::new(), }), })); @@ -66,20 +212,33 @@ fn with(f: impl FnOnce(Runtime)) { struct Runtime(Arc<Inner>); struct Inner { - released: task::TransferStack<Runtime>, core: TryLock<Core>, + owned: OwnedTasks<Runtime>, } struct Core { queue: VecDeque<task::Notified<Runtime>>, - tasks: LinkedList<Task<Runtime>, <Task<Runtime> as Link>::Target>, } static CURRENT: TryLock<Option<Runtime>> = TryLock::new(None); impl Runtime { + fn spawn<T>(&self, future: T) -> JoinHandle<T::Output> + where + T: 'static + Send + Future, + T::Output: 'static + Send, + { + let (handle, notified) = self.0.owned.bind(future, self.clone()); + + if let Some(notified) = notified { + self.schedule(notified); + } + + handle + } + fn tick(&self) -> usize { - self.tick_max(usize::max_value()) + self.tick_max(usize::MAX) } fn tick_max(&self, max: usize) -> usize { @@ -88,11 +247,10 @@ impl Runtime { while !self.is_empty() && n < max { let task = self.next_task(); n += 1; + let task = self.0.owned.assert_owner(task); task.run(); } - self.0.maintenance(); - n } @@ -107,50 +265,21 @@ impl Runtime { fn shutdown(&self) { let mut core = self.0.core.try_lock().unwrap(); - for task in core.tasks.iter() { - task.shutdown(); - } + self.0.owned.close_and_shutdown_all(); while let Some(task) = core.queue.pop_back() { - task.shutdown(); + drop(task); } drop(core); - while !self.0.core.try_lock().unwrap().tasks.is_empty() { - self.0.maintenance(); - } - } -} - -impl Inner { - fn maintenance(&self) { - use std::mem::ManuallyDrop; - - for task in self.released.drain() { - let task = ManuallyDrop::new(task); - - // safety: see worker.rs - unsafe { - let ptr = task.header().into(); - self.core.try_lock().unwrap().tasks.remove(ptr); - } - } + assert!(self.0.owned.is_empty()); } } impl Schedule for Runtime { - fn bind(task: Task<Self>) -> Runtime { - let rt = CURRENT.try_lock().unwrap().as_ref().unwrap().clone(); - rt.0.core.try_lock().unwrap().tasks.push_front(task); - rt - } - fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { - // safety: copying worker.rs - let task = unsafe { Task::from_raw(task.header().into()) }; - self.0.released.push(task); - None + self.0.owned.remove(task) } fn schedule(&self, task: task::Notified<Self>) { diff --git a/src/runtime/tests/task_combinations.rs b/src/runtime/tests/task_combinations.rs new file mode 100644 index 0000000..76ce233 --- /dev/null +++ b/src/runtime/tests/task_combinations.rs @@ -0,0 +1,380 @@ +use std::future::Future; +use std::panic; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use crate::runtime::Builder; +use crate::sync::oneshot; +use crate::task::JoinHandle; + +use futures::future::FutureExt; + +// Enums for each option in the combinations being tested + +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiRuntime { + CurrentThread, + Multi1, + Multi2, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiLocalSet { + Yes, + No, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiTask { + PanicOnRun, + PanicOnDrop, + PanicOnRunAndDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiOutput { + PanicOnDrop, + NoPanic, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinInterest { + Polled, + NotPolled, +} +#[allow(clippy::enum_variant_names)] // we aren't using glob imports +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiJoinHandle { + DropImmediately = 1, + DropFirstPoll = 2, + DropAfterNoConsume = 3, + DropAfterConsume = 4, +} +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiAbort { + NotAborted = 0, + AbortedImmediately = 1, + AbortedFirstPoll = 2, + AbortedAfterFinish = 3, + AbortedAfterConsumeOutput = 4, +} + +#[test] +fn test_combinations() { + let mut rt = &[ + CombiRuntime::CurrentThread, + CombiRuntime::Multi1, + CombiRuntime::Multi2, + ][..]; + + if cfg!(miri) { + rt = &[CombiRuntime::CurrentThread]; + } + + let ls = [CombiLocalSet::Yes, CombiLocalSet::No]; + let task = [ + CombiTask::NoPanic, + CombiTask::PanicOnRun, + CombiTask::PanicOnDrop, + CombiTask::PanicOnRunAndDrop, + ]; + let output = [CombiOutput::NoPanic, CombiOutput::PanicOnDrop]; + let ji = [CombiJoinInterest::Polled, CombiJoinInterest::NotPolled]; + let jh = [ + CombiJoinHandle::DropImmediately, + CombiJoinHandle::DropFirstPoll, + CombiJoinHandle::DropAfterNoConsume, + CombiJoinHandle::DropAfterConsume, + ]; + let abort = [ + CombiAbort::NotAborted, + CombiAbort::AbortedImmediately, + CombiAbort::AbortedFirstPoll, + CombiAbort::AbortedAfterFinish, + CombiAbort::AbortedAfterConsumeOutput, + ]; + + for rt in rt.iter().copied() { + for ls in ls.iter().copied() { + for task in task.iter().copied() { + for output in output.iter().copied() { + for ji in ji.iter().copied() { + for jh in jh.iter().copied() { + for abort in abort.iter().copied() { + test_combination(rt, ls, task, output, ji, jh, abort); + } + } + } + } + } + } + } +} + +fn test_combination( + rt: CombiRuntime, + ls: CombiLocalSet, + task: CombiTask, + output: CombiOutput, + ji: CombiJoinInterest, + jh: CombiJoinHandle, + abort: CombiAbort, +) { + if (jh as usize) < (abort as usize) { + // drop before abort not possible + return; + } + if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) { + // this causes double panic + return; + } + if (task == CombiTask::PanicOnRunAndDrop) && (abort != CombiAbort::AbortedImmediately) { + // this causes double panic + return; + } + + println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort); + + // A runtime optionally with a LocalSet + struct Rt { + rt: crate::runtime::Runtime, + ls: Option<crate::task::LocalSet>, + } + impl Rt { + fn new(rt: CombiRuntime, ls: CombiLocalSet) -> Self { + let rt = match rt { + CombiRuntime::CurrentThread => Builder::new_current_thread().build().unwrap(), + CombiRuntime::Multi1 => Builder::new_multi_thread() + .worker_threads(1) + .build() + .unwrap(), + CombiRuntime::Multi2 => Builder::new_multi_thread() + .worker_threads(2) + .build() + .unwrap(), + }; + + let ls = match ls { + CombiLocalSet::Yes => Some(crate::task::LocalSet::new()), + CombiLocalSet::No => None, + }; + + Self { rt, ls } + } + fn block_on<T>(&self, task: T) -> T::Output + where + T: Future, + { + match &self.ls { + Some(ls) => ls.block_on(&self.rt, task), + None => self.rt.block_on(task), + } + } + fn spawn<T>(&self, task: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + match &self.ls { + Some(ls) => ls.spawn_local(task), + None => self.rt.spawn(task), + } + } + } + + // The type used for the output of the future + struct Output { + panic_on_drop: bool, + on_drop: Option<oneshot::Sender<()>>, + } + impl Output { + fn disarm(&mut self) { + self.panic_on_drop = false; + } + } + impl Drop for Output { + fn drop(&mut self) { + let _ = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in Output"); + } + } + } + + // A wrapper around the future that is spawned + struct FutWrapper<F> { + inner: F, + on_drop: Option<oneshot::Sender<()>>, + panic_on_drop: bool, + } + impl<F: Future> Future for FutWrapper<F> { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> { + unsafe { + let me = Pin::into_inner_unchecked(self); + let inner = Pin::new_unchecked(&mut me.inner); + inner.poll(cx) + } + } + } + impl<F> Drop for FutWrapper<F> { + fn drop(&mut self) { + let _: Result<(), ()> = self.on_drop.take().unwrap().send(()); + if self.panic_on_drop { + panic!("Panicking in FutWrapper"); + } + } + } + + // The channels passed to the task + struct Signals { + on_first_poll: Option<oneshot::Sender<()>>, + wait_complete: Option<oneshot::Receiver<()>>, + on_output_drop: Option<oneshot::Sender<()>>, + } + + // The task we will spawn + async fn my_task(mut signal: Signals, task: CombiTask, out: CombiOutput) -> Output { + // Signal that we have been polled once + let _ = signal.on_first_poll.take().unwrap().send(()); + + // Wait for a signal, then complete the future + let _ = signal.wait_complete.take().unwrap().await; + + // If the task gets past wait_complete without yielding, then aborts + // may not be caught without this yield_now. + crate::task::yield_now().await; + + if task == CombiTask::PanicOnRun || task == CombiTask::PanicOnRunAndDrop { + panic!("Panicking in my_task on {:?}", std::thread::current().id()); + } + + Output { + panic_on_drop: out == CombiOutput::PanicOnDrop, + on_drop: signal.on_output_drop.take(), + } + } + + let rt = Rt::new(rt, ls); + + let (on_first_poll, wait_first_poll) = oneshot::channel(); + let (on_complete, wait_complete) = oneshot::channel(); + let (on_future_drop, wait_future_drop) = oneshot::channel(); + let (on_output_drop, wait_output_drop) = oneshot::channel(); + let signal = Signals { + on_first_poll: Some(on_first_poll), + wait_complete: Some(wait_complete), + on_output_drop: Some(on_output_drop), + }; + + // === Spawn task === + let mut handle = Some(rt.spawn(FutWrapper { + inner: my_task(signal, task, output), + on_drop: Some(on_future_drop), + panic_on_drop: task == CombiTask::PanicOnDrop || task == CombiTask::PanicOnRunAndDrop, + })); + + // Keep track of whether the task has been killed with an abort + let mut aborted = false; + + // If we want to poll the JoinHandle, do it now + if ji == CombiJoinInterest::Polled { + assert!( + handle.as_mut().unwrap().now_or_never().is_none(), + "Polling handle succeeded" + ); + } + + if abort == CombiAbort::AbortedImmediately { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropImmediately { + drop(handle.take().unwrap()); + } + + // === Wait for first poll === + let got_polled = rt.block_on(wait_first_poll).is_ok(); + if !got_polled { + // it's possible that we are aborted but still got polled + assert!( + aborted, + "Task completed without ever being polled but was not aborted." + ); + } + + if abort == CombiAbort::AbortedFirstPoll { + handle.as_mut().unwrap().abort(); + aborted = true; + } + if jh == CombiJoinHandle::DropFirstPoll { + drop(handle.take().unwrap()); + } + + // Signal the future that it can return now + let _ = on_complete.send(()); + // === Wait for future to be dropped === + assert!( + rt.block_on(wait_future_drop).is_ok(), + "The future should always be dropped." + ); + + if abort == CombiAbort::AbortedAfterFinish { + // Don't set aborted to true here as the task already finished + handle.as_mut().unwrap().abort(); + } + if jh == CombiJoinHandle::DropAfterNoConsume { + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the JoinHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping JoinHandle shouldn't panic here" + ); + } + } + + // Check whether we drop after consuming the output + if jh == CombiJoinHandle::DropAfterConsume { + // Using as_mut() to not immediately drop the handle + let result = rt.block_on(handle.as_mut().unwrap()); + + match result { + Ok(mut output) => { + // Don't panic here. + output.disarm(); + assert!(!aborted, "Task was aborted but returned output"); + } + Err(err) if err.is_cancelled() => assert!(aborted, "Cancelled output but not aborted"), + Err(err) if err.is_panic() => { + assert!( + (task == CombiTask::PanicOnRun) + || (task == CombiTask::PanicOnDrop) + || (task == CombiTask::PanicOnRunAndDrop) + || (output == CombiOutput::PanicOnDrop), + "Panic but nothing should panic" + ); + } + _ => unreachable!(), + } + + let handle = handle.take().unwrap(); + if abort == CombiAbort::AbortedAfterConsumeOutput { + handle.abort(); + } + drop(handle); + } + + // The output should have been dropped now. Check whether the output + // object was created at all. + let output_created = rt.block_on(wait_output_drop).is_ok(); + assert_eq!( + output_created, + (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) && !aborted, + "Creation of output object" + ); +} diff --git a/src/runtime/thread_pool/idle.rs b/src/runtime/thread_pool/idle.rs index b77cce5..6b7ee12 100644 --- a/src/runtime/thread_pool/idle.rs +++ b/src/runtime/thread_pool/idle.rs @@ -42,11 +42,11 @@ impl Idle { /// worker currently sleeping. pub(super) fn worker_to_notify(&self) -> Option<usize> { // If at least one worker is spinning, work being notified will - // eventully be found. A searching thread will find **some** work and + // eventually be found. A searching thread will find **some** work and // notify another worker, eventually leading to our work being found. // // For this to happen, this load must happen before the thread - // transitioning `num_searching` to zero. Acquire / Relese does not + // transitioning `num_searching` to zero. Acquire / Release does not // provide sufficient guarantees, so this load is done with `SeqCst` and // will pair with the `fetch_sub(1)` when transitioning out of // searching. @@ -126,7 +126,7 @@ impl Idle { } } - /// Returns `true` if `worker_id` is contained in the sleep set + /// Returns `true` if `worker_id` is contained in the sleep set. pub(super) fn is_parked(&self, worker_id: usize) -> bool { let sleepers = self.sleepers.lock(); sleepers.contains(&worker_id) diff --git a/src/runtime/thread_pool/mod.rs b/src/runtime/thread_pool/mod.rs index 47f8ee3..82e34c7 100644 --- a/src/runtime/thread_pool/mod.rs +++ b/src/runtime/thread_pool/mod.rs @@ -12,8 +12,9 @@ pub(crate) use worker::Launch; pub(crate) use worker::block_in_place; use crate::loom::sync::Arc; -use crate::runtime::task::{self, JoinHandle}; -use crate::runtime::Parker; +use crate::runtime::stats::RuntimeStats; +use crate::runtime::task::JoinHandle; +use crate::runtime::{Callback, Parker}; use std::fmt; use std::future::Future; @@ -23,14 +24,14 @@ pub(crate) struct ThreadPool { spawner: Spawner, } -/// Submit futures to the associated thread pool for execution. +/// Submits futures to the associated thread pool for execution. /// /// A `Spawner` instance is a handle to a single thread pool that allows the owner /// of the handle to spawn futures onto the thread pool. /// /// The `Spawner` handle is *only* used for spawning new futures. It does not /// impact the lifecycle of the thread pool in any way. The thread pool may -/// shutdown while there are outstanding `Spawner` instances. +/// shut down while there are outstanding `Spawner` instances. /// /// `Spawner` instances are obtained by calling [`ThreadPool::spawner`]. /// @@ -43,8 +44,13 @@ pub(crate) struct Spawner { // ===== impl ThreadPool ===== impl ThreadPool { - pub(crate) fn new(size: usize, parker: Parker) -> (ThreadPool, Launch) { - let (shared, launch) = worker::create(size, parker); + pub(crate) fn new( + size: usize, + parker: Parker, + before_park: Option<Callback>, + after_unpark: Option<Callback>, + ) -> (ThreadPool, Launch) { + let (shared, launch) = worker::create(size, parker, before_park, after_unpark); let spawner = Spawner { shared }; let thread_pool = ThreadPool { spawner }; @@ -90,17 +96,19 @@ impl Spawner { /// Spawns a future onto the thread pool pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output> where - F: Future + Send + 'static, + F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (task, handle) = task::joinable(future); - self.shared.schedule(task, false); - handle + worker::Shared::bind_new_task(&self.shared, future) } pub(crate) fn shutdown(&mut self) { self.shared.close(); } + + pub(crate) fn stats(&self) -> &RuntimeStats { + self.shared.stats() + } } impl fmt::Debug for Spawner { diff --git a/src/runtime/thread_pool/worker.rs b/src/runtime/thread_pool/worker.rs index 86d3f91..ae8efe6 100644 --- a/src/runtime/thread_pool/worker.rs +++ b/src/runtime/thread_pool/worker.rs @@ -3,17 +3,71 @@ //! run queue and other state. When `block_in_place` is called, the worker's //! "core" is handed off to a new thread allowing the scheduler to continue to //! make progress while the originating thread blocks. +//! +//! # Shutdown +//! +//! Shutting down the runtime involves the following steps: +//! +//! 1. The Shared::close method is called. This closes the inject queue and +//! OwnedTasks instance and wakes up all worker threads. +//! +//! 2. Each worker thread observes the close signal next time it runs +//! Core::maintenance by checking whether the inject queue is closed. +//! The Core::is_shutdown flag is set to true. +//! +//! 3. The worker thread calls `pre_shutdown` in parallel. Here, the worker +//! will keep removing tasks from OwnedTasks until it is empty. No new +//! tasks can be pushed to the OwnedTasks during or after this step as it +//! was closed in step 1. +//! +//! 5. The workers call Shared::shutdown to enter the single-threaded phase of +//! shutdown. These calls will push their core to Shared::shutdown_cores, +//! and the last thread to push its core will finish the shutdown procedure. +//! +//! 6. The local run queue of each core is emptied, then the inject queue is +//! emptied. +//! +//! At this point, shutdown has completed. It is not possible for any of the +//! collections to contain any tasks at this point, as each collection was +//! closed first, then emptied afterwards. +//! +//! ## Spawns during shutdown +//! +//! When spawning tasks during shutdown, there are two cases: +//! +//! * The spawner observes the OwnedTasks being open, and the inject queue is +//! closed. +//! * The spawner observes the OwnedTasks being closed and doesn't check the +//! inject queue. +//! +//! The first case can only happen if the OwnedTasks::bind call happens before +//! or during step 1 of shutdown. In this case, the runtime will clean up the +//! task in step 3 of shutdown. +//! +//! In the latter case, the task was not spawned and the task is immediately +//! cancelled by the spawner. +//! +//! The correctness of shutdown requires both the inject queue and OwnedTasks +//! collection to have a closed bit. With a close bit on only the inject queue, +//! spawning could run in to a situation where a task is successfully bound long +//! after the runtime has shut down. With a close bit on only the OwnedTasks, +//! the first spawning situation could result in the notification being pushed +//! to the inject queue after step 6 of shutdown, which would leave a task in +//! the inject queue indefinitely. This would be a ref-count cycle and a memory +//! leak. use crate::coop; +use crate::future::Future; use crate::loom::rand::seed; use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; use crate::runtime; use crate::runtime::enter::EnterContext; use crate::runtime::park::{Parker, Unparker}; +use crate::runtime::stats::{RuntimeStats, WorkerStatsBatcher}; +use crate::runtime::task::{Inject, JoinHandle, OwnedTasks}; use crate::runtime::thread_pool::{AtomicCell, Idle}; -use crate::runtime::{queue, task}; -use crate::util::linked_list::{Link, LinkedList}; +use crate::runtime::{queue, task, Callback}; use crate::util::FastRand; use std::cell::RefCell; @@ -44,7 +98,7 @@ struct Core { lifo_slot: Option<Notified>, /// The worker-local run queue. - run_queue: queue::Local<Arc<Worker>>, + run_queue: queue::Local<Arc<Shared>>, /// True if the worker is currently searching for more work. Searching /// involves attempting to steal from other workers. @@ -53,15 +107,15 @@ struct Core { /// True if the scheduler is being shutdown is_shutdown: bool, - /// Tasks owned by the core - tasks: LinkedList<Task, <Task as Link>::Target>, - /// Parker /// /// Stored in an `Option` as the parker is added / removed to make the /// borrow checker happy. park: Option<Parker>, + /// Batching stats so they can be submitted to RuntimeStats. + stats: WorkerStatsBatcher, + /// Fast random number generator. rand: FastRand, } @@ -72,28 +126,35 @@ pub(super) struct Shared { /// how they communicate between each other. remotes: Box<[Remote]>, - /// Submit work to the scheduler while **not** currently on a worker thread. - inject: queue::Inject<Arc<Worker>>, + /// Submits work to the scheduler while **not** currently on a worker thread. + inject: Inject<Arc<Shared>>, /// Coordinates idle workers idle: Idle, + /// Collection of all active tasks spawned onto this executor. + owned: OwnedTasks<Arc<Shared>>, + /// Cores that have observed the shutdown signal /// /// The core is **not** placed back in the worker to avoid it from being /// stolen by a thread that was spawned as part of `block_in_place`. #[allow(clippy::vec_box)] // we're moving an already-boxed value shutdown_cores: Mutex<Vec<Box<Core>>>, + + /// Callback for a worker parking itself + before_park: Option<Callback>, + /// Callback for a worker unparking itself + after_unpark: Option<Callback>, + + /// Collects stats from the runtime. + stats: RuntimeStats, } /// Used to communicate with a worker from other threads. struct Remote { - /// Steal tasks from this worker. - steal: queue::Steal<Arc<Worker>>, - - /// Transfers tasks to be released. Any worker pushes tasks, only the owning - /// worker pops. - pending_drop: task::TransferStack<Arc<Worker>>, + /// Steals tasks from this worker. + steal: queue::Steal<Arc<Shared>>, /// Unparks the associated worker thread unpark: Unparker, @@ -117,20 +178,25 @@ pub(crate) struct Launch(Vec<Arc<Worker>>); type RunResult = Result<Box<Core>, ()>; /// A task handle -type Task = task::Task<Arc<Worker>>; +type Task = task::Task<Arc<Shared>>; /// A notified task handle -type Notified = task::Notified<Arc<Worker>>; +type Notified = task::Notified<Arc<Shared>>; // Tracks thread-local state scoped_thread_local!(static CURRENT: Context); -pub(super) fn create(size: usize, park: Parker) -> (Arc<Shared>, Launch) { +pub(super) fn create( + size: usize, + park: Parker, + before_park: Option<Callback>, + after_unpark: Option<Callback>, +) -> (Arc<Shared>, Launch) { let mut cores = vec![]; let mut remotes = vec![]; // Create the local queues - for _ in 0..size { + for i in 0..size { let (steal, run_queue) = queue::local(); let park = park.clone(); @@ -142,23 +208,23 @@ pub(super) fn create(size: usize, park: Parker) -> (Arc<Shared>, Launch) { run_queue, is_searching: false, is_shutdown: false, - tasks: LinkedList::new(), park: Some(park), + stats: WorkerStatsBatcher::new(i), rand: FastRand::new(seed()), })); - remotes.push(Remote { - steal, - pending_drop: task::TransferStack::new(), - unpark, - }); + remotes.push(Remote { steal, unpark }); } let shared = Arc::new(Shared { remotes: remotes.into_boxed_slice(), - inject: queue::Inject::new(), + inject: Inject::new(), idle: Idle::new(size), + owned: OwnedTasks::new(), shutdown_cores: Mutex::new(vec![]), + before_park, + after_unpark, + stats: RuntimeStats::new(size), }); let mut launch = Launch(vec![]); @@ -203,18 +269,20 @@ where CURRENT.with(|maybe_cx| { match (crate::runtime::enter::context(), maybe_cx.is_some()) { (EnterContext::Entered { .. }, true) => { - // We are on a thread pool runtime thread, so we just need to set up blocking. + // We are on a thread pool runtime thread, so we just need to + // set up blocking. had_entered = true; } (EnterContext::Entered { allow_blocking }, false) => { - // We are on an executor, but _not_ on the thread pool. - // That is _only_ okay if we are in a thread pool runtime's block_on method: + // We are on an executor, but _not_ on the thread pool. That is + // _only_ okay if we are in a thread pool runtime's block_on + // method: if allow_blocking { had_entered = true; return; } else { - // This probably means we are on the basic_scheduler or in a LocalSet, - // where it is _not_ okay to block. + // This probably means we are on the basic_scheduler or in a + // LocalSet, where it is _not_ okay to block. panic!("can call blocking only when running on the multi-threaded runtime"); } } @@ -337,11 +405,14 @@ impl Context { } fn run_task(&self, task: Notified, mut core: Box<Core>) -> RunResult { + let task = self.worker.shared.owned.assert_owner(task); + // Make sure the worker is not in the **searching** state. This enables // another idle worker to try to steal work. core.transition_from_searching(&self.worker); // Make the core available to the runtime context + core.stats.incr_poll_count(); *self.core.borrow_mut() = Some(core); // Run the task @@ -366,7 +437,9 @@ impl Context { if coop::has_budget_remaining() { // Run the LIFO task, then loop + core.stats.incr_poll_count(); *self.core.borrow_mut() = Some(core); + let task = self.worker.shared.owned.assert_owner(task); task.run(); } else { // Not enough budget left to run the LIFO task, push it to @@ -392,19 +465,26 @@ impl Context { } fn park(&self, mut core: Box<Core>) -> Box<Core> { - core.transition_to_parked(&self.worker); + if let Some(f) = &self.worker.shared.before_park { + f(); + } - while !core.is_shutdown { - core = self.park_timeout(core, None); + if core.transition_to_parked(&self.worker) { + while !core.is_shutdown { + core = self.park_timeout(core, None); - // Run regularly scheduled maintenance - core.maintenance(&self.worker); + // Run regularly scheduled maintenance + core.maintenance(&self.worker); - if core.transition_from_parked(&self.worker) { - return core; + if core.transition_from_parked(&self.worker) { + break; + } } } + if let Some(f) = &self.worker.shared.after_unpark { + f(); + } core } @@ -412,6 +492,8 @@ impl Context { // Take the parker out of core let mut park = core.park.take().expect("park missing"); + core.stats.about_to_park(); + // Store `core` in context *self.core.borrow_mut() = Some(core); @@ -433,6 +515,8 @@ impl Context { self.worker.shared.notify_parked(); } + core.stats.returned_from_park(); + core } } @@ -474,7 +558,10 @@ impl Core { } let target = &worker.shared.remotes[i]; - if let Some(task) = target.steal.steal_into(&mut self.run_queue) { + if let Some(task) = target + .steal + .steal_into(&mut self.run_queue, &mut self.stats) + { return Some(task); } } @@ -500,8 +587,15 @@ impl Core { worker.shared.transition_worker_from_searching(); } - /// Prepare the worker state for parking - fn transition_to_parked(&mut self, worker: &Worker) { + /// Prepares the worker state for parking. + /// + /// Returns true if the transition happend, false if there is work to do first. + fn transition_to_parked(&mut self, worker: &Worker) -> bool { + // Workers should not park if they have work to do + if self.lifo_slot.is_some() || self.run_queue.has_tasks() { + return false; + } + // When the final worker transitions **out** of searching to parked, it // must check all the queues one last time in case work materialized // between the last work scan and transitioning out of searching. @@ -517,6 +611,8 @@ impl Core { if is_last_searcher { worker.shared.notify_if_work_pending(); } + + true } /// Returns `true` if the transition happened. @@ -538,10 +634,9 @@ impl Core { true } - /// Runs maintenance work such as free pending tasks and check the pool's - /// state. + /// Runs maintenance work such as checking the pool's state. fn maintenance(&mut self, worker: &Worker) { - self.drain_pending_drop(worker); + self.stats.submit(&worker.shared.stats); if !self.is_shutdown { // Check if the scheduler has been shutdown @@ -549,31 +644,17 @@ impl Core { } } - // Signals all tasks to shut down, and waits for them to complete. Must run - // before we enter the single-threaded phase of shutdown processing. + /// Signals all tasks to shut down, and waits for them to complete. Must run + /// before we enter the single-threaded phase of shutdown processing. fn pre_shutdown(&mut self, worker: &Worker) { // Signal to all tasks to shut down. - for header in self.tasks.iter() { - header.shutdown(); - } + worker.shared.owned.close_and_shutdown_all(); - loop { - self.drain_pending_drop(worker); - - if self.tasks.is_empty() { - break; - } - - // Wait until signalled - let park = self.park.as_mut().expect("park missing"); - park.park().expect("park failed"); - } + self.stats.submit(&worker.shared.stats); } - // Shutdown the core + /// Shuts down the core. fn shutdown(&mut self) { - assert!(self.tasks.is_empty()); - // Take the core let mut park = self.park.take().expect("park missing"); @@ -582,142 +663,48 @@ impl Core { park.shutdown(); } - - fn drain_pending_drop(&mut self, worker: &Worker) { - use std::mem::ManuallyDrop; - - for task in worker.remote().pending_drop.drain() { - let task = ManuallyDrop::new(task); - - // safety: tasks are only pushed into the `pending_drop` stacks that - // are associated with the list they are inserted into. When a task - // is pushed into `pending_drop`, the ref-inc is skipped, so we must - // not ref-dec here. - // - // See `bind` and `release` implementations. - unsafe { - self.tasks.remove(task.header().into()); - } - } - } } impl Worker { - /// Returns a reference to the scheduler's injection queue - fn inject(&self) -> &queue::Inject<Arc<Worker>> { + /// Returns a reference to the scheduler's injection queue. + fn inject(&self) -> &Inject<Arc<Shared>> { &self.shared.inject } - - /// Return a reference to this worker's remote data - fn remote(&self) -> &Remote { - &self.shared.remotes[self.index] - } - - fn eq(&self, other: &Worker) -> bool { - self.shared.ptr_eq(&other.shared) && self.index == other.index - } } -impl task::Schedule for Arc<Worker> { - fn bind(task: Task) -> Arc<Worker> { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - - // Track the task - cx.core - .borrow_mut() - .as_mut() - .expect("scheduler core missing") - .tasks - .push_front(task); - - // Return a clone of the worker - cx.worker.clone() - }) - } - +impl task::Schedule for Arc<Shared> { fn release(&self, task: &Task) -> Option<Task> { - use std::ptr::NonNull; - - enum Immediate { - // Task has been synchronously removed from the Core owned by the - // current thread - Removed(Option<Task>), - // Task is owned by another thread, so we need to notify it to clean - // up the task later. - MaybeRemote, - } - - let immediate = CURRENT.with(|maybe_cx| { - let cx = match maybe_cx { - Some(cx) => cx, - None => return Immediate::MaybeRemote, - }; - - if !self.eq(&cx.worker) { - // Task owned by another core, so we need to notify it. - return Immediate::MaybeRemote; - } - - let mut maybe_core = cx.core.borrow_mut(); - - if let Some(core) = &mut *maybe_core { - // Directly remove the task - // - // safety: the task is inserted in the list in `bind`. - unsafe { - let ptr = NonNull::from(task.header()); - return Immediate::Removed(core.tasks.remove(ptr)); - } - } - - Immediate::MaybeRemote - }); - - // Checks if we were called from within a worker, allowing for immediate - // removal of a scheduled task. Else we have to go through the slower - // process below where we remotely mark a task as dropped. - match immediate { - Immediate::Removed(task) => return task, - Immediate::MaybeRemote => (), - }; - - // Track the task to be released by the worker that owns it - // - // Safety: We get a new handle without incrementing the ref-count. - // A ref-count is held by the "owned" linked list and it is only - // ever removed from that list as part of the release process: this - // method or popping the task from `pending_drop`. Thus, we can rely - // on the ref-count held by the linked-list to keep the memory - // alive. - // - // When the task is removed from the stack, it is forgotten instead - // of dropped. - let task = unsafe { Task::from_raw(task.header().into()) }; - - self.remote().pending_drop.push(task); - - // The worker core has been handed off to another thread. In the - // event that the scheduler is currently shutting down, the thread - // that owns the task may be waiting on the release to complete - // shutdown. - if self.inject().is_closed() { - self.remote().unpark.unpark(); - } - - None + self.owned.remove(task) } fn schedule(&self, task: Notified) { - self.shared.schedule(task, false); + (**self).schedule(task, false); } fn yield_now(&self, task: Notified) { - self.shared.schedule(task, true); + (**self).schedule(task, true); } } impl Shared { + pub(super) fn bind_new_task<T>(me: &Arc<Self>, future: T) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (handle, notified) = me.owned.bind(future, me.clone()); + + if let Some(notified) = notified { + me.schedule(notified, false); + } + + handle + } + + pub(crate) fn stats(&self) -> &RuntimeStats { + &self.stats + } + pub(super) fn schedule(&self, task: Notified, is_yield: bool) { CURRENT.with(|maybe_cx| { if let Some(cx) = maybe_cx { @@ -731,10 +718,10 @@ impl Shared { } } - // Otherwise, use the inject queue + // Otherwise, use the inject queue. self.inject.push(task); self.notify_parked(); - }); + }) } fn schedule_local(&self, core: &mut Core, task: Notified, is_yield: bool) { @@ -818,12 +805,18 @@ impl Shared { return; } + debug_assert!(self.owned.is_empty()); + for mut core in cores.drain(..) { core.shutdown(); } // Drain the injection queue - while self.inject.pop().is_some() {} + // + // We already shut down every task, so we can simply drop the tasks. + while let Some(task) = self.inject.pop() { + drop(task); + } } fn ptr_eq(&self, other: &Shared) -> bool { diff --git a/src/signal/ctrl_c.rs b/src/signal/ctrl_c.rs index 1eeeb85..b26ab7e 100644 --- a/src/signal/ctrl_c.rs +++ b/src/signal/ctrl_c.rs @@ -47,6 +47,15 @@ use std::io; /// println!("received ctrl-c event"); /// } /// ``` +/// +/// Listen in the background: +/// +/// ```rust,no_run +/// tokio::spawn(async move { +/// tokio::signal::ctrl_c().await.unwrap(); +/// // Your handler here +/// }); +/// ``` pub async fn ctrl_c() -> io::Result<()> { os_impl::ctrl_c()?.recv().await; Ok(()) diff --git a/src/signal/mod.rs b/src/signal/mod.rs index fe572f0..882218a 100644 --- a/src/signal/mod.rs +++ b/src/signal/mod.rs @@ -1,4 +1,4 @@ -//! Asynchronous signal handling for Tokio +//! Asynchronous signal handling for Tokio. //! //! Note that signal handling is in general a very tricky topic and should be //! used with great care. This crate attempts to implement 'best practice' for diff --git a/src/signal/registry.rs b/src/signal/registry.rs index 8b89108..e0a2df9 100644 --- a/src/signal/registry.rs +++ b/src/signal/registry.rs @@ -240,17 +240,17 @@ mod tests { let registry = Registry::new(vec![EventInfo::default(), EventInfo::default()]); registry.record_event(0); - assert_eq!(false, registry.broadcast()); + assert!(!registry.broadcast()); let first = registry.register_listener(0); let second = registry.register_listener(1); registry.record_event(0); - assert_eq!(true, registry.broadcast()); + assert!(registry.broadcast()); drop(first); registry.record_event(0); - assert_eq!(false, registry.broadcast()); + assert!(!registry.broadcast()); drop(second); } diff --git a/src/signal/reusable_box.rs b/src/signal/reusable_box.rs index 426ecb0..796fa21 100644 --- a/src/signal/reusable_box.rs +++ b/src/signal/reusable_box.rs @@ -30,7 +30,7 @@ impl<T> ReusableBoxFuture<T> { Self { boxed } } - /// Replace the future currently stored in this box. + /// Replaces the future currently stored in this box. /// /// This reallocates if and only if the layout of the provided future is /// different from the layout of the currently stored future. @@ -43,7 +43,7 @@ impl<T> ReusableBoxFuture<T> { } } - /// Replace the future currently stored in this box. + /// Replaces the future currently stored in this box. /// /// This function never reallocates, but returns an error if the provided /// future has a different size or alignment from the currently stored @@ -70,7 +70,7 @@ impl<T> ReusableBoxFuture<T> { } } - /// Set the current future. + /// Sets the current future. /// /// # Safety /// @@ -103,14 +103,14 @@ impl<T> ReusableBoxFuture<T> { } } - /// Get a pinned reference to the underlying future. + /// Gets a pinned reference to the underlying future. pub(crate) fn get_pin(&mut self) -> Pin<&mut (dyn Future<Output = T> + Send)> { // SAFETY: The user of this box cannot move the box, and we do not move it // either. unsafe { Pin::new_unchecked(self.boxed.as_mut()) } } - /// Poll the future stored inside this box. + /// Polls the future stored inside this box. pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<T> { self.get_pin().poll(cx) } @@ -119,7 +119,7 @@ impl<T> ReusableBoxFuture<T> { impl<T> Future for ReusableBoxFuture<T> { type Output = T; - /// Poll the future stored inside this box. + /// Polls the future stored inside this box. fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<T> { Pin::into_inner(self).get_pin().poll(cx) } diff --git a/src/signal/unix.rs b/src/signal/unix.rs index cb1d1cc..86ea9a9 100644 --- a/src/signal/unix.rs +++ b/src/signal/unix.rs @@ -4,12 +4,12 @@ //! `Signal` type for receiving notifications of signals. #![cfg(unix)] +#![cfg_attr(docsrs, doc(cfg(all(unix, feature = "signal"))))] use crate::signal::registry::{globals, EventId, EventInfo, Globals, Init, Storage}; use crate::signal::RxFuture; use crate::sync::watch; -use libc::c_int; use mio::net::UnixStream; use std::io::{self, Error, ErrorKind, Write}; use std::pin::Pin; @@ -61,7 +61,7 @@ impl Init for OsExtraData { /// Represents the specific kind of signal to listen for. #[derive(Debug, Clone, Copy)] -pub struct SignalKind(c_int); +pub struct SignalKind(libc::c_int); impl SignalKind { /// Allows for listening to any valid OS signal. @@ -74,8 +74,14 @@ impl SignalKind { /// // let signum = libc::OS_SPECIFIC_SIGNAL; /// let kind = SignalKind::from_raw(signum); /// ``` - pub fn from_raw(signum: c_int) -> Self { - Self(signum) + // Use `std::os::raw::c_int` on public API to prevent leaking a non-stable + // type alias from libc. + // `libc::c_int` and `std::os::raw::c_int` are currently the same type, and are + // unlikely to change to other types, but technically libc can change this + // in the future minor version. + // See https://github.com/tokio-rs/tokio/issues/3767 for more. + pub fn from_raw(signum: std::os::raw::c_int) -> Self { + Self(signum as libc::c_int) } /// Represents the SIGALRM signal. @@ -208,7 +214,7 @@ impl Default for SignalInfo { /// 2. Wake up the driver by writing a byte to a pipe /// /// Those two operations should both be async-signal safe. -fn action(globals: Pin<&'static Globals>, signal: c_int) { +fn action(globals: Pin<&'static Globals>, signal: libc::c_int) { globals.record_event(signal as EventId); // Send a wakeup, ignore any errors (anything reasonably possible is @@ -222,7 +228,7 @@ fn action(globals: Pin<&'static Globals>, signal: c_int) { /// /// This will register the signal handler if it hasn't already been registered, /// returning any error along the way if that fails. -fn signal_enable(signal: SignalKind, handle: Handle) -> io::Result<()> { +fn signal_enable(signal: SignalKind, handle: &Handle) -> io::Result<()> { let signal = signal.0; if signal < 0 || signal_hook_registry::FORBIDDEN.contains(&signal) { return Err(Error::new( @@ -352,7 +358,7 @@ pub struct Signal { /// * If the signal is one of /// [`signal_hook::FORBIDDEN`](fn@signal_hook_registry::register#panics) pub fn signal(kind: SignalKind) -> io::Result<Signal> { - let rx = signal_with_handle(kind, Handle::current())?; + let rx = signal_with_handle(kind, &Handle::current())?; Ok(Signal { inner: RxFuture::new(rx), @@ -361,7 +367,7 @@ pub fn signal(kind: SignalKind) -> io::Result<Signal> { pub(crate) fn signal_with_handle( kind: SignalKind, - handle: Handle, + handle: &Handle, ) -> io::Result<watch::Receiver<()>> { // Turn the signal delivery on once we are ready for it signal_enable(kind, handle)?; @@ -457,14 +463,14 @@ mod tests { #[test] fn signal_enable_error_on_invalid_input() { - signal_enable(SignalKind::from_raw(-1), Handle::default()).unwrap_err(); + signal_enable(SignalKind::from_raw(-1), &Handle::default()).unwrap_err(); } #[test] fn signal_enable_error_on_forbidden_input() { signal_enable( SignalKind::from_raw(signal_hook_registry::FORBIDDEN[0]), - Handle::default(), + &Handle::default(), ) .unwrap_err(); } diff --git a/src/signal/unix/driver.rs b/src/signal/unix/driver.rs index 315f3bd..5fe7c35 100644 --- a/src/signal/unix/driver.rs +++ b/src/signal/unix/driver.rs @@ -47,7 +47,7 @@ impl Driver { use std::mem::ManuallyDrop; use std::os::unix::io::{AsRawFd, FromRawFd}; - // NB: We give each driver a "fresh" reciever file descriptor to avoid + // NB: We give each driver a "fresh" receiver file descriptor to avoid // the issues described in alexcrichton/tokio-process#42. // // In the past we would reuse the actual receiver file descriptor and diff --git a/src/signal/windows.rs b/src/signal/windows.rs index c231d62..11ec6cb 100644 --- a/src/signal/windows.rs +++ b/src/signal/windows.rs @@ -5,127 +5,22 @@ //! `SetConsoleCtrlHandler` function which receives events of the type //! `CTRL_C_EVENT` and `CTRL_BREAK_EVENT`. -#![cfg(windows)] +#![cfg(any(windows, docsrs))] +#![cfg_attr(docsrs, doc(cfg(all(windows, feature = "signal"))))] -use crate::signal::registry::{globals, EventId, EventInfo, Init, Storage}; use crate::signal::RxFuture; - -use std::convert::TryFrom; use std::io; -use std::sync::Once; use std::task::{Context, Poll}; -use winapi::shared::minwindef::{BOOL, DWORD, FALSE, TRUE}; -use winapi::um::consoleapi::SetConsoleCtrlHandler; -use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_C_EVENT}; - -#[derive(Debug)] -pub(crate) struct OsStorage { - ctrl_c: EventInfo, - ctrl_break: EventInfo, -} - -impl Init for OsStorage { - fn init() -> Self { - Self { - ctrl_c: EventInfo::default(), - ctrl_break: EventInfo::default(), - } - } -} - -impl Storage for OsStorage { - fn event_info(&self, id: EventId) -> Option<&EventInfo> { - match DWORD::try_from(id) { - Ok(CTRL_C_EVENT) => Some(&self.ctrl_c), - Ok(CTRL_BREAK_EVENT) => Some(&self.ctrl_break), - _ => None, - } - } - - fn for_each<'a, F>(&'a self, mut f: F) - where - F: FnMut(&'a EventInfo), - { - f(&self.ctrl_c); - f(&self.ctrl_break); - } -} - -#[derive(Debug)] -pub(crate) struct OsExtraData {} -impl Init for OsExtraData { - fn init() -> Self { - Self {} - } -} - -/// Stream of events discovered via `SetConsoleCtrlHandler`. -/// -/// This structure can be used to listen for events of the type `CTRL_C_EVENT` -/// and `CTRL_BREAK_EVENT`. The `Stream` trait is implemented for this struct -/// and will resolve for each notification received by the process. Note that -/// there are few limitations with this as well: -/// -/// * A notification to this process notifies *all* `Event` streams for that -/// event type. -/// * Notifications to an `Event` stream **are coalesced** if they aren't -/// processed quickly enough. This means that if two notifications are -/// received back-to-back, then the stream may only receive one item about the -/// two notifications. -#[must_use = "streams do nothing unless polled"] -#[derive(Debug)] -pub(crate) struct Event { - inner: RxFuture, -} - -impl Event { - fn new(signum: DWORD) -> io::Result<Self> { - global_init()?; - - let rx = globals().register_listener(signum as EventId); - - Ok(Self { - inner: RxFuture::new(rx), - }) - } -} +#[cfg(not(docsrs))] +#[path = "windows/sys.rs"] +mod imp; +#[cfg(not(docsrs))] +pub(crate) use self::imp::{OsExtraData, OsStorage}; -fn global_init() -> io::Result<()> { - static INIT: Once = Once::new(); - - let mut init = None; - - INIT.call_once(|| unsafe { - let rc = SetConsoleCtrlHandler(Some(handler), TRUE); - let ret = if rc == 0 { - Err(io::Error::last_os_error()) - } else { - Ok(()) - }; - - init = Some(ret); - }); - - init.unwrap_or_else(|| Ok(())) -} - -unsafe extern "system" fn handler(ty: DWORD) -> BOOL { - let globals = globals(); - globals.record_event(ty as EventId); - - // According to https://docs.microsoft.com/en-us/windows/console/handlerroutine - // the handler routine is always invoked in a new thread, thus we don't - // have the same restrictions as in Unix signal handlers, meaning we can - // go ahead and perform the broadcast here. - if globals.broadcast() { - TRUE - } else { - // No one is listening for this notification any more - // let the OS fire the next (possibly the default) handler. - FALSE - } -} +#[cfg(docsrs)] +#[path = "windows/stub.rs"] +mod imp; /// Creates a new stream which receives "ctrl-c" notifications sent to the /// process. @@ -150,7 +45,9 @@ unsafe extern "system" fn handler(ty: DWORD) -> BOOL { /// } /// ``` pub fn ctrl_c() -> io::Result<CtrlC> { - Event::new(CTRL_C_EVENT).map(|inner| CtrlC { inner }) + Ok(CtrlC { + inner: self::imp::ctrl_c()?, + }) } /// Represents a stream which receives "ctrl-c" notifications sent to the process @@ -163,7 +60,7 @@ pub fn ctrl_c() -> io::Result<CtrlC> { #[must_use = "streams do nothing unless polled"] #[derive(Debug)] pub struct CtrlC { - inner: Event, + inner: RxFuture, } impl CtrlC { @@ -191,7 +88,7 @@ impl CtrlC { /// } /// ``` pub async fn recv(&mut self) -> Option<()> { - self.inner.inner.recv().await + self.inner.recv().await } /// Polls to receive the next signal notification event, outside of an @@ -223,7 +120,7 @@ impl CtrlC { /// } /// ``` pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { - self.inner.inner.poll_recv(cx) + self.inner.poll_recv(cx) } } @@ -237,7 +134,7 @@ impl CtrlC { #[must_use = "streams do nothing unless polled"] #[derive(Debug)] pub struct CtrlBreak { - inner: Event, + inner: RxFuture, } impl CtrlBreak { @@ -263,7 +160,7 @@ impl CtrlBreak { /// } /// ``` pub async fn recv(&mut self) -> Option<()> { - self.inner.inner.recv().await + self.inner.recv().await } /// Polls to receive the next signal notification event, outside of an @@ -295,7 +192,7 @@ impl CtrlBreak { /// } /// ``` pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> { - self.inner.inner.poll_recv(cx) + self.inner.poll_recv(cx) } } @@ -320,56 +217,7 @@ impl CtrlBreak { /// } /// ``` pub fn ctrl_break() -> io::Result<CtrlBreak> { - Event::new(CTRL_BREAK_EVENT).map(|inner| CtrlBreak { inner }) -} - -#[cfg(all(test, not(loom)))] -mod tests { - use super::*; - use crate::runtime::Runtime; - - use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; - - #[test] - fn ctrl_c() { - let rt = rt(); - let _enter = rt.enter(); - - let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); - - assert_pending!(ctrl_c.poll()); - - // Windows doesn't have a good programmatic way of sending events - // like sending signals on Unix, so we'll stub out the actual OS - // integration and test that our handling works. - unsafe { - super::handler(CTRL_C_EVENT); - } - - assert_ready_ok!(ctrl_c.poll()); - } - - #[test] - fn ctrl_break() { - let rt = rt(); - - rt.block_on(async { - let mut ctrl_break = assert_ok!(super::ctrl_break()); - - // Windows doesn't have a good programmatic way of sending events - // like sending signals on Unix, so we'll stub out the actual OS - // integration and test that our handling works. - unsafe { - super::handler(CTRL_BREAK_EVENT); - } - - ctrl_break.recv().await.unwrap(); - }); - } - - fn rt() -> Runtime { - crate::runtime::Builder::new_current_thread() - .build() - .unwrap() - } + Ok(CtrlBreak { + inner: self::imp::ctrl_break()?, + }) } diff --git a/src/signal/windows/stub.rs b/src/signal/windows/stub.rs new file mode 100644 index 0000000..8863054 --- /dev/null +++ b/src/signal/windows/stub.rs @@ -0,0 +1,13 @@ +//! Stub implementations for the platform API so that rustdoc can build linkable +//! documentation on non-windows platforms. + +use crate::signal::RxFuture; +use std::io; + +pub(super) fn ctrl_c() -> io::Result<RxFuture> { + panic!() +} + +pub(super) fn ctrl_break() -> io::Result<RxFuture> { + panic!() +} diff --git a/src/signal/windows/sys.rs b/src/signal/windows/sys.rs new file mode 100644 index 0000000..8d29c35 --- /dev/null +++ b/src/signal/windows/sys.rs @@ -0,0 +1,153 @@ +use std::convert::TryFrom; +use std::io; +use std::sync::Once; + +use crate::signal::registry::{globals, EventId, EventInfo, Init, Storage}; +use crate::signal::RxFuture; + +use winapi::shared::minwindef::{BOOL, DWORD, FALSE, TRUE}; +use winapi::um::consoleapi::SetConsoleCtrlHandler; +use winapi::um::wincon::{CTRL_BREAK_EVENT, CTRL_C_EVENT}; + +pub(super) fn ctrl_c() -> io::Result<RxFuture> { + new(CTRL_C_EVENT) +} + +pub(super) fn ctrl_break() -> io::Result<RxFuture> { + new(CTRL_BREAK_EVENT) +} + +fn new(signum: DWORD) -> io::Result<RxFuture> { + global_init()?; + let rx = globals().register_listener(signum as EventId); + Ok(RxFuture::new(rx)) +} + +#[derive(Debug)] +pub(crate) struct OsStorage { + ctrl_c: EventInfo, + ctrl_break: EventInfo, +} + +impl Init for OsStorage { + fn init() -> Self { + Self { + ctrl_c: EventInfo::default(), + ctrl_break: EventInfo::default(), + } + } +} + +impl Storage for OsStorage { + fn event_info(&self, id: EventId) -> Option<&EventInfo> { + match DWORD::try_from(id) { + Ok(CTRL_C_EVENT) => Some(&self.ctrl_c), + Ok(CTRL_BREAK_EVENT) => Some(&self.ctrl_break), + _ => None, + } + } + + fn for_each<'a, F>(&'a self, mut f: F) + where + F: FnMut(&'a EventInfo), + { + f(&self.ctrl_c); + f(&self.ctrl_break); + } +} + +#[derive(Debug)] +pub(crate) struct OsExtraData {} + +impl Init for OsExtraData { + fn init() -> Self { + Self {} + } +} + +fn global_init() -> io::Result<()> { + static INIT: Once = Once::new(); + + let mut init = None; + + INIT.call_once(|| unsafe { + let rc = SetConsoleCtrlHandler(Some(handler), TRUE); + let ret = if rc == 0 { + Err(io::Error::last_os_error()) + } else { + Ok(()) + }; + + init = Some(ret); + }); + + init.unwrap_or_else(|| Ok(())) +} + +unsafe extern "system" fn handler(ty: DWORD) -> BOOL { + let globals = globals(); + globals.record_event(ty as EventId); + + // According to https://docs.microsoft.com/en-us/windows/console/handlerroutine + // the handler routine is always invoked in a new thread, thus we don't + // have the same restrictions as in Unix signal handlers, meaning we can + // go ahead and perform the broadcast here. + if globals.broadcast() { + TRUE + } else { + // No one is listening for this notification any more + // let the OS fire the next (possibly the default) handler. + FALSE + } +} + +#[cfg(all(test, not(loom)))] +mod tests { + use super::*; + use crate::runtime::Runtime; + + use tokio_test::{assert_ok, assert_pending, assert_ready_ok, task}; + + #[test] + fn ctrl_c() { + let rt = rt(); + let _enter = rt.enter(); + + let mut ctrl_c = task::spawn(crate::signal::ctrl_c()); + + assert_pending!(ctrl_c.poll()); + + // Windows doesn't have a good programmatic way of sending events + // like sending signals on Unix, so we'll stub out the actual OS + // integration and test that our handling works. + unsafe { + super::handler(CTRL_C_EVENT); + } + + assert_ready_ok!(ctrl_c.poll()); + } + + #[test] + fn ctrl_break() { + let rt = rt(); + + rt.block_on(async { + let mut ctrl_break = assert_ok!(crate::signal::windows::ctrl_break()); + + // Windows doesn't have a good programmatic way of sending events + // like sending signals on Unix, so we'll stub out the actual OS + // integration and test that our handling works. + unsafe { + super::handler(CTRL_BREAK_EVENT); + } + + ctrl_break.recv().await.unwrap(); + }); + } + + fn rt() -> Runtime { + crate::runtime::Builder::new_current_thread() + .build() + .unwrap() + } +} diff --git a/src/sync/barrier.rs b/src/sync/barrier.rs index a8b291f..0e39dac 100644 --- a/src/sync/barrier.rs +++ b/src/sync/barrier.rs @@ -1,8 +1,7 @@ +use crate::loom::sync::Mutex; use crate::sync::watch; -use std::sync::Mutex; - -/// A barrier enables multiple threads to synchronize the beginning of some computation. +/// A barrier enables multiple tasks to synchronize the beginning of some computation. /// /// ``` /// # #[tokio::main] @@ -52,10 +51,10 @@ struct BarrierState { } impl Barrier { - /// Creates a new barrier that can block a given number of threads. + /// Creates a new barrier that can block a given number of tasks. /// - /// A barrier will block `n`-1 threads which call [`Barrier::wait`] and then wake up all - /// threads at once when the `n`th thread calls `wait`. + /// A barrier will block `n`-1 tasks which call [`Barrier::wait`] and then wake up all + /// tasks at once when the `n`th task calls `wait`. pub fn new(mut n: usize) -> Barrier { let (waker, wait) = crate::sync::watch::channel(0); @@ -79,11 +78,11 @@ impl Barrier { /// Does not resolve until all tasks have rendezvoused here. /// - /// Barriers are re-usable after all threads have rendezvoused once, and can + /// Barriers are re-usable after all tasks have rendezvoused once, and can /// be used continuously. /// /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from - /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads + /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other tasks /// will receive a result that will return `false` from `is_leader`. pub async fn wait(&self) -> BarrierWaitResult { // NOTE: we are taking a _synchronous_ lock here. @@ -94,7 +93,7 @@ impl Barrier { // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across // a yield point, and thus marks the returned future as !Send. let generation = { - let mut state = self.state.lock().unwrap(); + let mut state = self.state.lock(); let generation = state.generation; state.arrived += 1; if state.arrived == self.n { @@ -129,14 +128,14 @@ impl Barrier { } } -/// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused. +/// A `BarrierWaitResult` is returned by `wait` when all tasks in the `Barrier` have rendezvoused. #[derive(Debug, Clone)] pub struct BarrierWaitResult(bool); impl BarrierWaitResult { - /// Returns `true` if this thread from wait is the "leader thread". + /// Returns `true` if this task from wait is the "leader task". /// - /// Only one thread will have `true` returned from their result, all other threads will have + /// Only one task will have `true` returned from their result, all other tasks will have /// `false` returned. pub fn is_leader(&self) -> bool { self.0 diff --git a/src/sync/batch_semaphore.rs b/src/sync/batch_semaphore.rs index a0bf5ef..b5c39d2 100644 --- a/src/sync/batch_semaphore.rs +++ b/src/sync/batch_semaphore.rs @@ -1,5 +1,5 @@ #![cfg_attr(not(feature = "sync"), allow(unreachable_pub, dead_code))] -//! # Implementation Details +//! # Implementation Details. //! //! The semaphore is implemented using an intrusive linked list of waiters. An //! atomic counter tracks the number of available permits. If the semaphore does @@ -19,6 +19,7 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::{Mutex, MutexGuard}; use crate::util::linked_list::{self, LinkedList}; +use crate::util::WakeList; use std::future::Future; use std::marker::PhantomPinned; @@ -137,7 +138,7 @@ impl Semaphore { } } - /// Creates a new semaphore with the initial number of permits + /// Creates a new semaphore with the initial number of permits. /// /// Maximum number of permits on 32-bit platforms is `1<<29`. /// @@ -158,7 +159,7 @@ impl Semaphore { } } - /// Returns the current number of available permits + /// Returns the current number of available permits. pub(crate) fn available_permits(&self) -> usize { self.permits.load(Acquire) >> Self::PERMIT_SHIFT } @@ -196,7 +197,7 @@ impl Semaphore { } } - /// Returns true if the semaphore is closed + /// Returns true if the semaphore is closed. pub(crate) fn is_closed(&self) -> bool { self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED } @@ -239,12 +240,12 @@ impl Semaphore { /// If `rem` exceeds the number of permits needed by the wait list, the /// remainder are assigned back to the semaphore. fn add_permits_locked(&self, mut rem: usize, waiters: MutexGuard<'_, Waitlist>) { - let mut wakers: [Option<Waker>; 8] = Default::default(); + let mut wakers = WakeList::new(); let mut lock = Some(waiters); let mut is_empty = false; while rem > 0 { let mut waiters = lock.take().unwrap_or_else(|| self.waiters.lock()); - 'inner: for slot in &mut wakers[..] { + 'inner: while wakers.can_push() { // Was the waiter assigned enough permits to wake it? match waiters.queue.last() { Some(waiter) => { @@ -260,7 +261,11 @@ impl Semaphore { } }; let mut waiter = waiters.queue.pop_back().unwrap(); - *slot = unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) }; + if let Some(waker) = + unsafe { waiter.as_mut().waker.with_mut(|waker| (*waker).take()) } + { + wakers.push(waker); + } } if rem > 0 && is_empty { @@ -283,10 +288,7 @@ impl Semaphore { drop(waiters); // release the lock - wakers - .iter_mut() - .filter_map(Option::take) - .for_each(Waker::wake); + wakers.wake_all(); } assert_eq!(rem, 0); @@ -478,7 +480,7 @@ impl<'a> Acquire<'a> { let this = self.get_unchecked_mut(); ( Pin::new_unchecked(&mut this.node), - &this.semaphore, + this.semaphore, this.num_permits, &mut this.queued, ) diff --git a/src/sync/broadcast.rs b/src/sync/broadcast.rs index 3ef8f84..0d9cd3b 100644 --- a/src/sync/broadcast.rs +++ b/src/sync/broadcast.rs @@ -293,37 +293,37 @@ pub mod error { use self::error::*; -/// Data shared between senders and receivers +/// Data shared between senders and receivers. struct Shared<T> { - /// slots in the channel + /// slots in the channel. buffer: Box<[RwLock<Slot<T>>]>, - /// Mask a position -> index + /// Mask a position -> index. mask: usize, /// Tail of the queue. Includes the rx wait list. tail: Mutex<Tail>, - /// Number of outstanding Sender handles + /// Number of outstanding Sender handles. num_tx: AtomicUsize, } -/// Next position to write a value +/// Next position to write a value. struct Tail { - /// Next position to write to + /// Next position to write to. pos: u64, - /// Number of active receivers + /// Number of active receivers. rx_cnt: usize, - /// True if the channel is closed + /// True if the channel is closed. closed: bool, - /// Receivers waiting for a value + /// Receivers waiting for a value. waiters: LinkedList<Waiter, <Waiter as linked_list::Link>::Target>, } -/// Slot in the buffer +/// Slot in the buffer. struct Slot<T> { /// Remaining number of receivers that are expected to see this value. /// @@ -333,7 +333,7 @@ struct Slot<T> { /// acquired. rem: AtomicUsize, - /// Uniquely identifies the `send` stored in the slot + /// Uniquely identifies the `send` stored in the slot. pos: u64, /// True signals the channel is closed. @@ -346,9 +346,9 @@ struct Slot<T> { val: UnsafeCell<Option<T>>, } -/// An entry in the wait queue +/// An entry in the wait queue. struct Waiter { - /// True if queued + /// True if queued. queued: bool, /// Task waiting on the broadcast channel. @@ -365,12 +365,12 @@ struct RecvGuard<'a, T> { slot: RwLockReadGuard<'a, Slot<T>>, } -/// Receive a value future +/// Receive a value future. struct Recv<'a, T> { - /// Receiver being waited on + /// Receiver being waited on. receiver: &'a mut Receiver<T>, - /// Entry in the waiter `LinkedList` + /// Entry in the waiter `LinkedList`. waiter: UnsafeCell<Waiter>, } @@ -824,6 +824,13 @@ impl<T: Clone> Receiver<T> { /// the channel. A subsequent call to [`recv`] will return this value /// **unless** it has been since overwritten. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// /// [`Receiver`]: crate::sync::broadcast::Receiver /// [`recv`]: crate::sync::broadcast::Receiver::recv /// diff --git a/src/sync/mod.rs b/src/sync/mod.rs index d89a9dd..457e6ab 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -428,6 +428,11 @@ //! bounding of any kind. cfg_sync! { + /// Named future types. + pub mod futures { + pub use super::notify::Notified; + } + mod barrier; pub use barrier::{Barrier, BarrierWaitResult}; @@ -436,7 +441,7 @@ cfg_sync! { pub mod mpsc; mod mutex; - pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard}; + pub use mutex::{Mutex, MutexGuard, TryLockError, OwnedMutexGuard, MappedMutexGuard}; pub(crate) mod notify; pub use notify::Notify; diff --git a/src/sync/mpsc/block.rs b/src/sync/mpsc/block.rs index 1c9ab14..58f4a9f 100644 --- a/src/sync/mpsc/block.rs +++ b/src/sync/mpsc/block.rs @@ -1,6 +1,5 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::{AtomicPtr, AtomicUsize}; -use crate::loom::thread; use std::mem::MaybeUninit; use std::ops; @@ -41,7 +40,7 @@ struct Values<T>([UnsafeCell<MaybeUninit<T>>; BLOCK_CAP]); use super::BLOCK_CAP; -/// Masks an index to get the block identifier +/// Masks an index to get the block identifier. const BLOCK_MASK: usize = !(BLOCK_CAP - 1); /// Masks an index to get the value offset in a block. @@ -90,7 +89,7 @@ impl<T> Block<T> { } } - /// Returns `true` if the block matches the given index + /// Returns `true` if the block matches the given index. pub(crate) fn is_at_index(&self, index: usize) -> bool { debug_assert!(offset(index) == 0); self.start_index == index @@ -344,8 +343,7 @@ impl<T> Block<T> { Err(curr) => curr, }; - // When running outside of loom, this calls `spin_loop_hint`. - thread::yield_now(); + crate::loom::thread::yield_now(); } } } diff --git a/src/sync/mpsc/bounded.rs b/src/sync/mpsc/bounded.rs index 1f670bf..5a2bfa6 100644 --- a/src/sync/mpsc/bounded.rs +++ b/src/sync/mpsc/bounded.rs @@ -1,6 +1,6 @@ use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError}; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::{SendError, TrySendError}; +use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError}; cfg_time! { use crate::sync::mpsc::error::SendTimeoutError; @@ -10,19 +10,19 @@ cfg_time! { use std::fmt; use std::task::{Context, Poll}; -/// Send values to the associated `Receiver`. +/// Sends values to the associated `Receiver`. /// /// Instances are created by the [`channel`](channel) function. /// -/// To use the `Sender` in a poll function, you can use the [`PollSender`] -/// utility. +/// To convert the `Sender` into a `Sink` or use it in a poll function, you can +/// use the [`PollSender`] utility. /// /// [`PollSender`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSender.html pub struct Sender<T> { chan: chan::Tx<T, Semaphore>, } -/// Permit to send one value into the channel. +/// Permits to send one value into the channel. /// /// `Permit` values are returned by [`Sender::reserve()`] and [`Sender::try_reserve()`] /// and are used to guarantee channel capacity before generating a message to send. @@ -33,7 +33,23 @@ pub struct Permit<'a, T> { chan: &'a chan::Tx<T, Semaphore>, } -/// Receive values from the associated `Sender`. +/// Owned permit to send one value into the channel. +/// +/// This is identical to the [`Permit`] type, except that it moves the sender +/// rather than borrowing it. +/// +/// `OwnedPermit` values are returned by [`Sender::reserve_owned()`] and +/// [`Sender::try_reserve_owned()`] and are used to guarantee channel capacity +/// before generating a message to send. +/// +/// [`Permit`]: Permit +/// [`Sender::reserve_owned()`]: Sender::reserve_owned +/// [`Sender::try_reserve_owned()`]: Sender::try_reserve_owned +pub struct OwnedPermit<T> { + chan: Option<chan::Tx<T, Semaphore>>, +} + +/// Receives values from the associated `Sender`. /// /// Instances are created by the [`channel`](channel) function. /// @@ -41,7 +57,7 @@ pub struct Permit<'a, T> { /// /// [`ReceiverStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.ReceiverStream.html pub struct Receiver<T> { - /// The channel receiver + /// The channel receiver. chan: chan::Rx<T, Semaphore>, } @@ -49,7 +65,7 @@ pub struct Receiver<T> { /// with backpressure. /// /// The channel will buffer up to the provided number of messages. Once the -/// buffer is full, attempts to `send` new messages will wait until a message is +/// buffer is full, attempts to send new messages will wait until a message is /// received from the channel. The provided buffer capacity must be at least 1. /// /// All data sent on `Sender` will become available on `Receiver` in the same @@ -60,7 +76,7 @@ pub struct Receiver<T> { /// /// If the `Receiver` is disconnected while trying to `send`, the `send` method /// will return a `SendError`. Similarly, if `Sender` is disconnected while -/// trying to `recv`, the `recv` method will return a `RecvError`. +/// trying to `recv`, the `recv` method will return `None`. /// /// # Panics /// @@ -118,11 +134,16 @@ impl<T> Receiver<T> { /// /// If there are no messages in the channel's buffer, but the channel has /// not yet been closed, this method will sleep until a message is sent or - /// the channel is closed. + /// the channel is closed. Note that if [`close`] is called, but there are + /// still outstanding [`Permits`] from before it was closed, the channel is + /// not considered closed by `recv` until the permits are released. /// - /// Note that if [`close`] is called, but there are still outstanding - /// [`Permits`] from before it was closed, the channel is not considered - /// closed by `recv` until the permits are released. + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. /// /// [`close`]: Self::close /// [`Permits`]: struct@crate::sync::mpsc::Permit @@ -166,6 +187,50 @@ impl<T> Receiver<T> { poll_fn(|cx| self.chan.recv(cx)).await } + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`] method, this method will never return an + /// [`Empty`] error spuriously. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [`poll_recv`]: Self::poll_recv + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(100); + /// + /// tx.send("hello").await.unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").await.unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + /// Blocking receive to call outside of asynchronous contexts. /// /// This method returns `None` if the channel has been closed and there are @@ -229,10 +294,11 @@ impl<T> Receiver<T> { /// /// To guarantee that no messages are dropped, after calling `close()`, /// `recv()` must be called until `None` is returned. If there are - /// outstanding [`Permit`] values, the `recv` method will not return `None` - /// until those are released. + /// outstanding [`Permit`] or [`OwnedPermit`] values, the `recv` method will + /// not return `None` until those are released. /// /// [`Permit`]: Permit + /// [`OwnedPermit`]: OwnedPermit /// /// # Examples /// @@ -269,7 +335,7 @@ impl<T> Receiver<T> { /// This method returns: /// /// * `Poll::Pending` if no messages are available but the channel is not - /// closed. + /// closed, or if a spurious failure happens. /// * `Poll::Ready(Some(message))` if a message is available. /// * `Poll::Ready(None)` if the channel has been closed and all messages /// sent before it was closed have been received. @@ -279,6 +345,12 @@ impl<T> Receiver<T> { /// receiver, or when the channel is closed. Note that on multiple calls to /// `poll_recv`, only the `Waker` from the `Context` passed to the most /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { self.chan.recv(cx) } @@ -318,6 +390,16 @@ impl<T> Sender<T> { /// [`close`]: Receiver::close /// [`Receiver`]: Receiver /// + /// # Cancel safety + /// + /// If `send` is used as the event in a [`tokio::select!`](crate::select) + /// statement and some other branch completes first, then it is guaranteed + /// that the message was not sent. + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `send` makes you lose your place in the queue. + /// /// # Examples /// /// In the following example, each call to `send` will block until the @@ -359,6 +441,11 @@ impl<T> Sender<T> { /// This allows the producers to get notified when interest in the produced /// values is canceled and immediately stop doing work. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// /// # Examples /// /// ``` @@ -585,7 +672,7 @@ impl<T> Sender<T> { self.chan.is_closed() } - /// Wait for channel capacity. Once capacity to send one message is + /// Waits for channel capacity. Once capacity to send one message is /// available, it is reserved for the caller. /// /// If the channel is full, the function waits for the number of unreceived @@ -600,6 +687,12 @@ impl<T> Sender<T> { /// [`Permit`]: Permit /// [`send`]: Permit::send /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `reserve` makes you lose your place in the queue. + /// /// # Examples /// /// ``` @@ -624,15 +717,105 @@ impl<T> Sender<T> { /// } /// ``` pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> { + self.reserve_inner().await?; + Ok(Permit { chan: &self.chan }) + } + + /// Waits for channel capacity, moving the `Sender` and returning an owned + /// permit. Once capacity to send one message is available, it is reserved + /// for the caller. + /// + /// This moves the sender _by value_, and returns an owned permit that can + /// be used to send a message into the channel. Unlike [`Sender::reserve`], + /// this method may be used in cases where the permit must be valid for the + /// `'static` lifetime. `Sender`s may be cloned cheaply (`Sender::clone` is + /// essentially a reference count increment, comparable to [`Arc::clone`]), + /// so when multiple [`OwnedPermit`]s are needed or the `Sender` cannot be + /// moved, it can be cloned prior to calling `reserve_owned`. + /// + /// If the channel is full, the function waits for the number of unreceived + /// messages to become less than the channel capacity. Capacity to send one + /// message is reserved for the caller. An [`OwnedPermit`] is returned to + /// track the reserved capacity. The [`send`] function on [`OwnedPermit`] + /// consumes the reserved capacity. + /// + /// Dropping the [`OwnedPermit`] without sending a message releases the + /// capacity back to the channel. + /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve` + /// complete in the order they were requested. Cancelling a call to + /// `reserve_owned` makes you lose your place in the queue. + /// + /// # Examples + /// Sending a message using an [`OwnedPermit`]: + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity, moving the sender. + /// let permit = tx.reserve_owned().await.unwrap(); + /// + /// // Send a message, consuming the permit and returning + /// // the moved sender. + /// let tx = permit.send(123); + /// + /// // The value sent on the permit is received. + /// assert_eq!(rx.recv().await.unwrap(), 123); + /// + /// // The sender can now be used again. + /// tx.send(456).await.unwrap(); + /// } + /// ``` + /// + /// When multiple [`OwnedPermit`]s are needed, or the sender cannot be moved + /// by value, it can be inexpensively cloned before calling `reserve_owned`: + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Clone the sender and reserve capacity. + /// let permit = tx.clone().reserve_owned().await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Sending on the permit succeeds. + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// } + /// ``` + /// + /// [`Sender::reserve`]: Sender::reserve + /// [`OwnedPermit`]: OwnedPermit + /// [`send`]: OwnedPermit::send + /// [`Arc::clone`]: std::sync::Arc::clone + pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> { + self.reserve_inner().await?; + Ok(OwnedPermit { + chan: Some(self.chan), + }) + } + + async fn reserve_inner(&self) -> Result<(), SendError<()>> { match self.chan.semaphore().0.acquire(1).await { - Ok(_) => {} - Err(_) => return Err(SendError(())), + Ok(_) => Ok(()), + Err(_) => Err(SendError(())), } - - Ok(Permit { chan: &self.chan }) } - /// Try to acquire a slot in the channel without waiting for the slot to become + /// Tries to acquire a slot in the channel without waiting for the slot to become /// available. /// /// If the channel is full this function will return [`TrySendError`], otherwise @@ -678,12 +861,80 @@ impl<T> Sender<T> { pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> { match self.chan.semaphore().0.try_acquire(1) { Ok(_) => {} - Err(_) => return Err(TrySendError::Full(())), + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())), } Ok(Permit { chan: &self.chan }) } + /// Tries to acquire a slot in the channel without waiting for the slot to become + /// available, returning an owned permit. + /// + /// This moves the sender _by value_, and returns an owned permit that can + /// be used to send a message into the channel. Unlike [`Sender::try_reserve`], + /// this method may be used in cases where the permit must be valid for the + /// `'static` lifetime. `Sender`s may be cloned cheaply (`Sender::clone` is + /// essentially a reference count increment, comparable to [`Arc::clone`]), + /// so when multiple [`OwnedPermit`]s are needed or the `Sender` cannot be + /// moved, it can be cloned prior to calling `try_reserve_owned`. + /// + /// If the channel is full this function will return a [`TrySendError`]. + /// Since the sender is taken by value, the `TrySendError` returned in this + /// case contains the sender, so that it may be used again. Otherwise, if + /// there is a slot available, this method will return an [`OwnedPermit`] + /// that can then be used to [`send`] on the channel with a guaranteed slot. + /// This function is similar to [`reserve_owned`] except it does not await + /// for the slot to become available. + /// + /// Dropping the [`OwnedPermit`] without sending a message releases the capacity back + /// to the channel. + /// + /// [`OwnedPermit`]: OwnedPermit + /// [`send`]: OwnedPermit::send + /// [`reserve_owned`]: Sender::reserve_owned + /// [`Arc::clone`]: std::sync::Arc::clone + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.clone().try_reserve_owned().unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Trying to reserve an additional slot on the `tx` will + /// // fail because there is no capacity. + /// assert!(tx.try_reserve().is_err()); + /// + /// // Sending on the permit succeeds + /// permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// + /// } + /// ``` + pub fn try_reserve_owned(self) -> Result<OwnedPermit<T>, TrySendError<Self>> { + match self.chan.semaphore().0.try_acquire(1) { + Ok(_) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(self)), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(self)), + } + + Ok(OwnedPermit { + chan: Some(self.chan), + }) + } + /// Returns `true` if senders belong to the same channel. /// /// # Examples @@ -720,7 +971,7 @@ impl<T> Sender<T> { /// let permit = tx.reserve().await.unwrap(); /// assert_eq!(tx.capacity(), 4); /// - /// // Sending and receiving a value increases the caapcity by one. + /// // Sending and receiving a value increases the capacity by one. /// permit.send(()); /// rx.recv().await.unwrap(); /// assert_eq!(tx.capacity(), 5); @@ -804,6 +1055,8 @@ impl<T> Drop for Permit<'_, T> { // Add the permit back to the semaphore semaphore.add_permit(); + // If this is the last sender for this channel, wake the receiver so + // that it can be notified that the channel is closed. if semaphore.is_closed() && semaphore.is_idle() { self.chan.wake_rx(); } @@ -817,3 +1070,123 @@ impl<T> fmt::Debug for Permit<'_, T> { .finish() } } + +// ===== impl Permit ===== + +impl<T> OwnedPermit<T> { + /// Sends a value using the reserved capacity. + /// + /// Capacity for the message has already been reserved. The message is sent + /// to the receiver and the permit is consumed. The operation will succeed + /// even if the receiver half has been closed. See [`Receiver::close`] for + /// more details on performing a clean shutdown. + /// + /// Unlike [`Permit::send`], this method returns the [`Sender`] from which + /// the `OwnedPermit` was reserved. + /// + /// [`Receiver::close`]: Receiver::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(1); + /// + /// // Reserve capacity + /// let permit = tx.reserve_owned().await.unwrap(); + /// + /// // Send a message on the permit, returning the sender. + /// let tx = permit.send(456); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// + /// // We may now reuse `tx` to send another message. + /// tx.send(789).await.unwrap(); + /// } + /// ``` + pub fn send(mut self, value: T) -> Sender<T> { + let chan = self.chan.take().unwrap_or_else(|| { + unreachable!("OwnedPermit channel is only taken when the permit is moved") + }); + chan.send(value); + + Sender { chan } + } + + /// Releases the reserved capacity *without* sending a message, returning the + /// [`Sender`]. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx) = mpsc::channel(1); + /// + /// // Clone the sender and reserve capacity + /// let permit = tx.clone().reserve_owned().await.unwrap(); + /// + /// // Trying to send on the original `tx` will fail, since the `permit` + /// // has reserved all the available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Release the permit without sending a message, returning the clone + /// // of the sender. + /// let tx2 = permit.release(); + /// + /// // We may now reuse `tx` to send another message. + /// tx.send(789).await.unwrap(); + /// # drop(rx); drop(tx2); + /// } + /// ``` + /// + /// [`Sender`]: Sender + pub fn release(mut self) -> Sender<T> { + use chan::Semaphore; + + let chan = self.chan.take().unwrap_or_else(|| { + unreachable!("OwnedPermit channel is only taken when the permit is moved") + }); + + // Add the permit back to the semaphore + chan.semaphore().add_permit(); + Sender { chan } + } +} + +impl<T> Drop for OwnedPermit<T> { + fn drop(&mut self) { + use chan::Semaphore; + + // Are we still holding onto the sender? + if let Some(chan) = self.chan.take() { + let semaphore = chan.semaphore(); + + // Add the permit back to the semaphore + semaphore.add_permit(); + + // If this `OwnedPermit` is holding the last sender for this + // channel, wake the receiver so that it can be notified that the + // channel is closed. + if semaphore.is_closed() && semaphore.is_idle() { + chan.wake_rx(); + } + } + + // Otherwise, do nothing. + } +} + +impl<T> fmt::Debug for OwnedPermit<T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("OwnedPermit") + .field("chan", &self.chan) + .finish() + } +} diff --git a/src/sync/mpsc/chan.rs b/src/sync/mpsc/chan.rs index 554d022..c3007de 100644 --- a/src/sync/mpsc/chan.rs +++ b/src/sync/mpsc/chan.rs @@ -2,6 +2,9 @@ use crate::loom::cell::UnsafeCell; use crate::loom::future::AtomicWaker; use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Arc; +use crate::park::thread::CachedParkThread; +use crate::park::Park; +use crate::sync::mpsc::error::TryRecvError; use crate::sync::mpsc::list; use crate::sync::notify::Notify; @@ -11,7 +14,7 @@ use std::sync::atomic::Ordering::{AcqRel, Relaxed}; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; -/// Channel sender +/// Channel sender. pub(crate) struct Tx<T, S> { inner: Arc<Chan<T, S>>, } @@ -22,7 +25,7 @@ impl<T, S: fmt::Debug> fmt::Debug for Tx<T, S> { } } -/// Channel receiver +/// Channel receiver. pub(crate) struct Rx<T, S: Semaphore> { inner: Arc<Chan<T, S>>, } @@ -44,7 +47,7 @@ pub(crate) trait Semaphore { } struct Chan<T, S> { - /// Notifies all tasks listening for the receiver being dropped + /// Notifies all tasks listening for the receiver being dropped. notify_rx_closed: Notify, /// Handle to the push half of the lock-free list. @@ -263,6 +266,51 @@ impl<T, S: Semaphore> Rx<T, S> { } }) } + + /// Try to receive the next value. + pub(crate) fn try_recv(&mut self) -> Result<T, TryRecvError> { + use super::list::TryPopResult; + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + + macro_rules! try_recv { + () => { + match rx_fields.list.try_pop(&self.inner.tx) { + TryPopResult::Ok(value) => { + self.inner.semaphore.add_permit(); + return Ok(value); + } + TryPopResult::Closed => return Err(TryRecvError::Disconnected), + TryPopResult::Empty => return Err(TryRecvError::Empty), + TryPopResult::Busy => {} // fall through + } + }; + } + + try_recv!(); + + // If a previous `poll_recv` call has set a waker, we wake it here. + // This allows us to put our own CachedParkThread waker in the + // AtomicWaker slot instead. + // + // This is not a spurious wakeup to `poll_recv` since we just got a + // Busy from `try_pop`, which only happens if there are messages in + // the queue. + self.inner.rx_waker.wake(); + + // Park the thread until the problematic send has completed. + let mut park = CachedParkThread::new(); + let waker = park.unpark().into_waker(); + loop { + self.inner.rx_waker.register_by_ref(&waker); + // It is possible that the problematic send has now completed, + // so we have to check for messages again. + try_recv!(); + park.park().expect("park failed"); + } + }) + } } impl<T, S: Semaphore> Drop for Rx<T, S> { diff --git a/src/sync/mpsc/error.rs b/src/sync/mpsc/error.rs index a2d2824..b7b9cf7 100644 --- a/src/sync/mpsc/error.rs +++ b/src/sync/mpsc/error.rs @@ -1,4 +1,4 @@ -//! Channel error types +//! Channel error types. use std::error::Error; use std::fmt; @@ -51,18 +51,46 @@ impl<T> From<SendError<T>> for TrySendError<T> { } } +// ===== TryRecvError ===== + +/// Error returned by `try_recv`. +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub enum TryRecvError { + /// This **channel** is currently empty, but the **Sender**(s) have not yet + /// disconnected, so data may yet become available. + Empty, + /// The **channel**'s sending half has become disconnected, and there will + /// never be any more data received on it. + Disconnected, +} + +impl fmt::Display for TryRecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + TryRecvError::Empty => "receiving on an empty channel".fmt(fmt), + TryRecvError::Disconnected => "receiving on a closed channel".fmt(fmt), + } + } +} + +impl Error for TryRecvError {} + // ===== RecvError ===== /// Error returned by `Receiver`. #[derive(Debug)] +#[doc(hidden)] +#[deprecated(note = "This type is unused because recv returns an Option.")] pub struct RecvError(()); +#[allow(deprecated)] impl fmt::Display for RecvError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "channel closed") } } +#[allow(deprecated)] impl Error for RecvError {} cfg_time! { diff --git a/src/sync/mpsc/list.rs b/src/sync/mpsc/list.rs index 5dad2ba..e4eeb45 100644 --- a/src/sync/mpsc/list.rs +++ b/src/sync/mpsc/list.rs @@ -8,28 +8,40 @@ use std::fmt; use std::ptr::NonNull; use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; -/// List queue transmit handle +/// List queue transmit handle. pub(crate) struct Tx<T> { /// Tail in the `Block` mpmc list. block_tail: AtomicPtr<Block<T>>, - /// Position to push the next message. This reference a block and offset + /// Position to push the next message. This references a block and offset /// into the block. tail_position: AtomicUsize, } /// List queue receive handle pub(crate) struct Rx<T> { - /// Pointer to the block being processed + /// Pointer to the block being processed. head: NonNull<Block<T>>, - /// Next slot index to process + /// Next slot index to process. index: usize, - /// Pointer to the next block pending release + /// Pointer to the next block pending release. free_head: NonNull<Block<T>>, } +/// Return value of `Rx::try_pop`. +pub(crate) enum TryPopResult<T> { + /// Successfully popped a value. + Ok(T), + /// The channel is empty. + Empty, + /// The channel is empty and closed. + Closed, + /// The channel is not empty, but the first value is being written. + Busy, +} + pub(crate) fn channel<T>() -> (Tx<T>, Rx<T>) { // Create the initial block shared between the tx and rx halves. let initial_block = Box::new(Block::new(0)); @@ -67,7 +79,7 @@ impl<T> Tx<T> { } } - /// Closes the send half of the list + /// Closes the send half of the list. /// /// Similar process as pushing a value, but instead of writing the value & /// setting the ready flag, the TX_CLOSED flag is set on the block. @@ -218,7 +230,7 @@ impl<T> fmt::Debug for Tx<T> { } impl<T> Rx<T> { - /// Pops the next value off the queue + /// Pops the next value off the queue. pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> { // Advance `head`, if needed if !self.try_advancing_head() { @@ -240,6 +252,26 @@ impl<T> Rx<T> { } } + /// Pops the next value off the queue, detecting whether the block + /// is busy or empty on failure. + /// + /// This function exists because `Rx::pop` can return `None` even if the + /// channel's queue contains a message that has been completely written. + /// This can happen if the fully delivered message is behind another message + /// that is in the middle of being written to the block, since the channel + /// can't return the messages out of order. + pub(crate) fn try_pop(&mut self, tx: &Tx<T>) -> TryPopResult<T> { + let tail_position = tx.tail_position.load(Acquire); + let result = self.pop(tx); + + match result { + Some(block::Read::Value(t)) => TryPopResult::Ok(t), + Some(block::Read::Closed) => TryPopResult::Closed, + None if tail_position == self.index => TryPopResult::Empty, + None => TryPopResult::Busy, + } + } + /// Tries advancing the block pointer to the block referenced by `self.index`. /// /// Returns `true` if successful, `false` if there is no next block to load. diff --git a/src/sync/mpsc/mod.rs b/src/sync/mpsc/mod.rs index e7033f6..879e3dc 100644 --- a/src/sync/mpsc/mod.rs +++ b/src/sync/mpsc/mod.rs @@ -73,7 +73,7 @@ pub(super) mod block; mod bounded; -pub use self::bounded::{channel, Permit, Receiver, Sender}; +pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender}; mod chan; diff --git a/src/sync/mpsc/unbounded.rs b/src/sync/mpsc/unbounded.rs index ffdb34c..b133f9f 100644 --- a/src/sync/mpsc/unbounded.rs +++ b/src/sync/mpsc/unbounded.rs @@ -1,6 +1,6 @@ use crate::loom::sync::atomic::AtomicUsize; use crate::sync::mpsc::chan; -use crate::sync::mpsc::error::SendError; +use crate::sync::mpsc::error::{SendError, TryRecvError}; use std::fmt; use std::task::{Context, Poll}; @@ -82,6 +82,13 @@ impl<T> UnboundedReceiver<T> { /// `None` is returned when all `Sender` halves have dropped, indicating /// that no further values can be sent on the channel. /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// /// # Examples /// /// ``` @@ -122,6 +129,50 @@ impl<T> UnboundedReceiver<T> { poll_fn(|cx| self.poll_recv(cx)).await } + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`] error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`] error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`] method, this method will never return an + /// [`Empty`] error spuriously. + /// + /// [`Empty`]: crate::sync::mpsc::error::TryRecvError::Empty + /// [`Disconnected`]: crate::sync::mpsc::error::TryRecvError::Disconnected + /// [`poll_recv`]: Self::poll_recv + /// [senders]: crate::sync::mpsc::Sender + /// [permits]: crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use tokio::sync::mpsc::error::TryRecvError; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// + /// tx.send("hello").unwrap(); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + /// + /// tx.send("hello").unwrap(); + /// // Drop the last sender, closing the channel. + /// drop(tx); + /// + /// assert_eq!(Ok("hello"), rx.try_recv()); + /// assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); + /// } + /// ``` + pub fn try_recv(&mut self) -> Result<T, TryRecvError> { + self.chan.try_recv() + } + /// Blocking receive to call outside of asynchronous contexts. /// /// # Panics @@ -165,7 +216,7 @@ impl<T> UnboundedReceiver<T> { /// This method returns: /// /// * `Poll::Pending` if no messages are available but the channel is not - /// closed. + /// closed, or if a spurious failure happens. /// * `Poll::Ready(Some(message))` if a message is available. /// * `Poll::Ready(None)` if the channel has been closed and all messages /// sent before it was closed have been received. @@ -175,6 +226,12 @@ impl<T> UnboundedReceiver<T> { /// receiver, or when the channel is closed. Note that on multiple calls to /// `poll_recv`, only the `Waker` from the `Context` passed to the most /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { self.chan.recv(cx) } @@ -241,6 +298,11 @@ impl<T> UnboundedSender<T> { /// This allows the producers to get notified when interest in the produced /// values is canceled and immediately stop doing work. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// /// # Examples /// /// ``` @@ -270,6 +332,7 @@ impl<T> UnboundedSender<T> { pub async fn closed(&self) { self.chan.closed().await } + /// Checks if the channel has been closed. This happens when the /// [`UnboundedReceiver`] is dropped, or when the /// [`UnboundedReceiver::close`] method is called. diff --git a/src/sync/mutex.rs b/src/sync/mutex.rs index 0a118e7..4d9f988 100644 --- a/src/sync/mutex.rs +++ b/src/sync/mutex.rs @@ -4,9 +4,9 @@ use crate::sync::batch_semaphore as semaphore; use std::cell::UnsafeCell; use std::error::Error; -use std::fmt; use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use std::{fmt, marker, mem}; /// An asynchronous `Mutex`-like type. /// @@ -160,6 +160,19 @@ pub struct OwnedMutexGuard<T: ?Sized> { lock: Arc<Mutex<T>>, } +/// A handle to a held `Mutex` that has had a function applied to it via [`MutexGuard::map`]. +/// +/// This can be used to hold a subfield of the protected data. +/// +/// [`MutexGuard::map`]: method@MutexGuard::map +#[must_use = "if unused the Mutex will immediately unlock"] +pub struct MappedMutexGuard<'a, T: ?Sized> { + s: &'a semaphore::Semaphore, + data: *mut T, + // Needed to tell the borrow checker that we are holding a `&mut T` + marker: marker::PhantomData<&'a mut T>, +} + // As long as T: Send, it's fine to send and share Mutex<T> between threads. // If T was not Send, sending and sharing a Mutex<T> would be bad, since you can // access T through Mutex<T>. @@ -167,6 +180,8 @@ unsafe impl<T> Send for Mutex<T> where T: ?Sized + Send {} unsafe impl<T> Sync for Mutex<T> where T: ?Sized + Send {} unsafe impl<T> Sync for MutexGuard<'_, T> where T: ?Sized + Send + Sync {} unsafe impl<T> Sync for OwnedMutexGuard<T> where T: ?Sized + Send + Sync {} +unsafe impl<'a, T> Sync for MappedMutexGuard<'a, T> where T: ?Sized + Sync + 'a {} +unsafe impl<'a, T> Send for MappedMutexGuard<'a, T> where T: ?Sized + Send + 'a {} /// Error returned from the [`Mutex::try_lock`], [`RwLock::try_read`] and /// [`RwLock::try_write`] functions. @@ -258,9 +273,15 @@ impl<T: ?Sized> Mutex<T> { } } - /// Locks this mutex, causing the current task - /// to yield until the lock has been acquired. - /// When the lock has been acquired, function returns a [`MutexGuard`]. + /// Locks this mutex, causing the current task to yield until the lock has + /// been acquired. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `lock` makes you lose your place in + /// the queue. /// /// # Examples /// @@ -280,6 +301,40 @@ impl<T: ?Sized> Mutex<T> { MutexGuard { lock: self } } + /// Blocking lock this mutex. When the lock has been acquired, function returns a + /// [`MutexGuard`]. + /// + /// This method is intended for use cases where you + /// need to use this mutex in asynchronous code as well as in synchronous code. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Mutex; + /// + /// #[tokio::main] + /// async fn main() { + /// let mutex = Arc::new(Mutex::new(1)); + /// + /// let mutex1 = Arc::clone(&mutex); + /// let sync_code = tokio::task::spawn_blocking(move || { + /// let mut n = mutex1.blocking_lock(); + /// *n = 2; + /// }); + /// + /// sync_code.await.unwrap(); + /// + /// let n = mutex.lock().await; + /// assert_eq!(*n, 2); + /// } + /// + /// ``` + #[cfg(feature = "sync")] + pub fn blocking_lock(&self) -> MutexGuard<'_, T> { + crate::future::block_on(self.lock()) + } + /// Locks this mutex, causing the current task to yield until the lock has /// been acquired. When the lock has been acquired, this returns an /// [`OwnedMutexGuard`]. @@ -290,6 +345,12 @@ impl<T: ?Sized> Mutex<T> { /// method, and the guard will live for the `'static` lifetime, as it keeps /// the `Mutex` alive by holding an `Arc`. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `lock_owned` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` @@ -435,14 +496,14 @@ where } } -impl<T> std::fmt::Debug for Mutex<T> +impl<T: ?Sized> std::fmt::Debug for Mutex<T> where T: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let mut d = f.debug_struct("Mutex"); match self.try_lock() { - Ok(inner) => d.field("data", &*inner), + Ok(inner) => d.field("data", &&*inner), Err(_) => d.field("data", &format_args!("<locked>")), }; d.finish() @@ -451,6 +512,129 @@ where // === impl MutexGuard === +impl<'a, T: ?Sized> MutexGuard<'a, T> { + /// Makes a new [`MappedMutexGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the [`MutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MutexGuard::map(...)`. A method + /// would interfere with methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let foo = Mutex::new(Foo(1)); + /// + /// { + /// let mut mapped = MutexGuard::map(foo.lock().await, |f| &mut f.0); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *foo.lock().await); + /// # } + /// ``` + /// + /// [`MutexGuard`]: struct@MutexGuard + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn map<U, F>(mut this: Self, f: F) -> MappedMutexGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = &this.lock.s; + mem::forget(this); + MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`MappedMutexGuard`] for a component of the locked data. The + /// original guard is returned if the closure returns `None`. + /// + /// This operation cannot fail as the [`MutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MutexGuard::try_map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// #[derive(Debug, Clone, Copy, PartialEq, Eq)] + /// struct Foo(u32); + /// + /// # #[tokio::main] + /// # async fn main() { + /// let foo = Mutex::new(Foo(1)); + /// + /// { + /// let mut mapped = MutexGuard::try_map(foo.lock().await, |f| Some(&mut f.0)) + /// .expect("should not fail"); + /// *mapped = 2; + /// } + /// + /// assert_eq!(Foo(2), *foo.lock().await); + /// # } + /// ``` + /// + /// [`MutexGuard`]: struct@MutexGuard + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn try_map<U, F>(mut this: Self, f: F) -> Result<MappedMutexGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = &this.lock.s; + mem::forget(this); + Ok(MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + }) + } + + /// Returns a reference to the original `Mutex`. + /// + /// ``` + /// use tokio::sync::{Mutex, MutexGuard}; + /// + /// async fn unlock_and_relock<'l>(guard: MutexGuard<'l, u32>) -> MutexGuard<'l, u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex = MutexGuard::mutex(&guard); + /// drop(guard); + /// let guard = mutex.lock().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Mutex::new(0u32); + /// # let guard = mutex.lock().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &'a Mutex<T> { + this.lock + } +} + impl<T: ?Sized> Drop for MutexGuard<'_, T> { fn drop(&mut self) { self.lock.s.release(1) @@ -484,6 +668,35 @@ impl<T: ?Sized + fmt::Display> fmt::Display for MutexGuard<'_, T> { // === impl OwnedMutexGuard === +impl<T: ?Sized> OwnedMutexGuard<T> { + /// Returns a reference to the original `Arc<Mutex>`. + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Mutex, OwnedMutexGuard}; + /// + /// async fn unlock_and_relock(guard: OwnedMutexGuard<u32>) -> OwnedMutexGuard<u32> { + /// println!("1. contains: {:?}", *guard); + /// let mutex: Arc<Mutex<u32>> = OwnedMutexGuard::mutex(&guard).clone(); + /// drop(guard); + /// let guard = mutex.lock_owned().await; + /// println!("2. contains: {:?}", *guard); + /// guard + /// } + /// # + /// # #[tokio::main] + /// # async fn main() { + /// # let mutex = Arc::new(Mutex::new(0u32)); + /// # let guard = mutex.lock_owned().await; + /// # unlock_and_relock(guard).await; + /// # } + /// ``` + #[inline] + pub fn mutex(this: &Self) -> &Arc<Mutex<T>> { + &this.lock + } +} + impl<T: ?Sized> Drop for OwnedMutexGuard<T> { fn drop(&mut self) { self.lock.s.release(1) @@ -514,3 +727,88 @@ impl<T: ?Sized + fmt::Display> fmt::Display for OwnedMutexGuard<T> { fmt::Display::fmt(&**self, f) } } + +// === impl MappedMutexGuard === + +impl<'a, T: ?Sized> MappedMutexGuard<'a, T> { + /// Makes a new [`MappedMutexGuard`] for a component of the locked data. + /// + /// This operation cannot fail as the [`MappedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MappedMutexGuard::map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn map<U, F>(mut this: Self, f: F) -> MappedMutexGuard<'a, U> + where + F: FnOnce(&mut T) -> &mut U, + { + let data = f(&mut *this) as *mut U; + let s = this.s; + mem::forget(this); + MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + } + } + + /// Attempts to make a new [`MappedMutexGuard`] for a component of the locked data. The + /// original guard is returned if the closure returns `None`. + /// + /// This operation cannot fail as the [`MappedMutexGuard`] passed in already locked the mutex. + /// + /// This is an associated function that needs to be used as `MappedMutexGuard::try_map(...)`. A + /// method would interfere with methods of the same name on the contents of the locked data. + /// + /// [`MappedMutexGuard`]: struct@MappedMutexGuard + #[inline] + pub fn try_map<U, F>(mut this: Self, f: F) -> Result<MappedMutexGuard<'a, U>, Self> + where + F: FnOnce(&mut T) -> Option<&mut U>, + { + let data = match f(&mut *this) { + Some(data) => data as *mut U, + None => return Err(this), + }; + let s = this.s; + mem::forget(this); + Ok(MappedMutexGuard { + s, + data, + marker: marker::PhantomData, + }) + } +} + +impl<'a, T: ?Sized> Drop for MappedMutexGuard<'a, T> { + fn drop(&mut self) { + self.s.release(1) + } +} + +impl<'a, T: ?Sized> Deref for MappedMutexGuard<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.data } + } +} + +impl<'a, T: ?Sized> DerefMut for MappedMutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.data } + } +} + +impl<'a, T: ?Sized + fmt::Debug> fmt::Debug for MappedMutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, f) + } +} + +impl<'a, T: ?Sized + fmt::Display> fmt::Display for MappedMutexGuard<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, f) + } +} diff --git a/src/sync/notify.rs b/src/sync/notify.rs index 2d30da9..c93ce3b 100644 --- a/src/sync/notify.rs +++ b/src/sync/notify.rs @@ -8,6 +8,7 @@ use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::Mutex; use crate::util::linked_list::{self, LinkedList}; +use crate::util::WakeList; use std::cell::UnsafeCell; use std::future::Future; @@ -19,7 +20,7 @@ use std::task::{Context, Poll, Waker}; type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; -/// Notify a single task to wake up. +/// Notifies a single task to wake up. /// /// `Notify` provides a basic mechanism to notify a single task of an event. /// `Notify` itself does not carry any data. Instead, it is to be used to signal @@ -56,13 +57,16 @@ type WaitList = LinkedList<Waiter, <Waiter as linked_list::Link>::Target>; /// let notify = Arc::new(Notify::new()); /// let notify2 = notify.clone(); /// -/// tokio::spawn(async move { +/// let handle = tokio::spawn(async move { /// notify2.notified().await; /// println!("received notification"); /// }); /// /// println!("sending notification"); /// notify.notify_one(); +/// +/// // Wait for task to receive notification. +/// handle.await.unwrap(); /// } /// ``` /// @@ -127,10 +131,10 @@ enum NotificationType { #[derive(Debug)] struct Waiter { - /// Intrusive linked-list pointers + /// Intrusive linked-list pointers. pointers: linked_list::Pointers<Waiter>, - /// Waiting task's waker + /// Waiting task's waker. waker: Option<Waker>, /// `true` if the notification has been assigned to this waiter. @@ -140,7 +144,7 @@ struct Waiter { _p: PhantomPinned, } -/// Future returned from `notified()` +/// Future returned from [`Notify::notified()`] #[derive(Debug)] pub struct Notified<'a> { /// The `Notify` being received on. @@ -167,13 +171,13 @@ const NOTIFY_WAITERS_SHIFT: usize = 2; const STATE_MASK: usize = (1 << NOTIFY_WAITERS_SHIFT) - 1; const NOTIFY_WAITERS_CALLS_MASK: usize = !STATE_MASK; -/// Initial "idle" state +/// Initial "idle" state. const EMPTY: usize = 0; /// One or more threads are currently waiting to be notified. const WAITING: usize = 1; -/// Pending notification +/// Pending notification. const NOTIFIED: usize = 2; fn set_state(data: usize, state: usize) -> usize { @@ -192,6 +196,10 @@ fn inc_num_notify_waiters_calls(data: usize) -> usize { data + (1 << NOTIFY_WAITERS_SHIFT) } +fn atomic_inc_num_notify_waiters_calls(data: &AtomicUsize) { + data.fetch_add(1 << NOTIFY_WAITERS_SHIFT, SeqCst); +} + impl Notify { /// Create a new `Notify`, initialized without a permit. /// @@ -242,6 +250,12 @@ impl Notify { /// /// [`notify_one()`]: Notify::notify_one /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute notifications in the order + /// they were requested. Cancelling a call to `notified` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` @@ -278,7 +292,7 @@ impl Notify { } } - /// Notifies a waiting task + /// Notifies a waiting task. /// /// If a task is currently waiting, that task is notified. Otherwise, a /// permit is stored in this `Notify` value and the **next** call to @@ -348,7 +362,7 @@ impl Notify { } } - /// Notifies all waiting tasks + /// Notifies all waiting tasks. /// /// If a task is currently waiting, that task is notified. Unlike with /// `notify_one()`, no permit is stored to be used by the next call to @@ -381,10 +395,7 @@ impl Notify { /// } /// ``` pub fn notify_waiters(&self) { - const NUM_WAKERS: usize = 32; - - let mut wakers: [Option<Waker>; NUM_WAKERS] = Default::default(); - let mut curr_waker = 0; + let mut wakers = WakeList::new(); // There are waiters, the lock must be acquired to notify. let mut waiters = self.waiters.lock(); @@ -394,11 +405,9 @@ impl Notify { let curr = self.state.load(SeqCst); if let EMPTY | NOTIFIED = get_state(curr) { - // There are no waiting tasks. In this case, no synchronization is - // established between `notify` and `notified().await`. - // All we need to do is increment the number of times this - // method was called. - self.state.store(inc_num_notify_waiters_calls(curr), SeqCst); + // There are no waiting tasks. All we need to do is increment the + // number of times this method was called. + atomic_inc_num_notify_waiters_calls(&self.state); return; } @@ -406,7 +415,7 @@ impl Notify { // concurrently change, as holding the lock is required to // transition **out** of `WAITING`. 'outer: loop { - while curr_waker < NUM_WAKERS { + while wakers.can_push() { match waiters.pop_back() { Some(mut waiter) => { // Safety: `waiters` lock is still held. @@ -417,8 +426,7 @@ impl Notify { waiter.notified = Some(NotificationType::AllWaiters); if let Some(waker) = waiter.waker.take() { - wakers[curr_waker] = Some(waker); - curr_waker += 1; + wakers.push(waker); } } None => { @@ -429,11 +437,7 @@ impl Notify { drop(waiters); - for waker in wakers.iter_mut().take(curr_waker) { - waker.take().unwrap().wake(); - } - - curr_waker = 0; + wakers.wake_all(); // Acquire the lock again. waiters = self.waiters.lock(); @@ -448,9 +452,7 @@ impl Notify { // Release the lock before notifying drop(waiters); - for waker in wakers.iter_mut().take(curr_waker) { - waker.take().unwrap().wake(); - } + wakers.wake_all(); } } @@ -520,7 +522,7 @@ impl Notified<'_> { is_unpin::<AtomicUsize>(); let me = self.get_unchecked_mut(); - (&me.notify, &mut me.state, &me.waiter) + (me.notify, &mut me.state, &me.waiter) } } } @@ -552,6 +554,10 @@ impl Future for Notified<'_> { return Poll::Ready(()); } + // Clone the waker before locking, a waker clone can be + // triggering arbitrary code. + let waker = cx.waker().clone(); + // Acquire the lock and attempt to transition to the waiting // state. let mut waiters = notify.waiters.lock(); @@ -613,7 +619,7 @@ impl Future for Notified<'_> { // Safety: called while locked. unsafe { - (*waiter.get()).waker = Some(cx.waker().clone()); + (*waiter.get()).waker = Some(waker); } // Insert the waiter into the linked list diff --git a/src/sync/once_cell.rs b/src/sync/once_cell.rs index fa9b1f1..d31a40e 100644 --- a/src/sync/once_cell.rs +++ b/src/sync/once_cell.rs @@ -1,4 +1,4 @@ -use super::Semaphore; +use super::{Semaphore, SemaphorePermit, TryAcquireError}; use crate::loom::cell::UnsafeCell; use std::error::Error; use std::fmt; @@ -8,15 +8,30 @@ use std::ops::Drop; use std::ptr; use std::sync::atomic::{AtomicBool, Ordering}; -/// A thread-safe cell which can be written to only once. +// This file contains an implementation of an OnceCell. The principle +// behind the safety the of the cell is that any thread with an `&OnceCell` may +// access the `value` field according the following rules: +// +// 1. When `value_set` is false, the `value` field may be modified by the +// thread holding the permit on the semaphore. +// 2. When `value_set` is true, the `value` field may be accessed immutably by +// any thread. +// +// It is an invariant that if the semaphore is closed, then `value_set` is true. +// The reverse does not necessarily hold — but if not, the semaphore may not +// have any available permits. +// +// A thread with a `&mut OnceCell` may modify the value in any way it wants as +// long as the invariants are upheld. + +/// A thread-safe cell that can be written to only once. /// -/// Provides the functionality to either set the value, in case `OnceCell` -/// is uninitialized, or get the already initialized value by using an async -/// function via [`OnceCell::get_or_init`]. -/// -/// [`OnceCell::get_or_init`]: crate::sync::OnceCell::get_or_init +/// A `OnceCell` is typically used for global variables that need to be +/// initialized once on first use, but need no further changes. The `OnceCell` +/// in Tokio allows the initialization procedure to be asynchronous. /// /// # Examples +/// /// ``` /// use tokio::sync::OnceCell; /// @@ -28,8 +43,28 @@ use std::sync::atomic::{AtomicBool, Ordering}; /// /// #[tokio::main] /// async fn main() { -/// let result1 = ONCE.get_or_init(some_computation).await; -/// assert_eq!(*result1, 2); +/// let result = ONCE.get_or_init(some_computation).await; +/// assert_eq!(*result, 2); +/// } +/// ``` +/// +/// It is often useful to write a wrapper method for accessing the value. +/// +/// ``` +/// use tokio::sync::OnceCell; +/// +/// static ONCE: OnceCell<u32> = OnceCell::const_new(); +/// +/// async fn get_global_integer() -> &'static u32 { +/// ONCE.get_or_init(|| async { +/// 1 + 1 +/// }).await +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// let result = get_global_integer().await; +/// assert_eq!(*result, 2); /// } /// ``` pub struct OnceCell<T> { @@ -68,7 +103,7 @@ impl<T: Eq> Eq for OnceCell<T> {} impl<T> Drop for OnceCell<T> { fn drop(&mut self) { - if self.initialized() { + if self.initialized_mut() { unsafe { self.value .with_mut(|ptr| ptr::drop_in_place((&mut *ptr).as_mut_ptr())); @@ -77,8 +112,20 @@ impl<T> Drop for OnceCell<T> { } } +impl<T> From<T> for OnceCell<T> { + fn from(value: T) -> Self { + let semaphore = Semaphore::new(0); + semaphore.close(); + OnceCell { + value_set: AtomicBool::new(true), + value: UnsafeCell::new(MaybeUninit::new(value)), + semaphore, + } + } +} + impl<T> OnceCell<T> { - /// Creates a new uninitialized OnceCell instance. + /// Creates a new empty `OnceCell` instance. pub fn new() -> Self { OnceCell { value_set: AtomicBool::new(false), @@ -87,26 +134,44 @@ impl<T> OnceCell<T> { } } - /// Creates a new initialized OnceCell instance if `value` is `Some`, otherwise - /// has the same functionality as [`OnceCell::new`]. + /// Creates a new `OnceCell` that contains the provided value, if any. + /// + /// If the `Option` is `None`, this is equivalent to `OnceCell::new`. /// /// [`OnceCell::new`]: crate::sync::OnceCell::new pub fn new_with(value: Option<T>) -> Self { if let Some(v) = value { - let semaphore = Semaphore::new(0); - semaphore.close(); - OnceCell { - value_set: AtomicBool::new(true), - value: UnsafeCell::new(MaybeUninit::new(v)), - semaphore, - } + OnceCell::from(v) } else { OnceCell::new() } } - /// Creates a new uninitialized OnceCell instance. - #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + /// Creates a new empty `OnceCell` instance. + /// + /// Equivalent to `OnceCell::new`, except that it can be used in static + /// variables. + /// + /// # Example + /// + /// ``` + /// use tokio::sync::OnceCell; + /// + /// static ONCE: OnceCell<u32> = OnceCell::const_new(); + /// + /// async fn get_global_integer() -> &'static u32 { + /// ONCE.get_or_init(|| async { + /// 1 + 1 + /// }).await + /// } + /// + /// #[tokio::main] + /// async fn main() { + /// let result = get_global_integer().await; + /// assert_eq!(*result, 2); + /// } + /// ``` + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] pub const fn const_new() -> Self { OnceCell { @@ -116,33 +181,48 @@ impl<T> OnceCell<T> { } } - /// Whether the value of the OnceCell is set or not. + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. pub fn initialized(&self) -> bool { + // Using acquire ordering so any threads that read a true from this + // atomic is able to read the value. self.value_set.load(Ordering::Acquire) } - // SAFETY: safe to call only once self.initialized() is true + /// Returns `true` if the `OnceCell` currently contains a value, and `false` + /// otherwise. + fn initialized_mut(&mut self) -> bool { + *self.value_set.get_mut() + } + + // SAFETY: The OnceCell must not be empty. unsafe fn get_unchecked(&self) -> &T { &*self.value.with(|ptr| (*ptr).as_ptr()) } - // SAFETY: safe to call only once self.initialized() is true. Safe because - // because of the mutable reference. + // SAFETY: The OnceCell must not be empty. unsafe fn get_unchecked_mut(&mut self) -> &mut T { &mut *self.value.with_mut(|ptr| (*ptr).as_mut_ptr()) } - // SAFETY: safe to call only once a permit on the semaphore has been - // acquired - unsafe fn set_value(&self, value: T) { - self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + fn set_value(&self, value: T, permit: SemaphorePermit<'_>) -> &T { + // SAFETY: We are holding the only permit on the semaphore. + unsafe { + self.value.with_mut(|ptr| (*ptr).as_mut_ptr().write(value)); + } + + // Using release ordering so any threads that read a true from this + // atomic is able to read the value we just stored. self.value_set.store(true, Ordering::Release); self.semaphore.close(); + permit.forget(); + + // SAFETY: We just initialized the cell. + unsafe { self.get_unchecked() } } - /// Tries to get a reference to the value of the OnceCell. - /// - /// Returns None if the value of the OnceCell hasn't previously been initialized. + /// Returns a reference to the value currently stored in the `OnceCell`, or + /// `None` if the `OnceCell` is empty. pub fn get(&self) -> Option<&T> { if self.initialized() { Some(unsafe { self.get_unchecked() }) @@ -151,179 +231,161 @@ impl<T> OnceCell<T> { } } - /// Tries to return a mutable reference to the value of the cell. + /// Returns a mutable reference to the value currently stored in the + /// `OnceCell`, or `None` if the `OnceCell` is empty. /// - /// Returns None if the cell hasn't previously been initialized. + /// Since this call borrows the `OnceCell` mutably, it is safe to mutate the + /// value inside the `OnceCell` — the mutable borrow statically guarantees + /// no other references exist. pub fn get_mut(&mut self) -> Option<&mut T> { - if self.initialized() { + if self.initialized_mut() { Some(unsafe { self.get_unchecked_mut() }) } else { None } } - /// Sets the value of the OnceCell to the argument value. + /// Sets the value of the `OnceCell` to the given value if the `OnceCell` is + /// empty. + /// + /// If the `OnceCell` already has a value, this call will fail with an + /// [`SetError::AlreadyInitializedError`]. /// - /// If the value of the OnceCell was already set prior to this call - /// then [`SetError::AlreadyInitializedError`] is returned. If another thread - /// is initializing the cell while this method is called, - /// [`SetError::InitializingError`] is returned. In order to wait - /// for an ongoing initialization to finish, call - /// [`OnceCell::get_or_init`] instead. + /// If the `OnceCell` is empty, but some other task is currently trying to + /// set the value, this call will fail with [`SetError::InitializingError`]. /// /// [`SetError::AlreadyInitializedError`]: crate::sync::SetError::AlreadyInitializedError /// [`SetError::InitializingError`]: crate::sync::SetError::InitializingError - /// ['OnceCell::get_or_init`]: crate::sync::OnceCell::get_or_init pub fn set(&self, value: T) -> Result<(), SetError<T>> { - if !self.initialized() { - // Another thread might be initializing the cell, in which case `try_acquire` will - // return an error - match self.semaphore.try_acquire() { - Ok(_permit) => { - if !self.initialized() { - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - return Ok(()); - } else { - unreachable!( - "acquired the permit after OnceCell value was already initialized." - ); - } - } - _ => { - // Couldn't acquire the permit, look if initializing process is already completed - if !self.initialized() { - return Err(SetError::InitializingError(value)); - } - } - } + if self.initialized() { + return Err(SetError::AlreadyInitializedError(value)); } - Err(SetError::AlreadyInitializedError(value)) + // Another task might be initializing the cell, in which case + // `try_acquire` will return an error. If we succeed to acquire the + // permit, then we can set the value. + match self.semaphore.try_acquire() { + Ok(permit) => { + debug_assert!(!self.initialized()); + self.set_value(value, permit); + Ok(()) + } + Err(TryAcquireError::NoPermits) => { + // Some other task is holding the permit. That task is + // currently trying to initialize the value. + Err(SetError::InitializingError(value)) + } + Err(TryAcquireError::Closed) => { + // The semaphore was closed. Some other task has initialized + // the value. + Err(SetError::AlreadyInitializedError(value)) + } + } } - /// Tries to initialize the value of the OnceCell using the async function `f`. - /// If the value of the OnceCell was already initialized prior to this call, - /// a reference to that initialized value is returned. If some other thread - /// initiated the initialization prior to this call and the initialization - /// hasn't completed, this call waits until the initialization is finished. + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. + /// + /// If the provided operation is cancelled or panics, the initialization + /// attempt is cancelled. If there are other tasks waiting for the value to + /// be initialized, one of them will start another attempt at initializing + /// the value. /// - /// This will deadlock if `f` tries to initialize the cell itself. + /// This will deadlock if `f` tries to initialize the cell recursively. pub async fn get_or_init<F, Fut>(&self, f: F) -> &T where F: FnOnce() -> Fut, Fut: Future<Output = T>, { if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references + // SAFETY: The OnceCell has been fully initialized. unsafe { self.get_unchecked() } } else { - // After acquire().await we have either acquired a permit while self.value - // is still uninitialized, or the current thread is awoken after another thread - // has intialized the value and closed the semaphore, in which case self.initialized - // is true and we don't set the value here + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. match self.semaphore.acquire().await { - Ok(_permit) => { - if !self.initialized() { - // If `f()` panics or `select!` is called, this `get_or_init` call - // is aborted and the semaphore permit is dropped. - let value = f().await; - - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { self.get_unchecked() } - } else { - unreachable!("acquired semaphore after value was already initialized."); - } + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_init` call is aborted and the semaphore permit is + // dropped. + let value = f().await; + + self.set_value(value, permit) } Err(_) => { - if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { self.get_unchecked() } - } else { - unreachable!( - "Semaphore closed, but the OnceCell has not been initialized." - ); - } + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { self.get_unchecked() } } } } } - /// Tries to initialize the value of the OnceCell using the async function `f`. - /// If the value of the OnceCell was already initialized prior to this call, - /// a reference to that initialized value is returned. If some other thread - /// initiated the initialization prior to this call and the initialization - /// hasn't completed, this call waits until the initialization is finished. - /// If the function argument `f` returns an error, `get_or_try_init` - /// returns that error, otherwise the result of `f` will be stored in the cell. + /// Gets the value currently in the `OnceCell`, or initialize it with the + /// given asynchronous operation. + /// + /// If some other task is currently working on initializing the `OnceCell`, + /// this call will wait for that other task to finish, then return the value + /// that the other task produced. /// - /// This will deadlock if `f` tries to initialize the cell itself. + /// If the provided operation returns an error, is cancelled or panics, the + /// initialization attempt is cancelled. If there are other tasks waiting + /// for the value to be initialized, one of them will start another attempt + /// at initializing the value. + /// + /// This will deadlock if `f` tries to initialize the cell recursively. pub async fn get_or_try_init<E, F, Fut>(&self, f: F) -> Result<&T, E> where F: FnOnce() -> Fut, Fut: Future<Output = Result<T, E>>, { if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references + // SAFETY: The OnceCell has been fully initialized. unsafe { Ok(self.get_unchecked()) } } else { - // After acquire().await we have either acquired a permit while self.value - // is still uninitialized, or the current thread is awoken after another thread - // has intialized the value and closed the semaphore, in which case self.initialized - // is true and we don't set the value here + // Here we try to acquire the semaphore permit. Holding the permit + // will allow us to set the value of the OnceCell, and prevents + // other tasks from initializing the OnceCell while we are holding + // it. match self.semaphore.acquire().await { - Ok(_permit) => { - if !self.initialized() { - // If `f()` panics or `select!` is called, this `get_or_try_init` call - // is aborted and the semaphore permit is dropped. - let value = f().await; - - match value { - Ok(value) => { - // SAFETY: There is only one permit on the semaphore, hence only one - // mutable reference is created - unsafe { self.set_value(value) }; - - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { Ok(self.get_unchecked()) } - } - Err(e) => Err(e), - } - } else { - unreachable!("acquired semaphore after value was already initialized."); + Ok(permit) => { + debug_assert!(!self.initialized()); + + // If `f()` panics or `select!` is called, this + // `get_or_try_init` call is aborted and the semaphore + // permit is dropped. + let value = f().await; + + match value { + Ok(value) => Ok(self.set_value(value, permit)), + Err(e) => Err(e), } } Err(_) => { - if self.initialized() { - // SAFETY: once the value is initialized, no mutable references are given out, so - // we can give out arbitrarily many immutable references - unsafe { Ok(self.get_unchecked()) } - } else { - unreachable!( - "Semaphore closed, but the OnceCell has not been initialized." - ); - } + debug_assert!(self.initialized()); + + // SAFETY: The semaphore has been closed. This only happens + // when the OnceCell is fully initialized. + unsafe { Ok(self.get_unchecked()) } } } } } - /// Moves the value out of the cell, destroying the cell in the process. - /// - /// Returns `None` if the cell is uninitialized. + /// Takes the value from the cell, destroying the cell in the process. + /// Returns `None` if the cell is empty. pub fn into_inner(mut self) -> Option<T> { - if self.initialized() { + if self.initialized_mut() { // Set to uninitialized for the destructor of `OnceCell` to work properly *self.value_set.get_mut() = false; Some(unsafe { self.value.with(|ptr| ptr::read(ptr).assume_init()) }) @@ -332,20 +394,18 @@ impl<T> OnceCell<T> { } } - /// Takes ownership of the current value, leaving the cell uninitialized. - /// - /// Returns `None` if the cell is uninitialized. + /// Takes ownership of the current value, leaving the cell empty. Returns + /// `None` if the cell is empty. pub fn take(&mut self) -> Option<T> { std::mem::take(self).into_inner() } } -// Since `get` gives us access to immutable references of the -// OnceCell, OnceCell can only be Sync if T is Sync, otherwise -// OnceCell would allow sharing references of !Sync values across -// threads. We need T to be Send in order for OnceCell to by Sync -// because we can use `set` on `&OnceCell<T>` to send -// values (of type T) across threads. +// Since `get` gives us access to immutable references of the OnceCell, OnceCell +// can only be Sync if T is Sync, otherwise OnceCell would allow sharing +// references of !Sync values across threads. We need T to be Send in order for +// OnceCell to by Sync because we can use `set` on `&OnceCell<T>` to send values +// (of type T) across threads. unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} // Access to OnceCell's value is guarded by the semaphore permit @@ -353,20 +413,17 @@ unsafe impl<T: Sync + Send> Sync for OnceCell<T> {} // it's safe to send it to another thread unsafe impl<T: Send> Send for OnceCell<T> {} -/// Errors that can be returned from [`OnceCell::set`] +/// Errors that can be returned from [`OnceCell::set`]. /// /// [`OnceCell::set`]: crate::sync::OnceCell::set #[derive(Debug, PartialEq)] pub enum SetError<T> { - /// Error resulting from [`OnceCell::set`] calls if the cell was previously initialized. + /// The cell was already initialized when [`OnceCell::set`] was called. /// /// [`OnceCell::set`]: crate::sync::OnceCell::set AlreadyInitializedError(T), - /// Error resulting from [`OnceCell::set`] calls when the cell is currently being - /// inintialized during the calls to that method. - /// - /// [`OnceCell::set`]: crate::sync::OnceCell::set + /// The cell is currently being initialized. InitializingError(T), } diff --git a/src/sync/oneshot.rs b/src/sync/oneshot.rs index 0df6037..4fb22ec 100644 --- a/src/sync/oneshot.rs +++ b/src/sync/oneshot.rs @@ -51,6 +51,70 @@ //! } //! } //! ``` +//! +//! To use a oneshot channel in a `tokio::select!` loop, add `&mut` in front of +//! the channel. +//! +//! ``` +//! use tokio::sync::oneshot; +//! use tokio::time::{interval, sleep, Duration}; +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread", start_paused = true)] +//! async fn main() { +//! let (send, mut recv) = oneshot::channel(); +//! let mut interval = interval(Duration::from_millis(100)); +//! +//! # let handle = +//! tokio::spawn(async move { +//! sleep(Duration::from_secs(1)).await; +//! send.send("shut down").unwrap(); +//! }); +//! +//! loop { +//! tokio::select! { +//! _ = interval.tick() => println!("Another 100ms"), +//! msg = &mut recv => { +//! println!("Got message: {}", msg.unwrap()); +//! break; +//! } +//! } +//! } +//! # handle.await.unwrap(); +//! } +//! ``` +//! +//! To use a `Sender` from a destructor, put it in an [`Option`] and call +//! [`Option::take`]. +//! +//! ``` +//! use tokio::sync::oneshot; +//! +//! struct SendOnDrop { +//! sender: Option<oneshot::Sender<&'static str>>, +//! } +//! impl Drop for SendOnDrop { +//! fn drop(&mut self) { +//! if let Some(sender) = self.sender.take() { +//! // Using `let _ =` to ignore send errors. +//! let _ = sender.send("I got dropped!"); +//! } +//! } +//! } +//! +//! #[tokio::main] +//! # async fn _doc() {} +//! # #[tokio::main(flavor = "current_thread")] +//! async fn main() { +//! let (send, recv) = oneshot::channel(); +//! +//! let send_on_drop = SendOnDrop { sender: Some(send) }; +//! drop(send_on_drop); +//! +//! assert_eq!(recv.await, Ok("I got dropped!")); +//! } +//! ``` use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::AtomicUsize; @@ -68,16 +132,98 @@ use std::task::{Context, Poll, Waker}; /// /// A pair of both a [`Sender`] and a [`Receiver`] are created by the /// [`channel`](fn@channel) function. +/// +/// # Examples +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel(); +/// +/// tokio::spawn(async move { +/// if let Err(_) = tx.send(3) { +/// println!("the receiver dropped"); +/// } +/// }); +/// +/// match rx.await { +/// Ok(v) => println!("got = {:?}", v), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// If the sender is dropped without sending, the receiver will fail with +/// [`error::RecvError`]: +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// #[tokio::main] +/// async fn main() { +/// let (tx, rx) = oneshot::channel::<u32>(); +/// +/// tokio::spawn(async move { +/// drop(tx); +/// }); +/// +/// match rx.await { +/// Ok(_) => panic!("This doesn't happen"), +/// Err(_) => println!("the sender dropped"), +/// } +/// } +/// ``` +/// +/// To use a `Sender` from a destructor, put it in an [`Option`] and call +/// [`Option::take`]. +/// +/// ``` +/// use tokio::sync::oneshot; +/// +/// struct SendOnDrop { +/// sender: Option<oneshot::Sender<&'static str>>, +/// } +/// impl Drop for SendOnDrop { +/// fn drop(&mut self) { +/// if let Some(sender) = self.sender.take() { +/// // Using `let _ =` to ignore send errors. +/// let _ = sender.send("I got dropped!"); +/// } +/// } +/// } +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread")] +/// async fn main() { +/// let (send, recv) = oneshot::channel(); +/// +/// let send_on_drop = SendOnDrop { sender: Some(send) }; +/// drop(send_on_drop); +/// +/// assert_eq!(recv.await, Ok("I got dropped!")); +/// } +/// ``` +/// +/// [`Option`]: std::option::Option +/// [`Option::take`]: std::option::Option::take #[derive(Debug)] pub struct Sender<T> { inner: Option<Arc<Inner<T>>>, } -/// Receive a value from the associated [`Sender`]. +/// Receives a value from the associated [`Sender`]. /// /// A pair of both a [`Sender`] and a [`Receiver`] are created by the /// [`channel`](fn@channel) function. /// +/// This channel has no `recv` method because the receiver itself implements the +/// [`Future`] trait. To receive a value, `.await` the `Receiver` object directly. +/// +/// [`Future`]: trait@std::future::Future +/// /// # Examples /// /// ``` @@ -120,13 +266,46 @@ pub struct Sender<T> { /// } /// } /// ``` +/// +/// To use a `Receiver` in a `tokio::select!` loop, add `&mut` in front of the +/// channel. +/// +/// ``` +/// use tokio::sync::oneshot; +/// use tokio::time::{interval, sleep, Duration}; +/// +/// #[tokio::main] +/// # async fn _doc() {} +/// # #[tokio::main(flavor = "current_thread", start_paused = true)] +/// async fn main() { +/// let (send, mut recv) = oneshot::channel(); +/// let mut interval = interval(Duration::from_millis(100)); +/// +/// # let handle = +/// tokio::spawn(async move { +/// sleep(Duration::from_secs(1)).await; +/// send.send("shut down").unwrap(); +/// }); +/// +/// loop { +/// tokio::select! { +/// _ = interval.tick() => println!("Another 100ms"), +/// msg = &mut recv => { +/// println!("Got message: {}", msg.unwrap()); +/// break; +/// } +/// } +/// } +/// # handle.await.unwrap(); +/// } +/// ``` #[derive(Debug)] pub struct Receiver<T> { inner: Option<Arc<Inner<T>>>, } pub mod error { - //! Oneshot error types + //! Oneshot error types. use std::fmt; @@ -171,7 +350,7 @@ pub mod error { use self::error::*; struct Inner<T> { - /// Manages the state of the inner cell + /// Manages the state of the inner cell. state: AtomicUsize, /// The value. This is set by `Sender` and read by `Receiver`. The state of @@ -179,9 +358,19 @@ struct Inner<T> { value: UnsafeCell<Option<T>>, /// The task to notify when the receiver drops without consuming the value. + /// + /// ## Safety + /// + /// The `TX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. tx_task: Task, /// The task to notify when the value is sent. + /// + /// ## Safety + /// + /// The `RX_TASK_SET` bit in the `state` field is set if this field is + /// initialized. If that bit is unset, this field may be uninitialized. rx_task: Task, } @@ -220,7 +409,7 @@ impl Task { #[derive(Clone, Copy)] struct State(usize); -/// Create a new one-shot channel for sending single values across asynchronous +/// Creates a new one-shot channel for sending single values across asynchronous /// tasks. /// /// The function returns separate "send" and "receive" handles. The `Sender` @@ -311,11 +500,24 @@ impl<T> Sender<T> { let inner = self.inner.take().unwrap(); inner.value.with_mut(|ptr| unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless the + // channel has been marked as "complete" (the `VALUE_SENT` state bit + // is set). + // That bit is only set by the sender later on in this method, and + // calling this method consumes `self`. Therefore, if it was possible to + // call this method, we know that the `VALUE_SENT` bit is unset, and + // the receiver is not currently accessing the `UnsafeCell`. *ptr = Some(t); }); if !inner.complete() { unsafe { + // SAFETY: The receiver will not access the `UnsafeCell` unless + // the channel has been marked as "complete". Calling + // `complete()` will return true if this bit is set, and false + // if it is not set. Thus, if `complete()` returned false, it is + // safe for us to access the value, because we know that the + // receiver will not. return Err(inner.consume_value().unwrap()); } } @@ -430,7 +632,7 @@ impl<T> Sender<T> { state.is_closed() } - /// Check whether the oneshot channel has been closed, and if not, schedules the + /// Checks whether the oneshot channel has been closed, and if not, schedules the /// `Waker` in the provided `Context` to receive a notification when the channel is /// closed. /// @@ -661,6 +863,11 @@ impl<T> Receiver<T> { let state = State::load(&inner.state, Acquire); if state.is_complete() { + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. match unsafe { inner.consume_value() } { Some(value) => Ok(value), None => Err(TryRecvError::Closed), @@ -751,6 +958,11 @@ impl<T> Inner<T> { State::set_rx_task(&self.state); coop.made_progress(); + // SAFETY: If `state.is_complete()` returns true, then the + // `VALUE_SENT` bit has been set and the sender side of the + // channel will no longer attempt to access the inner + // `UnsafeCell`. Therefore, it is now safe for us to access the + // cell. return match unsafe { self.consume_value() } { Some(value) => Ready(Ok(value)), None => Ready(Err(RecvError(()))), @@ -797,6 +1009,14 @@ impl<T> Inner<T> { } /// Consumes the value. This function does not check `state`. + /// + /// # Safety + /// + /// Calling this method concurrently on multiple threads will result in a + /// data race. The `VALUE_SENT` state bit is used to ensure that only the + /// sender *or* the receiver will call this method at a given point in time. + /// If `VALUE_SENT` is not set, then only the sender may call this method; + /// if it is set, then only the receiver may call this method. unsafe fn consume_value(&self) -> Option<T> { self.value.with_mut(|ptr| (*ptr).take()) } @@ -837,9 +1057,28 @@ impl<T: fmt::Debug> fmt::Debug for Inner<T> { } } +/// Indicates that a waker for the receiving task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `rx_task` field may be uninitialized. const RX_TASK_SET: usize = 0b00001; +/// Indicates that a value has been stored in the channel's inner `UnsafeCell`. +/// +/// # Safety +/// +/// This bit controls which side of the channel is permitted to access the +/// `UnsafeCell`. If it is set, the `UnsafeCell` may ONLY be accessed by the +/// receiver. If this bit is NOT set, the `UnsafeCell` may ONLY be accessed by +/// the sender. const VALUE_SENT: usize = 0b00010; const CLOSED: usize = 0b00100; + +/// Indicates that a waker for the sending task has been set. +/// +/// # Safety +/// +/// If this bit is not set, the `tx_task` field may be uninitialized. const TX_TASK_SET: usize = 0b01000; impl State { @@ -852,11 +1091,38 @@ impl State { } fn set_complete(cell: &AtomicUsize) -> State { - // TODO: This could be `Release`, followed by an `Acquire` fence *if* - // the `RX_TASK_SET` flag is set. However, `loom` does not support - // fences yet. - let val = cell.fetch_or(VALUE_SENT, AcqRel); - State(val) + // This method is a compare-and-swap loop rather than a fetch-or like + // other `set_$WHATEVER` methods on `State`. This is because we must + // check if the state has been closed before setting the `VALUE_SENT` + // bit. + // + // We don't want to set both the `VALUE_SENT` bit if the `CLOSED` + // bit is already set, because `VALUE_SENT` will tell the receiver that + // it's okay to access the inner `UnsafeCell`. Immediately after calling + // `set_complete`, if the channel was closed, the sender will _also_ + // access the `UnsafeCell` to take the value back out, so if a + // `poll_recv` or `try_recv` call is occurring concurrently, both + // threads may try to access the `UnsafeCell` if we were to set the + // `VALUE_SENT` bit on a closed channel. + let mut state = cell.load(Ordering::Relaxed); + loop { + if State(state).is_closed() { + break; + } + // TODO: This could be `Release`, followed by an `Acquire` fence *if* + // the `RX_TASK_SET` flag is set. However, `loom` does not support + // fences yet. + match cell.compare_exchange_weak( + state, + state | VALUE_SENT, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(actual) => state = actual, + } + } + State(state) } fn is_rx_task_set(self) -> bool { diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs index 6f0c011..120bc72 100644 --- a/src/sync/rwlock.rs +++ b/src/sync/rwlock.rs @@ -299,6 +299,12 @@ impl<T: ?Sized> RwLock<T> { /// Returns an RAII guard which will drop this read access of the `RwLock` /// when dropped. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `read` makes you lose your place in + /// the queue. + /// /// # Examples /// /// ``` @@ -357,6 +363,12 @@ impl<T: ?Sized> RwLock<T> { /// Returns an RAII guard which will drop this read access of the `RwLock` /// when dropped. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `read_owned` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` @@ -501,6 +513,12 @@ impl<T: ?Sized> RwLock<T> { /// Returns an RAII guard which will drop the write access of this `RwLock` /// when dropped. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `write` makes you lose your place + /// in the queue. + /// /// # Examples /// /// ``` @@ -543,6 +561,12 @@ impl<T: ?Sized> RwLock<T> { /// Returns an RAII guard which will drop the write access of this `RwLock` /// when dropped. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute locks in the order they + /// were requested. Cancelling a call to `write_owned` makes you lose your + /// place in the queue. + /// /// # Examples /// /// ``` diff --git a/src/sync/rwlock/owned_read_guard.rs b/src/sync/rwlock/owned_read_guard.rs index b7f3926..1881295 100644 --- a/src/sync/rwlock/owned_read_guard.rs +++ b/src/sync/rwlock/owned_read_guard.rs @@ -22,7 +22,7 @@ pub struct OwnedRwLockReadGuard<T: ?Sized, U: ?Sized = T> { } impl<T: ?Sized, U: ?Sized> OwnedRwLockReadGuard<T, U> { - /// Make a new `OwnedRwLockReadGuard` for a component of the locked data. + /// Makes a new `OwnedRwLockReadGuard` for a component of the locked data. /// This operation cannot fail as the `OwnedRwLockReadGuard` passed in /// already locked the data. /// diff --git a/src/sync/rwlock/owned_write_guard.rs b/src/sync/rwlock/owned_write_guard.rs index 91b6595..0a78d28 100644 --- a/src/sync/rwlock/owned_write_guard.rs +++ b/src/sync/rwlock/owned_write_guard.rs @@ -24,7 +24,7 @@ pub struct OwnedRwLockWriteGuard<T: ?Sized> { } impl<T: ?Sized> OwnedRwLockWriteGuard<T> { - /// Make a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked + /// Makes a new [`OwnedRwLockMappedWriteGuard`] for a component of the locked /// data. /// /// This operation cannot fail as the `OwnedRwLockWriteGuard` passed in diff --git a/src/sync/rwlock/owned_write_guard_mapped.rs b/src/sync/rwlock/owned_write_guard_mapped.rs index 6453236..d88ee01 100644 --- a/src/sync/rwlock/owned_write_guard_mapped.rs +++ b/src/sync/rwlock/owned_write_guard_mapped.rs @@ -23,7 +23,7 @@ pub struct OwnedRwLockMappedWriteGuard<T: ?Sized, U: ?Sized = T> { } impl<T: ?Sized, U: ?Sized> OwnedRwLockMappedWriteGuard<T, U> { - /// Make a new `OwnedRwLockMappedWriteGuard` for a component of the locked + /// Makes a new `OwnedRwLockMappedWriteGuard` for a component of the locked /// data. /// /// This operation cannot fail as the `OwnedRwLockMappedWriteGuard` passed diff --git a/src/sync/rwlock/read_guard.rs b/src/sync/rwlock/read_guard.rs index 38eec77..090b297 100644 --- a/src/sync/rwlock/read_guard.rs +++ b/src/sync/rwlock/read_guard.rs @@ -19,7 +19,7 @@ pub struct RwLockReadGuard<'a, T: ?Sized> { } impl<'a, T: ?Sized> RwLockReadGuard<'a, T> { - /// Make a new `RwLockReadGuard` for a component of the locked data. + /// Makes a new `RwLockReadGuard` for a component of the locked data. /// /// This operation cannot fail as the `RwLockReadGuard` passed in already /// locked the data. diff --git a/src/sync/rwlock/write_guard.rs b/src/sync/rwlock/write_guard.rs index 865a121..8c80ee7 100644 --- a/src/sync/rwlock/write_guard.rs +++ b/src/sync/rwlock/write_guard.rs @@ -22,7 +22,7 @@ pub struct RwLockWriteGuard<'a, T: ?Sized> { } impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> { - /// Make a new [`RwLockMappedWriteGuard`] for a component of the locked data. + /// Makes a new [`RwLockMappedWriteGuard`] for a component of the locked data. /// /// This operation cannot fail as the `RwLockWriteGuard` passed in already /// locked the data. diff --git a/src/sync/rwlock/write_guard_mapped.rs b/src/sync/rwlock/write_guard_mapped.rs index 9c5b1e7..3cf69de 100644 --- a/src/sync/rwlock/write_guard_mapped.rs +++ b/src/sync/rwlock/write_guard_mapped.rs @@ -21,7 +21,7 @@ pub struct RwLockMappedWriteGuard<'a, T: ?Sized> { } impl<'a, T: ?Sized> RwLockMappedWriteGuard<'a, T> { - /// Make a new `RwLockMappedWriteGuard` for a component of the locked data. + /// Makes a new `RwLockMappedWriteGuard` for a component of the locked data. /// /// This operation cannot fail as the `RwLockMappedWriteGuard` passed in already /// locked the data. diff --git a/src/sync/semaphore.rs b/src/sync/semaphore.rs index af75042..839b523 100644 --- a/src/sync/semaphore.rs +++ b/src/sync/semaphore.rs @@ -24,7 +24,55 @@ use std::sync::Arc; /// To use the `Semaphore` in a poll function, you can use the [`PollSemaphore`] /// utility. /// +/// # Examples +/// +/// Basic usage: +/// +/// ``` +/// use tokio::sync::{Semaphore, TryAcquireError}; +/// +/// #[tokio::main] +/// async fn main() { +/// let semaphore = Semaphore::new(3); +/// +/// let a_permit = semaphore.acquire().await.unwrap(); +/// let two_permits = semaphore.acquire_many(2).await.unwrap(); +/// +/// assert_eq!(semaphore.available_permits(), 0); +/// +/// let permit_attempt = semaphore.try_acquire(); +/// assert_eq!(permit_attempt.err(), Some(TryAcquireError::NoPermits)); +/// } +/// ``` +/// +/// Use [`Semaphore::acquire_owned`] to move permits across tasks: +/// +/// ``` +/// use std::sync::Arc; +/// use tokio::sync::Semaphore; +/// +/// #[tokio::main] +/// async fn main() { +/// let semaphore = Arc::new(Semaphore::new(3)); +/// let mut join_handles = Vec::new(); +/// +/// for _ in 0..5 { +/// let permit = semaphore.clone().acquire_owned().await.unwrap(); +/// join_handles.push(tokio::spawn(async move { +/// // perform task... +/// // explicitly own `permit` in the task +/// drop(permit); +/// })); +/// } +/// +/// for handle in join_handles { +/// handle.await.unwrap(); +/// } +/// } +/// ``` +/// /// [`PollSemaphore`]: https://docs.rs/tokio-util/0.6/tokio_util/sync/struct.PollSemaphore.html +/// [`Semaphore::acquire_owned`]: crate::sync::Semaphore::acquire_owned #[derive(Debug)] pub struct Semaphore { /// The low level semaphore @@ -79,6 +127,15 @@ impl Semaphore { } /// Creates a new semaphore with the initial number of permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// static SEM: Semaphore = Semaphore::const_new(10); + /// ``` + /// #[cfg(all(feature = "parking_lot", not(all(loom, test))))] #[cfg_attr(docsrs, doc(cfg(feature = "parking_lot")))] pub const fn const_new(permits: usize) -> Self { @@ -105,12 +162,38 @@ impl Semaphore { /// Otherwise, this returns a [`SemaphorePermit`] representing the /// acquired permit. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire` makes you lose your place + /// in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Semaphore::new(2); + /// + /// let permit_1 = semaphore.acquire().await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.acquire().await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// drop(permit_1); + /// assert_eq!(semaphore.available_permits(), 1); + /// } + /// ``` + /// /// [`AcquireError`]: crate::sync::AcquireError /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub async fn acquire(&self) -> Result<SemaphorePermit<'_>, AcquireError> { self.ll_sem.acquire(1).await?; Ok(SemaphorePermit { - sem: &self, + sem: self, permits: 1, }) } @@ -121,12 +204,32 @@ impl Semaphore { /// Otherwise, this returns a [`SemaphorePermit`] representing the /// acquired permits. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_many` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Semaphore::new(5); + /// + /// let permit = semaphore.acquire_many(3).await.unwrap(); + /// assert_eq!(semaphore.available_permits(), 2); + /// } + /// ``` + /// /// [`AcquireError`]: crate::sync::AcquireError /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> { self.ll_sem.acquire(n).await?; Ok(SemaphorePermit { - sem: &self, + sem: self, permits: n, }) } @@ -137,6 +240,25 @@ impl Semaphore { /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, /// this returns a [`SemaphorePermit`] representing the acquired permits. /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Semaphore::new(2); + /// + /// let permit_1 = semaphore.try_acquire().unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire().unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// let permit_3 = semaphore.try_acquire(); + /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits /// [`SemaphorePermit`]: crate::sync::SemaphorePermit @@ -153,8 +275,24 @@ impl Semaphore { /// Tries to acquire `n` permits from the semaphore. /// /// If the semaphore has been closed, this returns a [`TryAcquireError::Closed`] - /// and a [`TryAcquireError::NoPermits`] if there are no permits left. Otherwise, - /// this returns a [`SemaphorePermit`] representing the acquired permits. + /// and a [`TryAcquireError::NoPermits`] if there are not enough permits left. + /// Otherwise, this returns a [`SemaphorePermit`] representing the acquired permits. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Semaphore::new(4); + /// + /// let permit_1 = semaphore.try_acquire_many(3).unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire_many(2); + /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` /// /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits @@ -176,6 +314,38 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_owned` makes you lose your + /// place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(3)); + /// let mut join_handles = Vec::new(); + /// + /// for _ in 0..5 { + /// let permit = semaphore.clone().acquire_owned().await.unwrap(); + /// join_handles.push(tokio::spawn(async move { + /// // perform task... + /// // explicitly own `permit` in the task + /// drop(permit); + /// })); + /// } + /// + /// for handle in join_handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// /// [`Arc`]: std::sync::Arc /// [`AcquireError`]: crate::sync::AcquireError /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit @@ -194,6 +364,38 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Cancel safety + /// + /// This method uses a queue to fairly distribute permits in the order they + /// were requested. Cancelling a call to `acquire_many_owned` makes you lose + /// your place in the queue. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::Semaphore; + /// + /// #[tokio::main] + /// async fn main() { + /// let semaphore = Arc::new(Semaphore::new(10)); + /// let mut join_handles = Vec::new(); + /// + /// for _ in 0..5 { + /// let permit = semaphore.clone().acquire_many_owned(2).await.unwrap(); + /// join_handles.push(tokio::spawn(async move { + /// // perform task... + /// // explicitly own `permit` in the task + /// drop(permit); + /// })); + /// } + /// + /// for handle in join_handles { + /// handle.await.unwrap(); + /// } + /// } + /// ``` + /// /// [`Arc`]: std::sync::Arc /// [`AcquireError`]: crate::sync::AcquireError /// [`OwnedSemaphorePermit`]: crate::sync::OwnedSemaphorePermit @@ -216,6 +418,26 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Arc::new(Semaphore::new(2)); + /// + /// let permit_1 = Arc::clone(&semaphore).try_acquire_owned().unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = Arc::clone(&semaphore).try_acquire_owned().unwrap(); + /// assert_eq!(semaphore.available_permits(), 0); + /// + /// let permit_3 = semaphore.try_acquire_owned(); + /// assert_eq!(permit_3.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// /// [`Arc`]: std::sync::Arc /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits @@ -238,6 +460,23 @@ impl Semaphore { /// Otherwise, this returns a [`OwnedSemaphorePermit`] representing the /// acquired permit. /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::{Semaphore, TryAcquireError}; + /// + /// # fn main() { + /// let semaphore = Arc::new(Semaphore::new(4)); + /// + /// let permit_1 = Arc::clone(&semaphore).try_acquire_many_owned(3).unwrap(); + /// assert_eq!(semaphore.available_permits(), 1); + /// + /// let permit_2 = semaphore.try_acquire_many_owned(2); + /// assert_eq!(permit_2.err(), Some(TryAcquireError::NoPermits)); + /// # } + /// ``` + /// /// [`Arc`]: std::sync::Arc /// [`TryAcquireError::Closed`]: crate::sync::TryAcquireError::Closed /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits diff --git a/src/sync/task/atomic_waker.rs b/src/sync/task/atomic_waker.rs index 5917204..e1330fb 100644 --- a/src/sync/task/atomic_waker.rs +++ b/src/sync/task/atomic_waker.rs @@ -4,6 +4,7 @@ use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::{self, AtomicUsize}; use std::fmt; +use std::panic::{resume_unwind, AssertUnwindSafe, RefUnwindSafe, UnwindSafe}; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; use std::task::Waker; @@ -27,9 +28,12 @@ pub(crate) struct AtomicWaker { waker: UnsafeCell<Option<Waker>>, } +impl RefUnwindSafe for AtomicWaker {} +impl UnwindSafe for AtomicWaker {} + // `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell // stores a `Waker` value produced by calls to `register` and many threads can -// race to take the waker by calling `wake. +// race to take the waker by calling `wake`. // // If a new `Waker` instance is produced by calling `register` before an existing // one is consumed, then the existing one is overwritten. @@ -84,7 +88,7 @@ pub(crate) struct AtomicWaker { // back to `WAITING`. This transition must succeed as, at this point, the state // cannot be transitioned by another thread. // -// If the thread is unable to obtain the lock, the `WAKING` bit is still. +// If the thread is unable to obtain the lock, the `WAKING` bit is still set. // This is because it has either been set by the current thread but the previous // value included the `REGISTERING` bit **or** a concurrent thread is in the // `WAKING` critical section. Either way, no action must be taken. @@ -123,7 +127,7 @@ pub(crate) struct AtomicWaker { // Thread A still holds the `wake` lock, the call to `register` will result // in the task waking itself and get scheduled again. -/// Idle state +/// Idle state. const WAITING: usize = 0; /// A new waker value is being registered with the `AtomicWaker` cell. @@ -171,6 +175,10 @@ impl AtomicWaker { where W: WakerRef, { + fn catch_unwind<F: FnOnce() -> R, R>(f: F) -> std::thread::Result<R> { + std::panic::catch_unwind(AssertUnwindSafe(f)) + } + match self .state .compare_exchange(WAITING, REGISTERING, Acquire, Acquire) @@ -178,8 +186,24 @@ impl AtomicWaker { { WAITING => { unsafe { - // Locked acquired, update the waker cell - self.waker.with_mut(|t| *t = Some(waker.into_waker())); + // If `into_waker` panics (because it's code outside of + // AtomicWaker) we need to prime a guard that is called on + // unwind to restore the waker to a WAITING state. Otherwise + // any future calls to register will incorrectly be stuck + // believing it's being updated by someone else. + let new_waker_or_panic = catch_unwind(move || waker.into_waker()); + + // Set the field to contain the new waker, or if + // `into_waker` panicked, leave the old value. + let mut maybe_panic = None; + let mut old_waker = None; + match new_waker_or_panic { + Ok(new_waker) => { + old_waker = self.waker.with_mut(|t| (*t).take()); + self.waker.with_mut(|t| *t = Some(new_waker)); + } + Err(panic) => maybe_panic = Some(panic), + } // Release the lock. If the state transitioned to include // the `WAKING` bit, this means that a wake has been @@ -193,33 +217,67 @@ impl AtomicWaker { .compare_exchange(REGISTERING, WAITING, AcqRel, Acquire); match res { - Ok(_) => {} + Ok(_) => { + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + let _ = catch_unwind(move || { + drop(old_waker); + }); + } Err(actual) => { // This branch can only be reached if a // concurrent thread called `wake`. In this // case, `actual` **must** be `REGISTERING | - // `WAKING`. + // WAKING`. debug_assert_eq!(actual, REGISTERING | WAKING); // Take the waker to wake once the atomic operation has // completed. - let waker = self.waker.with_mut(|t| (*t).take()).unwrap(); + let mut waker = self.waker.with_mut(|t| (*t).take()); // Just swap, because no one could change state // while state == `Registering | `Waking` self.state.swap(WAITING, AcqRel); - // The atomic swap was complete, now - // wake the waker and return. - waker.wake(); + // If `into_waker` panicked, then the waker in the + // waker slot is actually the old waker. + if maybe_panic.is_some() { + old_waker = waker.take(); + } + + // We don't want to give the caller the panic if it + // was someone else who put in that waker. + if let Some(old_waker) = old_waker { + let _ = catch_unwind(move || { + old_waker.wake(); + }); + } + + // The atomic swap was complete, now wake the waker + // and return. + // + // If this panics, we end up in a consumed state and + // return the panic to the caller. + if let Some(waker) = waker { + debug_assert!(maybe_panic.is_none()); + waker.wake(); + } } } + + if let Some(panic) = maybe_panic { + // If `into_waker` panicked, return the panic to the caller. + resume_unwind(panic); + } } } WAKING => { // Currently in the process of waking the task, i.e., // `wake` is currently being called on the old waker. // So, we call wake on the new waker. + // + // If this panics, someone else is responsible for restoring the + // state of the waker. waker.wake(); // This is equivalent to a spin lock, so use a spin hint. @@ -245,6 +303,8 @@ impl AtomicWaker { /// If `register` has not been called yet, then this does nothing. pub(crate) fn wake(&self) { if let Some(waker) = self.take_waker() { + // If wake panics, we've consumed the waker which is a legitimate + // outcome. waker.wake(); } } diff --git a/src/sync/tests/atomic_waker.rs b/src/sync/tests/atomic_waker.rs index c832d62..b167a5d 100644 --- a/src/sync/tests/atomic_waker.rs +++ b/src/sync/tests/atomic_waker.rs @@ -32,3 +32,42 @@ fn wake_without_register() { assert!(!waker.is_woken()); } + +#[test] +fn atomic_waker_panic_safe() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| panic!("clone"), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + static NONPANICKING_VTABLE: RawWakerVTable = RawWakerVTable::new( + |_| RawWaker::new(ptr::null(), &NONPANICKING_VTABLE), + |_| unimplemented!("wake"), + |_| unimplemented!("wake_by_ref"), + |_| (), + ); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + let nonpanicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NONPANICKING_VTABLE)) }; + + let atomic_waker = AtomicWaker::new(); + + let panicking = panic::AssertUnwindSafe(&panicking); + + let result = panic::catch_unwind(|| { + let panic::AssertUnwindSafe(panicking) = panicking; + atomic_waker.register_by_ref(panicking); + }); + + assert!(result.is_err()); + assert!(atomic_waker.take_waker().is_none()); + + atomic_waker.register_by_ref(&nonpanicking); + assert!(atomic_waker.take_waker().is_some()); +} diff --git a/src/sync/tests/loom_atomic_waker.rs b/src/sync/tests/loom_atomic_waker.rs index c148bcb..f8bae65 100644 --- a/src/sync/tests/loom_atomic_waker.rs +++ b/src/sync/tests/loom_atomic_waker.rs @@ -43,3 +43,58 @@ fn basic_notification() { })); }); } + +#[test] +fn test_panicky_waker() { + use std::panic; + use std::ptr; + use std::task::{RawWaker, RawWakerVTable, Waker}; + + static PANICKING_VTABLE: RawWakerVTable = + RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ()); + + let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) }; + + // If you're working with this test (and I sure hope you never have to!), + // uncomment the following section because there will be a lot of panics + // which would otherwise log. + // + // We can't however leaved it uncommented, because it's global. + // panic::set_hook(Box::new(|_| ())); + + const NUM_NOTIFY: usize = 2; + + loom::model(move || { + let chan = Arc::new(Chan { + num: AtomicUsize::new(0), + task: AtomicWaker::new(), + }); + + for _ in 0..NUM_NOTIFY { + let chan = chan.clone(); + + thread::spawn(move || { + chan.num.fetch_add(1, Relaxed); + chan.task.wake(); + }); + } + + // Note: this panic should have no effect on the overall state of the + // waker and it should proceed as normal. + // + // A thread above might race to flag a wakeup, and a WAKING state will + // be preserved if this expected panic races with that so the below + // procedure should be allowed to continue uninterrupted. + let _ = panic::catch_unwind(|| chan.task.register_by_ref(&panicking)); + + block_on(poll_fn(move |cx| { + chan.task.register_by_ref(cx.waker()); + + if NUM_NOTIFY == chan.num.load(Relaxed) { + return Ready(()); + } + + Pending + })); + }); +} diff --git a/src/sync/tests/loom_mpsc.rs b/src/sync/tests/loom_mpsc.rs index c12313b..f165e70 100644 --- a/src/sync/tests/loom_mpsc.rs +++ b/src/sync/tests/loom_mpsc.rs @@ -132,3 +132,59 @@ fn dropping_unbounded_tx() { assert!(v.is_none()); }); } + +#[test] +fn try_recv() { + loom::model(|| { + use crate::sync::{mpsc, Semaphore}; + use loom::sync::{Arc, Mutex}; + + const PERMITS: usize = 2; + const TASKS: usize = 2; + const CYCLES: usize = 1; + + struct Context { + sem: Arc<Semaphore>, + tx: mpsc::Sender<()>, + rx: Mutex<mpsc::Receiver<()>>, + } + + fn run(ctx: &Context) { + block_on(async { + let permit = ctx.sem.acquire().await; + assert_ok!(ctx.rx.lock().unwrap().try_recv()); + crate::task::yield_now().await; + assert_ok!(ctx.tx.clone().try_send(())); + drop(permit); + }); + } + + let (tx, rx) = mpsc::channel(PERMITS); + let sem = Arc::new(Semaphore::new(PERMITS)); + let ctx = Arc::new(Context { + sem, + tx, + rx: Mutex::new(rx), + }); + + for _ in 0..PERMITS { + assert_ok!(ctx.tx.clone().try_send(())); + } + + let mut ths = Vec::new(); + + for _ in 0..TASKS { + let ctx = ctx.clone(); + + ths.push(thread::spawn(move || { + run(&ctx); + })); + } + + run(&ctx); + + for th in ths { + th.join().unwrap(); + } + }); +} diff --git a/src/sync/tests/loom_notify.rs b/src/sync/tests/loom_notify.rs index 4be949a..d484a75 100644 --- a/src/sync/tests/loom_notify.rs +++ b/src/sync/tests/loom_notify.rs @@ -33,12 +33,41 @@ fn notify_waiters() { tx.notify_waiters(); }); - th.join().unwrap(); - block_on(async { notified1.await; notified2.await; }); + + th.join().unwrap(); + }); +} + +#[test] +fn notify_waiters_and_one() { + loom::model(|| { + let notify = Arc::new(Notify::new()); + let tx1 = notify.clone(); + let tx2 = notify.clone(); + + let th1 = thread::spawn(move || { + tx1.notify_waiters(); + }); + + let th2 = thread::spawn(move || { + tx2.notify_one(); + }); + + let th3 = thread::spawn(move || { + let notified = notify.notified(); + + block_on(async { + notified.await; + }); + }); + + th1.join().unwrap(); + th2.join().unwrap(); + th3.join().unwrap(); }); } diff --git a/src/sync/tests/loom_oneshot.rs b/src/sync/tests/loom_oneshot.rs index 9729cfb..c5f7972 100644 --- a/src/sync/tests/loom_oneshot.rs +++ b/src/sync/tests/loom_oneshot.rs @@ -55,6 +55,35 @@ fn changing_rx_task() { }); } +#[test] +fn try_recv_close() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + thread::spawn(move || { + let _ = tx.send(()); + }); + + rx.close(); + let _ = rx.try_recv(); + }) +} + +#[test] +fn recv_closed() { + // reproduces https://github.com/tokio-rs/tokio/issues/4225 + loom::model(|| { + let (tx, mut rx) = oneshot::channel(); + + thread::spawn(move || { + let _ = tx.send(1); + }); + + rx.close(); + let _ = block_on(rx); + }); +} + // TODO: Move this into `oneshot` proper. use std::future::Future; diff --git a/src/sync/tests/mod.rs b/src/sync/tests/mod.rs index c5d5601..ee76418 100644 --- a/src/sync/tests/mod.rs +++ b/src/sync/tests/mod.rs @@ -1,5 +1,6 @@ cfg_not_loom! { mod atomic_waker; + mod notify; mod semaphore_batch; } diff --git a/src/sync/tests/notify.rs b/src/sync/tests/notify.rs new file mode 100644 index 0000000..8c9a573 --- /dev/null +++ b/src/sync/tests/notify.rs @@ -0,0 +1,44 @@ +use crate::sync::Notify; +use std::future::Future; +use std::mem::ManuallyDrop; +use std::sync::Arc; +use std::task::{Context, RawWaker, RawWakerVTable, Waker}; + +#[test] +fn notify_clones_waker_before_lock() { + const VTABLE: &RawWakerVTable = &RawWakerVTable::new(clone_w, wake, wake_by_ref, drop_w); + + unsafe fn clone_w(data: *const ()) -> RawWaker { + let arc = ManuallyDrop::new(Arc::<Notify>::from_raw(data as *const Notify)); + // Or some other arbitrary code that shouldn't be executed while the + // Notify wait list is locked. + arc.notify_one(); + let _arc_clone: ManuallyDrop<_> = arc.clone(); + RawWaker::new(data, VTABLE) + } + + unsafe fn drop_w(data: *const ()) { + let _ = Arc::<Notify>::from_raw(data as *const Notify); + } + + unsafe fn wake(_data: *const ()) { + unreachable!() + } + + unsafe fn wake_by_ref(_data: *const ()) { + unreachable!() + } + + let notify = Arc::new(Notify::new()); + let notify2 = notify.clone(); + + let waker = + unsafe { Waker::from_raw(RawWaker::new(Arc::into_raw(notify2) as *const _, VTABLE)) }; + let mut cx = Context::from_waker(&waker); + + let future = notify.notified(); + pin!(future); + + // The result doesn't matter, we're just testing that we don't deadlock. + let _ = future.poll(&mut cx); +} diff --git a/src/sync/watch.rs b/src/sync/watch.rs index bf6f0ac..7e45c11 100644 --- a/src/sync/watch.rs +++ b/src/sync/watch.rs @@ -56,8 +56,9 @@ use crate::sync::notify::Notify; use crate::loom::sync::atomic::AtomicUsize; -use crate::loom::sync::atomic::Ordering::{Relaxed, SeqCst}; +use crate::loom::sync::atomic::Ordering::Relaxed; use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; +use std::mem; use std::ops; /// Receives values from the associated [`Sender`](struct@Sender). @@ -74,7 +75,7 @@ pub struct Receiver<T> { shared: Arc<Shared<T>>, /// Last observed version - version: usize, + version: Version, } /// Sends values to the associated [`Receiver`](struct@Receiver). @@ -85,7 +86,7 @@ pub struct Sender<T> { shared: Arc<Shared<T>>, } -/// Returns a reference to the inner value +/// Returns a reference to the inner value. /// /// Outstanding borrows hold a read lock on the inner value. This means that /// long lived borrows could cause the produce half to block. It is recommended @@ -97,35 +98,33 @@ pub struct Ref<'a, T> { #[derive(Debug)] struct Shared<T> { - /// The most recent value + /// The most recent value. value: RwLock<T>, - /// The current version + /// The current version. /// /// The lowest bit represents a "closed" state. The rest of the bits /// represent the current version. - version: AtomicUsize, + state: AtomicState, - /// Tracks the number of `Receiver` instances + /// Tracks the number of `Receiver` instances. ref_count_rx: AtomicUsize, /// Notifies waiting receivers that the value changed. notify_rx: Notify, - /// Notifies any task listening for `Receiver` dropped events + /// Notifies any task listening for `Receiver` dropped events. notify_tx: Notify, } pub mod error { - //! Watch error types + //! Watch error types. use std::fmt; /// Error produced when sending a value fails. #[derive(Debug)] - pub struct SendError<T> { - pub(crate) inner: T, - } + pub struct SendError<T>(pub T); // ===== impl SendError ===== @@ -152,7 +151,72 @@ pub mod error { impl std::error::Error for RecvError {} } -const CLOSED: usize = 1; +use self::state::{AtomicState, Version}; +mod state { + use crate::loom::sync::atomic::AtomicUsize; + use crate::loom::sync::atomic::Ordering::SeqCst; + + const CLOSED: usize = 1; + + /// The version part of the state. The lowest bit is always zero. + #[derive(Copy, Clone, Debug, Eq, PartialEq)] + pub(super) struct Version(usize); + + /// Snapshot of the state. The first bit is used as the CLOSED bit. + /// The remaining bits are used as the version. + /// + /// The CLOSED bit tracks whether the Sender has been dropped. Dropping all + /// receivers does not set it. + #[derive(Copy, Clone, Debug)] + pub(super) struct StateSnapshot(usize); + + /// The state stored in an atomic integer. + #[derive(Debug)] + pub(super) struct AtomicState(AtomicUsize); + + impl Version { + /// Get the initial version when creating the channel. + pub(super) fn initial() -> Self { + Version(0) + } + } + + impl StateSnapshot { + /// Extract the version from the state. + pub(super) fn version(self) -> Version { + Version(self.0 & !CLOSED) + } + + /// Is the closed bit set? + pub(super) fn is_closed(self) -> bool { + (self.0 & CLOSED) == CLOSED + } + } + + impl AtomicState { + /// Create a new `AtomicState` that is not closed and which has the + /// version set to `Version::initial()`. + pub(super) fn new() -> Self { + AtomicState(AtomicUsize::new(0)) + } + + /// Load the current value of the state. + pub(super) fn load(&self) -> StateSnapshot { + StateSnapshot(self.0.load(SeqCst)) + } + + /// Increment the version counter. + pub(super) fn increment_version(&self) { + // Increment by two to avoid touching the CLOSED bit. + self.0.fetch_add(2, SeqCst); + } + + /// Set the closed bit in the state. + pub(super) fn set_closed(&self) { + self.0.fetch_or(CLOSED, SeqCst); + } + } +} /// Creates a new watch channel, returning the "send" and "receive" handles. /// @@ -184,7 +248,7 @@ const CLOSED: usize = 1; pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { let shared = Arc::new(Shared { value: RwLock::new(init), - version: AtomicUsize::new(0), + state: AtomicState::new(), ref_count_rx: AtomicUsize::new(1), notify_rx: Notify::new(), notify_tx: Notify::new(), @@ -194,26 +258,35 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) { shared: shared.clone(), }; - let rx = Receiver { shared, version: 0 }; + let rx = Receiver { + shared, + version: Version::initial(), + }; (tx, rx) } impl<T> Receiver<T> { - fn from_shared(version: usize, shared: Arc<Shared<T>>) -> Self { + fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self { // No synchronization necessary as this is only used as a counter and // not memory access. shared.ref_count_rx.fetch_add(1, Relaxed); - Self { version, shared } + Self { shared, version } } - /// Returns a reference to the most recently sent value + /// Returns a reference to the most recently sent value. + /// + /// This method does not mark the returned value as seen, so future calls to + /// [`changed`] may return immediately even if you have already seen the + /// value with a call to `borrow`. /// /// Outstanding borrows hold a read lock. This means that long lived borrows /// could cause the send half to block. It is recommended to keep the borrow /// as short lived as possible. /// + /// [`changed`]: Receiver::changed + /// /// # Examples /// /// ``` @@ -227,11 +300,40 @@ impl<T> Receiver<T> { Ref { inner } } - /// Wait for a change notification + /// Returns a reference to the most recently sent value and mark that value + /// as seen. + /// + /// This method marks the value as seen, so [`changed`] will not return + /// immediately if the newest value is one previously returned by + /// `borrow_and_update`. + /// + /// Outstanding borrows hold a read lock. This means that long lived borrows + /// could cause the send half to block. It is recommended to keep the borrow + /// as short lived as possible. + /// + /// [`changed`]: Receiver::changed + pub fn borrow_and_update(&mut self) -> Ref<'_, T> { + let inner = self.shared.value.read().unwrap(); + self.version = self.shared.state.load().version(); + Ref { inner } + } + + /// Waits for a change notification, then marks the newest value as seen. + /// + /// If the newest value in the channel has not yet been marked seen when + /// this method is called, the method marks that value seen and returns + /// immediately. If the newest value has already been marked seen, then the + /// method sleeps until a new message is sent by the [`Sender`] connected to + /// this `Receiver`, or until the [`Sender`] is dropped. + /// + /// This method returns an error if and only if the [`Sender`] is dropped. + /// + /// # Cancel safety /// - /// Returns when a new value has been sent by the [`Sender`] since the last - /// time `changed()` was called. When the `Sender` half is dropped, `Err` is - /// returned. + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no values have been marked + /// seen by this call to `changed`. /// /// [`Sender`]: struct@Sender /// @@ -280,11 +382,11 @@ impl<T> Receiver<T> { fn maybe_changed<T>( shared: &Shared<T>, - version: &mut usize, + version: &mut Version, ) -> Option<Result<(), error::RecvError>> { // Load the version from the state - let state = shared.version.load(SeqCst); - let new_version = state & !CLOSED; + let state = shared.state.load(); + let new_version = state.version(); if *version != new_version { // Observe the new version and return @@ -292,7 +394,7 @@ fn maybe_changed<T>( return Some(Ok(())); } - if CLOSED == state & CLOSED { + if state.is_closed() { // All receivers have dropped. return Some(Err(error::RecvError(()))); } @@ -322,21 +424,57 @@ impl<T> Drop for Receiver<T> { impl<T> Sender<T> { /// Sends a new value via the channel, notifying all receivers. + /// + /// This method fails if the channel has been closed, which happens when + /// every receiver has been dropped. pub fn send(&self, value: T) -> Result<(), error::SendError<T>> { // This is pretty much only useful as a hint anyway, so synchronization isn't critical. - if 0 == self.shared.ref_count_rx.load(Relaxed) { - return Err(error::SendError { inner: value }); + if 0 == self.receiver_count() { + return Err(error::SendError(value)); } - *self.shared.value.write().unwrap() = value; + self.send_replace(value); + Ok(()) + } + + /// Sends a new value via the channel, notifying all receivers and returning + /// the previous value in the channel. + /// + /// This can be useful for reusing the buffers inside a watched value. + /// Additionally, this method permits sending values even when there are no + /// receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _rx) = watch::channel(1); + /// assert_eq!(tx.send_replace(2), 1); + /// assert_eq!(tx.send_replace(3), 2); + /// ``` + pub fn send_replace(&self, value: T) -> T { + let old = { + // Acquire the write lock and update the value. + let mut lock = self.shared.value.write().unwrap(); + let old = mem::replace(&mut *lock, value); + + self.shared.state.increment_version(); - // Update the version. 2 is used so that the CLOSED bit is not set. - self.shared.version.fetch_add(2, SeqCst); + // Release the write lock. + // + // Incrementing the version counter while holding the lock ensures + // that receivers are able to figure out the version number of the + // value they are currently looking at. + drop(lock); + + old + }; // Notify all watchers self.shared.notify_rx.notify_waiters(); - Ok(()) + old } /// Returns a reference to the most recently sent value @@ -371,7 +509,7 @@ impl<T> Sender<T> { /// assert!(tx.is_closed()); /// ``` pub fn is_closed(&self) -> bool { - self.shared.ref_count_rx.load(Relaxed) == 0 + self.receiver_count() == 0 } /// Completes when all receivers have dropped. @@ -379,6 +517,11 @@ impl<T> Sender<T> { /// This allows the producer to get notified when interest in the produced /// values is canceled and immediately stop doing work. /// + /// # Cancel safety + /// + /// This method is cancel safe. Once the channel is closed, it stays closed + /// forever and all future calls to `closed` will return immediately. + /// /// # Examples /// /// ``` @@ -399,29 +542,109 @@ impl<T> Sender<T> { /// } /// ``` pub async fn closed(&self) { - let notified = self.shared.notify_tx.notified(); + while self.receiver_count() > 0 { + let notified = self.shared.notify_tx.notified(); - if self.shared.ref_count_rx.load(Relaxed) == 0 { - return; - } + if self.receiver_count() == 0 { + return; + } - notified.await; - debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed)); + notified.await; + // The channel could have been reopened in the meantime by calling + // `subscribe`, so we loop again. + } } - cfg_signal_internal! { - pub(crate) fn subscribe(&self) -> Receiver<T> { - let shared = self.shared.clone(); - let version = shared.version.load(SeqCst); + /// Creates a new [`Receiver`] connected to this `Sender`. + /// + /// All messages sent before this call to `subscribe` are initially marked + /// as seen by the new `Receiver`. + /// + /// This method can be called even if there are no other receivers. In this + /// case, the channel is reopened. + /// + /// # Examples + /// + /// The new channel will receive messages sent on this `Sender`. + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// + /// tx.send(5).unwrap(); + /// + /// let rx = tx.subscribe(); + /// assert_eq!(5, *rx.borrow()); + /// + /// tx.send(10).unwrap(); + /// assert_eq!(10, *rx.borrow()); + /// } + /// ``` + /// + /// The most recent message is considered seen by the channel, so this test + /// is guaranteed to pass. + /// + /// ``` + /// use tokio::sync::watch; + /// use tokio::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, _rx) = watch::channel(0u64); + /// tx.send(5).unwrap(); + /// let mut rx = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// // by spawning and sleeping, the message is sent after `main` + /// // hits the call to `changed`. + /// # if false { + /// tokio::time::sleep(Duration::from_millis(10)).await; + /// # } + /// tx.send(100).unwrap(); + /// }); + /// + /// rx.changed().await.unwrap(); + /// assert_eq!(100, *rx.borrow()); + /// } + /// ``` + pub fn subscribe(&self) -> Receiver<T> { + let shared = self.shared.clone(); + let version = shared.state.load().version(); + + // The CLOSED bit in the state tracks only whether the sender is + // dropped, so we do not need to unset it if this reopens the channel. + Receiver::from_shared(version, shared) + } - Receiver::from_shared(version, shared) - } + /// Returns the number of receivers that currently exist. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, rx1) = watch::channel("hello"); + /// + /// assert_eq!(1, tx.receiver_count()); + /// + /// let mut _rx2 = rx1.clone(); + /// + /// assert_eq!(2, tx.receiver_count()); + /// } + /// ``` + pub fn receiver_count(&self) -> usize { + self.shared.ref_count_rx.load(Relaxed) } } impl<T> Drop for Sender<T> { fn drop(&mut self) { - self.shared.version.fetch_or(CLOSED, SeqCst); + self.shared.state.set_closed(); self.shared.notify_rx.notify_waiters(); } } diff --git a/src/task/blocking.rs b/src/task/blocking.rs index 28bbcdb..825f25f 100644 --- a/src/task/blocking.rs +++ b/src/task/blocking.rs @@ -5,19 +5,24 @@ cfg_rt_multi_thread! { /// blocking the executor. /// /// In general, issuing a blocking call or performing a lot of compute in a - /// future without yielding is not okay, as it may prevent the executor from - /// driving other futures forward. This function runs the closure on the - /// current thread by having the thread temporarily cease from being a core - /// thread, and turns it into a blocking thread. See the [CPU-bound tasks - /// and blocking code][blocking] section for more information. - /// - /// Although this function avoids starving other independently spawned - /// tasks, any other code running concurrently in the same task will be - /// suspended during the call to `block_in_place`. This can happen e.g. when - /// using the [`join!`] macro. To avoid this issue, use [`spawn_blocking`] - /// instead. - /// - /// Note that this function can only be used when using the `multi_thread` runtime. + /// future without yielding is problematic, as it may prevent the executor + /// from driving other tasks forward. Calling this function informs the + /// executor that the currently executing task is about to block the thread, + /// so the executor is able to hand off any other tasks it has to a new + /// worker thread before that happens. See the [CPU-bound tasks and blocking + /// code][blocking] section for more information. + /// + /// Be aware that although this function avoids starving other independently + /// spawned tasks, any other code running concurrently in the same task will + /// be suspended during the call to `block_in_place`. This can happen e.g. + /// when using the [`join!`] macro. To avoid this issue, use + /// [`spawn_blocking`] instead of `block_in_place`. + /// + /// Note that this function cannot be used within a [`current_thread`] runtime + /// because in this case there are no other worker threads to hand off tasks + /// to. On the other hand, calling the function outside a runtime is + /// allowed. In this case, `block_in_place` just calls the provided closure + /// normally. /// /// Code running behind `block_in_place` cannot be cancelled. When you shut /// down the executor, it will wait indefinitely for all blocking operations @@ -43,6 +48,28 @@ cfg_rt_multi_thread! { /// }); /// # } /// ``` + /// + /// Code running inside `block_in_place` may use `block_on` to reenter the + /// async context. + /// + /// ``` + /// use tokio::task; + /// use tokio::runtime::Handle; + /// + /// # async fn docs() { + /// task::block_in_place(move || { + /// Handle::current().block_on(async move { + /// // do something async + /// }); + /// }); + /// # } + /// ``` + /// + /// # Panics + /// + /// This function panics if called from a [`current_thread`] runtime. + /// + /// [`current_thread`]: fn@crate::runtime::Builder::new_current_thread pub fn block_in_place<F, R>(f: F) -> R where F: FnOnce() -> R, @@ -62,13 +89,14 @@ cfg_rt! { /// /// Tokio will spawn more blocking threads when they are requested through this /// function until the upper limit configured on the [`Builder`] is reached. - /// This limit is very large by default, because `spawn_blocking` is often used - /// for various kinds of IO operations that cannot be performed asynchronously. - /// When you run CPU-bound code using `spawn_blocking`, you should keep this - /// large upper limit in mind. When running many CPU-bound computations, a - /// semaphore or some other synchronization primitive should be used to limit - /// the number of computation executed in parallel. Specialized CPU-bound - /// executors, such as [rayon], may also be a good fit. + /// After reaching the upper limit, the tasks are put in a queue. + /// The thread limit is very large by default, because `spawn_blocking` is often + /// used for various kinds of IO operations that cannot be performed + /// asynchronously. When you run CPU-bound code using `spawn_blocking`, you + /// should keep this large upper limit in mind. When running many CPU-bound + /// computations, a semaphore or some other synchronization primitive should be + /// used to limit the number of computation executed in parallel. Specialized + /// CPU-bound executors, such as [rayon], may also be a good fit. /// /// This function is intended for non-async operations that eventually finish on /// their own. If you want to spawn an ordinary thread, you should use @@ -84,27 +112,82 @@ cfg_rt! { /// still spawn additional threads for blocking operations. The basic /// scheduler's single thread is only used for asynchronous code. /// + /// # Related APIs and patterns for bridging asynchronous and blocking code + /// + /// In simple cases, it is sufficient to have the closure accept input + /// parameters at creation time and return a single value (or struct/tuple, etc.). + /// + /// For more complex situations in which it is desirable to stream data to or from + /// the synchronous context, the [`mpsc channel`] has `blocking_send` and + /// `blocking_recv` methods for use in non-async code such as the thread created + /// by `spawn_blocking`. + /// + /// Another option is [`SyncIoBridge`] for cases where the synchronous context + /// is operating on byte streams. For example, you might use an asynchronous + /// HTTP client such as [hyper] to fetch data, but perform complex parsing + /// of the payload body using a library written for synchronous I/O. + /// + /// Finally, see also [Bridging with sync code][bridgesync] for discussions + /// around the opposite case of using Tokio as part of a larger synchronous + /// codebase. + /// /// [`Builder`]: struct@crate::runtime::Builder /// [blocking]: ../index.html#cpu-bound-tasks-and-blocking-code /// [rayon]: https://docs.rs/rayon + /// [`mpsc channel`]: crate::sync::mpsc + /// [`SyncIoBridge`]: https://docs.rs/tokio-util/0.6/tokio_util/io/struct.SyncIoBridge.html + /// [hyper]: https://docs.rs/hyper /// [`thread::spawn`]: fn@std::thread::spawn /// [`shutdown_timeout`]: fn@crate::runtime::Runtime::shutdown_timeout + /// [bridgesync]: https://tokio.rs/tokio/topics/bridging /// /// # Examples /// + /// Pass an input value and receive result of computation: + /// /// ``` /// use tokio::task; /// /// # async fn docs() -> Result<(), Box<dyn std::error::Error>>{ + /// // Initial input + /// let mut v = "Hello, ".to_string(); /// let res = task::spawn_blocking(move || { - /// // do some compute-heavy work or call synchronous code - /// "done computing" + /// // Stand-in for compute-heavy work or using synchronous APIs + /// v.push_str("world"); + /// // Pass ownership of the value back to the asynchronous context + /// v /// }).await?; /// - /// assert_eq!(res, "done computing"); + /// // `res` is the value returned from the thread + /// assert_eq!(res.as_str(), "Hello, world"); /// # Ok(()) /// # } /// ``` + /// + /// Use a channel: + /// + /// ``` + /// use tokio::task; + /// use tokio::sync::mpsc; + /// + /// # async fn docs() { + /// let (tx, mut rx) = mpsc::channel(2); + /// let start = 5; + /// let worker = task::spawn_blocking(move || { + /// for x in 0..10 { + /// // Stand in for complex computation + /// tx.blocking_send(start + x).unwrap(); + /// } + /// }); + /// + /// let mut acc = 0; + /// while let Some(v) = rx.recv().await { + /// acc += v; + /// } + /// assert_eq!(acc, 95); + /// worker.await.unwrap(); + /// # } + /// ``` #[cfg_attr(tokio_track_caller, track_caller)] pub fn spawn_blocking<F, R>(f: F) -> JoinHandle<R> where diff --git a/src/task/builder.rs b/src/task/builder.rs new file mode 100644 index 0000000..f991fc6 --- /dev/null +++ b/src/task/builder.rs @@ -0,0 +1,102 @@ +#![allow(unreachable_pub)] +use crate::{runtime::context, task::JoinHandle}; +use std::future::Future; + +/// Factory which is used to configure the properties of a new task. +/// +/// Methods can be chained in order to configure it. +/// +/// Currently, there is only one configuration option: +/// +/// - [`name`], which specifies an associated name for +/// the task +/// +/// There are three types of task that can be spawned from a Builder: +/// - [`spawn_local`] for executing futures on the current thread +/// - [`spawn`] for executing [`Send`] futures on the runtime +/// - [`spawn_blocking`] for executing blocking code in the +/// blocking thread pool. +/// +/// ## Example +/// +/// ```no_run +/// use tokio::net::{TcpListener, TcpStream}; +/// +/// use std::io; +/// +/// async fn process(socket: TcpStream) { +/// // ... +/// # drop(socket); +/// } +/// +/// #[tokio::main] +/// async fn main() -> io::Result<()> { +/// let listener = TcpListener::bind("127.0.0.1:8080").await?; +/// +/// loop { +/// let (socket, _) = listener.accept().await?; +/// +/// tokio::task::Builder::new() +/// .name("tcp connection handler") +/// .spawn(async move { +/// // Process each socket concurrently. +/// process(socket).await +/// }); +/// } +/// } +/// ``` +#[derive(Default, Debug)] +pub struct Builder<'a> { + name: Option<&'a str>, +} + +impl<'a> Builder<'a> { + /// Creates a new task builder. + pub fn new() -> Self { + Self::default() + } + + /// Assigns a name to the task which will be spawned. + pub fn name(&self, name: &'a str) -> Self { + Self { name: Some(name) } + } + + /// Spawns a task on the executor. + /// + /// See [`task::spawn`](crate::task::spawn) for + /// more details. + #[cfg_attr(tokio_track_caller, track_caller)] + pub fn spawn<Fut>(self, future: Fut) -> JoinHandle<Fut::Output> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + super::spawn::spawn_inner(future, self.name) + } + + /// Spawns a task on the current thread. + /// + /// See [`task::spawn_local`](crate::task::spawn_local) + /// for more details. + #[cfg_attr(tokio_track_caller, track_caller)] + pub fn spawn_local<Fut>(self, future: Fut) -> JoinHandle<Fut::Output> + where + Fut: Future + 'static, + Fut::Output: 'static, + { + super::local::spawn_local_inner(future, self.name) + } + + /// Spawns blocking code on the blocking threadpool. + /// + /// See [`task::spawn_blocking`](crate::task::spawn_blocking) + /// for more details. + #[cfg_attr(tokio_track_caller, track_caller)] + pub fn spawn_blocking<Function, Output>(self, function: Function) -> JoinHandle<Output> + where + Function: FnOnce() -> Output + Send + 'static, + Output: Send + 'static, + { + context::current().spawn_blocking_inner(function, self.name) + } +} diff --git a/src/task/local.rs b/src/task/local.rs index 64f1ac5..4a5d313 100644 --- a/src/task/local.rs +++ b/src/task/local.rs @@ -1,15 +1,15 @@ //! Runs `!Send` futures on the current thread. -use crate::runtime::task::{self, JoinHandle, Task}; +use crate::loom::sync::{Arc, Mutex}; +use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; use crate::sync::AtomicWaker; -use crate::util::linked_list::{Link, LinkedList}; +use crate::util::VecDequeCell; -use std::cell::{Cell, RefCell}; +use std::cell::Cell; use std::collections::VecDeque; use std::fmt; use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; -use std::sync::{Arc, Mutex}; use std::task::Poll; use pin_project_lite::pin_project; @@ -211,10 +211,10 @@ cfg_rt! { /// [`task::spawn_local`]: fn@spawn_local /// [`mpsc`]: mod@crate::sync::mpsc pub struct LocalSet { - /// Current scheduler tick + /// Current scheduler tick. tick: Cell<u8>, - /// State available from thread-local + /// State available from thread-local. context: Context, /// This type should not be Send. @@ -222,29 +222,24 @@ cfg_rt! { } } -/// State available from the thread-local +/// State available from the thread-local. struct Context { - /// Owned task set and local run queue - tasks: RefCell<Tasks>, - - /// State shared between threads. - shared: Arc<Shared>, -} - -struct Tasks { /// Collection of all active tasks spawned onto this executor. - owned: LinkedList<Task<Arc<Shared>>, <Task<Arc<Shared>> as Link>::Target>, + owned: LocalOwnedTasks<Arc<Shared>>, /// Local run queue sender and receiver. - queue: VecDeque<task::Notified<Arc<Shared>>>, + queue: VecDequeCell<task::Notified<Arc<Shared>>>, + + /// State shared between threads. + shared: Arc<Shared>, } /// LocalSet state shared between threads. struct Shared { - /// Remote run queue sender - queue: Mutex<VecDeque<task::Notified<Arc<Shared>>>>, + /// Remote run queue sender. + queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>, - /// Wake the `LocalSet` task + /// Wake the `LocalSet` task. waker: AtomicWaker, } @@ -297,27 +292,36 @@ cfg_rt! { F: Future + 'static, F::Output: 'static, { - let future = crate::util::trace::task(future, "local"); + spawn_local_inner(future, None) + } + + pub(super) fn spawn_local_inner<F>(future: F, name: Option<&str>) -> JoinHandle<F::Output> + where F: Future + 'static, + F::Output: 'static + { + let future = crate::util::trace::task(future, "local", name); CURRENT.with(|maybe_cx| { let cx = maybe_cx .expect("`spawn_local` called from outside of a `task::LocalSet`"); - // Safety: Tasks are only polled and dropped from the thread that - // spawns them. - let (task, handle) = unsafe { task::joinable_local(future) }; - cx.tasks.borrow_mut().queue.push_back(task); + let (handle, notified) = cx.owned.bind(future, cx.shared.clone()); + + if let Some(notified) = notified { + cx.shared.schedule(notified); + } + handle }) } } -/// Initial queue capacity +/// Initial queue capacity. const INITIAL_CAPACITY: usize = 64; /// Max number of tasks to poll per tick. const MAX_TASKS_PER_TICK: usize = 61; -/// How often it check the remote queue first +/// How often it check the remote queue first. const REMOTE_FIRST_INTERVAL: u8 = 31; impl LocalSet { @@ -326,12 +330,10 @@ impl LocalSet { LocalSet { tick: Cell::new(0), context: Context { - tasks: RefCell::new(Tasks { - owned: LinkedList::new(), - queue: VecDeque::with_capacity(INITIAL_CAPACITY), - }), + owned: LocalOwnedTasks::new(), + queue: VecDequeCell::with_capacity(INITIAL_CAPACITY), shared: Arc::new(Shared { - queue: Mutex::new(VecDeque::with_capacity(INITIAL_CAPACITY)), + queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))), waker: AtomicWaker::new(), }), }, @@ -381,9 +383,14 @@ impl LocalSet { F: Future + 'static, F::Output: 'static, { - let future = crate::util::trace::task(future, "local"); - let (task, handle) = unsafe { task::joinable_local(future) }; - self.context.tasks.borrow_mut().queue.push_back(task); + let future = crate::util::trace::task(future, "local", None); + + let (handle, notified) = self.context.owned.bind(future, self.context.shared.clone()); + + if let Some(notified) = notified { + self.context.shared.schedule(notified); + } + self.context.shared.waker.wake(); handle } @@ -459,7 +466,7 @@ impl LocalSet { rt.block_on(self.run_until(future)) } - /// Run a future to completion on the local set, returning its output. + /// Runs a future to completion on the local set, returning its output. /// /// This returns a future that runs the given future with a local set, /// allowing it to call [`spawn_local`] to spawn additional `!Send` futures. @@ -498,7 +505,7 @@ impl LocalSet { run_until.await } - /// Tick the scheduler, returning whether the local future needs to be + /// Ticks the scheduler, returning whether the local future needs to be /// notified again. fn tick(&self) -> bool { for _ in 0..MAX_TASKS_PER_TICK { @@ -522,26 +529,30 @@ impl LocalSet { true } - fn next_task(&self) -> Option<task::Notified<Arc<Shared>>> { + fn next_task(&self) -> Option<task::LocalNotified<Arc<Shared>>> { let tick = self.tick.get(); self.tick.set(tick.wrapping_add(1)); - if tick % REMOTE_FIRST_INTERVAL == 0 { + let task = if tick % REMOTE_FIRST_INTERVAL == 0 { self.context .shared .queue .lock() - .unwrap() - .pop_front() - .or_else(|| self.context.tasks.borrow_mut().queue.pop_front()) + .as_mut() + .and_then(|queue| queue.pop_front()) + .or_else(|| self.context.queue.pop_front()) } else { - self.context - .tasks - .borrow_mut() - .queue - .pop_front() - .or_else(|| self.context.shared.queue.lock().unwrap().pop_front()) - } + self.context.queue.pop_front().or_else(|| { + self.context + .shared + .queue + .lock() + .as_mut() + .and_then(|queue| queue.pop_front()) + }) + }; + + task.map(|task| self.context.owned.assert_owner(task)) } fn with<T>(&self, f: impl FnOnce() -> T) -> T { @@ -567,7 +578,7 @@ impl Future for LocalSet { // there are still tasks remaining in the run queue. cx.waker().wake_by_ref(); Poll::Pending - } else if self.context.tasks.borrow().owned.is_empty() { + } else if self.context.owned.is_empty() { // If the scheduler has no remaining futures, we're done! Poll::Ready(()) } else { @@ -588,27 +599,24 @@ impl Default for LocalSet { impl Drop for LocalSet { fn drop(&mut self) { self.with(|| { - // Loop required here to ensure borrow is dropped between iterations - #[allow(clippy::while_let_loop)] - loop { - let task = match self.context.tasks.borrow_mut().owned.pop_back() { - Some(task) => task, - None => break, - }; - - // Safety: same as `run_unchecked`. - task.shutdown(); - } - - for task in self.context.tasks.borrow_mut().queue.drain(..) { - task.shutdown(); + // Shut down all tasks in the LocalOwnedTasks and close it to + // prevent new tasks from ever being added. + self.context.owned.close_and_shutdown_all(); + + // We already called shutdown on all tasks above, so there is no + // need to call shutdown. + for task in self.context.queue.take() { + drop(task); } - for task in self.context.shared.queue.lock().unwrap().drain(..) { - task.shutdown(); + // Take the queue from the Shared object to prevent pushing + // notifications to it in the future. + let queue = self.context.shared.queue.lock().take().unwrap(); + for task in queue { + drop(task); } - assert!(self.context.tasks.borrow().owned.is_empty()); + assert!(self.context.owned.is_empty()); }); } } @@ -651,11 +659,19 @@ impl Shared { fn schedule(&self, task: task::Notified<Arc<Self>>) { CURRENT.with(|maybe_cx| match maybe_cx { Some(cx) if cx.shared.ptr_eq(self) => { - cx.tasks.borrow_mut().queue.push_back(task); + cx.queue.push_back(task); } _ => { - self.queue.lock().unwrap().push_back(task); - self.waker.wake(); + // First check whether the queue is still there (if not, the + // LocalSet is dropped). Then push to it if so, and if not, + // do nothing. + let mut lock = self.queue.lock(); + + if let Some(queue) = lock.as_mut() { + queue.push_back(task); + drop(lock); + self.waker.wake(); + } } }); } @@ -666,26 +682,11 @@ impl Shared { } impl task::Schedule for Arc<Shared> { - fn bind(task: Task<Self>) -> Arc<Shared> { - CURRENT.with(|maybe_cx| { - let cx = maybe_cx.expect("scheduler context missing"); - cx.tasks.borrow_mut().owned.push_front(task); - cx.shared.clone() - }) - } - fn release(&self, task: &Task<Self>) -> Option<Task<Self>> { - use std::ptr::NonNull; - CURRENT.with(|maybe_cx| { let cx = maybe_cx.expect("scheduler context missing"); - assert!(cx.shared.ptr_eq(self)); - - let ptr = NonNull::from(task.header()); - // safety: task must be contained by list. It is inserted into the - // list in `bind`. - unsafe { cx.tasks.borrow_mut().owned.remove(ptr) } + cx.owned.remove(task) }) } diff --git a/src/task/mod.rs b/src/task/mod.rs index abae818..ea98787 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -86,7 +86,7 @@ //! ``` //! //! Again, like `std::thread`'s [`JoinHandle` type][thread_join], if the spawned -//! task panics, awaiting its `JoinHandle` will return a [`JoinError`]`. For +//! task panics, awaiting its `JoinHandle` will return a [`JoinError`]. For //! example: //! //! ``` @@ -122,6 +122,11 @@ //! Instead, Tokio provides two APIs for running blocking operations in an //! asynchronous context: [`task::spawn_blocking`] and [`task::block_in_place`]. //! +//! Be aware that if you call a non-async method from async code, that non-async +//! method is still inside the asynchronous context, so you should also avoid +//! blocking operations there. This includes destructors of objects destroyed in +//! async code. +//! //! #### spawn_blocking //! //! The `task::spawn_blocking` function is similar to the `task::spawn` function @@ -294,4 +299,14 @@ cfg_rt! { mod unconstrained; pub use unconstrained::{unconstrained, Unconstrained}; + + cfg_trace! { + mod builder; + pub use builder::Builder; + } + + /// Task-related futures. + pub mod futures { + pub use super::task_local::TaskLocalFuture; + } } diff --git a/src/task/spawn.rs b/src/task/spawn.rs index d846fb4..065d38f 100644 --- a/src/task/spawn.rs +++ b/src/task/spawn.rs @@ -1,6 +1,4 @@ -use crate::runtime; -use crate::task::JoinHandle; -use crate::util::error::CONTEXT_MISSING_ERROR; +use crate::{task::JoinHandle, util::error::CONTEXT_MISSING_ERROR}; use std::future::Future; @@ -124,14 +122,28 @@ cfg_rt! { /// error[E0391]: cycle detected when processing `main` /// ``` #[cfg_attr(tokio_track_caller, track_caller)] - pub fn spawn<T>(task: T) -> JoinHandle<T::Output> + pub fn spawn<T>(future: T) -> JoinHandle<T::Output> where T: Future + Send + 'static, T::Output: Send + 'static, { - let spawn_handle = runtime::context::spawn_handle() - .expect(CONTEXT_MISSING_ERROR); - let task = crate::util::trace::task(task, "task"); + // preventing stack overflows on debug mode, by quickly sending the + // task to the heap. + if cfg!(debug_assertions) && std::mem::size_of::<T>() > 2048 { + spawn_inner(Box::pin(future), None) + } else { + spawn_inner(future, None) + } + } + + #[cfg_attr(tokio_track_caller, track_caller)] + pub(super) fn spawn_inner<T>(future: T, name: Option<&str>) -> JoinHandle<T::Output> + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let spawn_handle = crate::runtime::context::spawn_handle().expect(CONTEXT_MISSING_ERROR); + let task = crate::util::trace::task(future, "task", name); spawn_handle.spawn(task) } } diff --git a/src/task/task_local.rs b/src/task/task_local.rs index 6571ffd..b6e7df4 100644 --- a/src/task/task_local.rs +++ b/src/task/task_local.rs @@ -2,6 +2,7 @@ use pin_project_lite::pin_project; use std::cell::RefCell; use std::error::Error; use std::future::Future; +use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{Context, Poll}; use std::{fmt, thread}; @@ -115,16 +116,16 @@ impl<T: 'static> LocalKey<T> { /// }).await; /// # } /// ``` - pub async fn scope<F>(&'static self, value: T, f: F) -> F::Output + pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> where F: Future, { TaskLocalFuture { - local: &self, + local: self, slot: Some(value), future: f, + _pinned: PhantomPinned, } - .await } /// Sets a value `T` as the task-local value for the closure `F`. @@ -148,12 +149,14 @@ impl<T: 'static> LocalKey<T> { where F: FnOnce() -> R, { - let mut scope = TaskLocalFuture { - local: &self, + let scope = TaskLocalFuture { + local: self, slot: Some(value), future: (), + _pinned: PhantomPinned, }; - Pin::new(&mut scope).with_task(|_| f()) + crate::pin!(scope); + scope.with_task(|_| f()) } /// Accesses the current task-local and runs the provided closure. @@ -206,11 +209,37 @@ impl<T: 'static> fmt::Debug for LocalKey<T> { } pin_project! { - struct TaskLocalFuture<T: StaticLifetime, F> { + /// A future that sets a value `T` of a task local for the future `F` during + /// its execution. + /// + /// The value of the task-local must be `'static` and will be dropped on the + /// completion of the future. + /// + /// Created by the function [`LocalKey::scope`](self::LocalKey::scope). + /// + /// ### Examples + /// + /// ``` + /// # async fn dox() { + /// tokio::task_local! { + /// static NUMBER: u32; + /// } + /// + /// NUMBER.scope(1, async move { + /// println!("task local value: {}", NUMBER.get()); + /// }).await; + /// # } + /// ``` + pub struct TaskLocalFuture<T, F> + where + T: 'static + { local: &'static LocalKey<T>, slot: Option<T>, #[pin] future: F, + #[pin] + _pinned: PhantomPinned, } } @@ -252,10 +281,6 @@ impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> { } } -// Required to make `pin_project` happy. -trait StaticLifetime: 'static {} -impl<T: 'static> StaticLifetime for T {} - /// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with). #[derive(Clone, Copy, Eq, PartialEq)] pub struct AccessError { diff --git a/src/task/unconstrained.rs b/src/task/unconstrained.rs index 4a62f81..31c732b 100644 --- a/src/task/unconstrained.rs +++ b/src/task/unconstrained.rs @@ -5,6 +5,7 @@ use std::task::{Context, Poll}; pin_project! { /// Future for the [`unconstrained`](unconstrained) method. + #[cfg_attr(docsrs, doc(cfg(feature = "rt")))] #[must_use = "Unconstrained does nothing unless polled"] pub struct Unconstrained<F> { #[pin] @@ -38,6 +39,7 @@ where /// otherwise. /// /// See also the usage example in the [task module](index.html#unconstrained). +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] pub fn unconstrained<F>(inner: F) -> Unconstrained<F> { Unconstrained { inner } } diff --git a/src/task/yield_now.rs b/src/task/yield_now.rs index 251cb93..5eeb46a 100644 --- a/src/task/yield_now.rs +++ b/src/task/yield_now.rs @@ -2,37 +2,58 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -cfg_rt! { - /// Yields execution back to the Tokio runtime. - /// - /// A task yields by awaiting on `yield_now()`, and may resume when that - /// future completes (with no output.) The current task will be re-added as - /// a pending task at the _back_ of the pending queue. Any other pending - /// tasks will be scheduled. No other waking is required for the task to - /// continue. - /// - /// See also the usage example in the [task module](index.html#yield_now). - #[must_use = "yield_now does nothing unless polled/`await`-ed"] - pub async fn yield_now() { - /// Yield implementation - struct YieldNow { - yielded: bool, - } - - impl Future for YieldNow { - type Output = (); +/// Yields execution back to the Tokio runtime. +/// +/// A task yields by awaiting on `yield_now()`, and may resume when that future +/// completes (with no output.) The current task will be re-added as a pending +/// task at the _back_ of the pending queue. Any other pending tasks will be +/// scheduled. No other waking is required for the task to continue. +/// +/// See also the usage example in the [task module](index.html#yield_now). +/// +/// ## Non-guarantees +/// +/// This function may not yield all the way up to the executor if there are any +/// special combinators above it in the call stack. For example, if a +/// [`tokio::select!`] has another branch complete during the same poll as the +/// `yield_now()`, then the yield is not propagated all the way up to the +/// runtime. +/// +/// It is generally not guaranteed that the runtime behaves like you expect it +/// to when deciding which task to schedule next after a call to `yield_now()`. +/// In particular, the runtime may choose to poll the task that just ran +/// `yield_now()` again immediately without polling any other tasks first. For +/// example, the runtime will not drive the IO driver between every poll of a +/// task, and this could result in the runtime polling the current task again +/// immediately even if there is another task that could make progress if that +/// other task is waiting for a notification from the IO driver. +/// +/// In general, changes to the order in which the runtime polls tasks is not +/// considered a breaking change, and your program should be correct no matter +/// which order the runtime polls your tasks in. +/// +/// [`tokio::select!`]: macro@crate::select +#[must_use = "yield_now does nothing unless polled/`await`-ed"] +#[cfg_attr(docsrs, doc(cfg(feature = "rt")))] +pub async fn yield_now() { + /// Yield implementation + struct YieldNow { + yielded: bool, + } - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { - if self.yielded { - return Poll::Ready(()); - } + impl Future for YieldNow { + type Output = (); - self.yielded = true; - cx.waker().wake_by_ref(); - Poll::Pending + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + if self.yielded { + return Poll::Ready(()); } - } - YieldNow { yielded: false }.await + self.yielded = true; + cx.waker().wake_by_ref(); + Poll::Pending + } } + + YieldNow { yielded: false }.await } diff --git a/src/time/clock.rs b/src/time/clock.rs index 8957800..41be9ba 100644 --- a/src/time/clock.rs +++ b/src/time/clock.rs @@ -7,7 +7,7 @@ //! configurable. cfg_not_test_util! { - use crate::time::{Duration, Instant}; + use crate::time::{Instant}; #[derive(Debug, Clone)] pub(crate) struct Clock {} @@ -24,20 +24,12 @@ cfg_not_test_util! { pub(crate) fn now(&self) -> Instant { now() } - - pub(crate) fn is_paused(&self) -> bool { - false - } - - pub(crate) fn advance(&self, _dur: Duration) { - unreachable!(); - } } } cfg_test_util! { use crate::time::{Duration, Instant}; - use std::sync::{Arc, Mutex}; + use crate::loom::sync::{Arc, Mutex}; cfg_rt! { fn clock() -> Option<Clock> { @@ -65,32 +57,51 @@ cfg_test_util! { /// Instant to use as the clock's base instant. base: std::time::Instant, - /// Instant at which the clock was last unfrozen + /// Instant at which the clock was last unfrozen. unfrozen: Option<std::time::Instant>, } - /// Pause time + /// Pauses time. /// /// The current value of `Instant::now()` is saved and all subsequent calls - /// to `Instant::now()` until the timer wheel is checked again will return - /// the saved value. Once the timer wheel is checked, time will immediately - /// advance to the next registered `Sleep`. This is useful for running tests - /// that depend on time. + /// to `Instant::now()` will return the saved value. The saved value can be + /// changed by [`advance`] or by the time auto-advancing once the runtime + /// has no work to do. This only affects the `Instant` type in Tokio, and + /// the `Instant` in std continues to work as normal. /// /// Pausing time requires the `current_thread` Tokio runtime. This is the /// default runtime used by `#[tokio::test]`. The runtime can be initialized /// with time in a paused state using the `Builder::start_paused` method. /// + /// For cases where time is immediately paused, it is better to pause + /// the time using the `main` or `test` macro: + /// ``` + /// #[tokio::main(flavor = "current_thread", start_paused = true)] + /// async fn main() { + /// println!("Hello world"); + /// } + /// ``` + /// /// # Panics /// /// Panics if time is already frozen or if called from outside of a /// `current_thread` Tokio runtime. + /// + /// # Auto-advance + /// + /// If time is paused and the runtime has no work to do, the clock is + /// auto-advanced to the next pending timer. This means that [`Sleep`] or + /// other timer-backed primitives can cause the runtime to advance the + /// current time when awaited. + /// + /// [`Sleep`]: crate::time::Sleep + /// [`advance`]: crate::time::advance pub fn pause() { let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); clock.pause(); } - /// Resume time + /// Resumes time. /// /// Clears the saved `Instant::now()` value. Subsequent calls to /// `Instant::now()` will return the value returned by the system call. @@ -101,7 +112,7 @@ cfg_test_util! { /// runtime. pub fn resume() { let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); - let mut inner = clock.inner.lock().unwrap(); + let mut inner = clock.inner.lock(); if inner.unfrozen.is_some() { panic!("time is not frozen"); @@ -110,35 +121,45 @@ cfg_test_util! { inner.unfrozen = Some(std::time::Instant::now()); } - /// Advance time + /// Advances time. /// /// Increments the saved `Instant::now()` value by `duration`. Subsequent /// calls to `Instant::now()` will return the result of the increment. /// + /// This function will make the current time jump forward by the given + /// duration in one jump. This means that all `sleep` calls with a deadline + /// before the new time will immediately complete "at the same time", and + /// the runtime is free to poll them in any order. Additionally, this + /// method will not wait for the `sleep` calls it advanced past to complete. + /// If you want to do that, you should instead call [`sleep`] and rely on + /// the runtime's auto-advance feature. + /// + /// Note that calls to `sleep` are not guaranteed to complete the first time + /// they are polled after a call to `advance`. For example, this can happen + /// if the runtime has not yet touched the timer driver after the call to + /// `advance`. However if they don't, the runtime will poll the task again + /// shortly. + /// /// # Panics /// /// Panics if time is not frozen or if called from outside of the Tokio /// runtime. + /// + /// # Auto-advance + /// + /// If the time is paused and there is no work to do, the runtime advances + /// time to the next timer. See [`pause`](pause#auto-advance) for more + /// details. + /// + /// [`sleep`]: fn@crate::time::sleep pub async fn advance(duration: Duration) { - use crate::future::poll_fn; - use std::task::Poll; - let clock = clock().expect("time cannot be frozen from outside the Tokio runtime"); clock.advance(duration); - let mut yielded = false; - poll_fn(|cx| { - if yielded { - Poll::Ready(()) - } else { - yielded = true; - cx.waker().wake_by_ref(); - Poll::Pending - } - }).await; + crate::task::yield_now().await; } - /// Return the current instant, factoring in frozen time. + /// Returns the current instant, factoring in frozen time. pub(crate) fn now() -> Instant { if let Some(clock) = clock() { clock.now() @@ -148,7 +169,7 @@ cfg_test_util! { } impl Clock { - /// Return a new `Clock` instance that uses the current execution context's + /// Returns a new `Clock` instance that uses the current execution context's /// source of time. pub(crate) fn new(enable_pausing: bool, start_paused: bool) -> Clock { let now = std::time::Instant::now(); @@ -169,7 +190,7 @@ cfg_test_util! { } pub(crate) fn pause(&self) { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock(); if !inner.enable_pausing { drop(inner); // avoid poisoning the lock @@ -183,12 +204,12 @@ cfg_test_util! { } pub(crate) fn is_paused(&self) -> bool { - let inner = self.inner.lock().unwrap(); + let inner = self.inner.lock(); inner.unfrozen.is_none() } pub(crate) fn advance(&self, duration: Duration) { - let mut inner = self.inner.lock().unwrap(); + let mut inner = self.inner.lock(); if inner.unfrozen.is_some() { panic!("time is not frozen"); @@ -198,7 +219,7 @@ cfg_test_util! { } pub(crate) fn now(&self) -> Instant { - let inner = self.inner.lock().unwrap(); + let inner = self.inner.lock(); let mut ret = inner.base; diff --git a/src/time/driver/entry.rs b/src/time/driver/entry.rs index e630fa8..9e9f0dc 100644 --- a/src/time/driver/entry.rs +++ b/src/time/driver/entry.rs @@ -68,7 +68,7 @@ use std::{marker::PhantomPinned, pin::Pin, ptr::NonNull}; type TimerResult = Result<(), crate::time::error::Error>; -const STATE_DEREGISTERED: u64 = u64::max_value(); +const STATE_DEREGISTERED: u64 = u64::MAX; const STATE_PENDING_FIRE: u64 = STATE_DEREGISTERED - 1; const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; @@ -85,10 +85,10 @@ const STATE_MIN_VALUE: u64 = STATE_PENDING_FIRE; /// requires only the driver lock. pub(super) struct StateCell { /// Holds either the scheduled expiration time for this timer, or (if the - /// timer has been fired and is unregistered), [`u64::max_value()`]. + /// timer has been fired and is unregistered), `u64::MAX`. state: AtomicU64, /// If the timer is fired (an Acquire order read on state shows - /// `u64::max_value()`), holds the result that should be returned from + /// `u64::MAX`), holds the result that should be returned from /// polling the timer. Otherwise, the contents are unspecified and reading /// without holding the driver lock is undefined behavior. result: UnsafeCell<TimerResult>, @@ -125,7 +125,7 @@ impl StateCell { fn when(&self) -> Option<u64> { let cur_state = self.state.load(Ordering::Relaxed); - if cur_state == u64::max_value() { + if cur_state == u64::MAX { None } else { Some(cur_state) @@ -271,7 +271,7 @@ impl StateCell { /// ordering, but is conservative - if it returns false, the timer is /// definitely _not_ registered. pub(super) fn might_be_registered(&self) -> bool { - self.state.load(Ordering::Relaxed) != u64::max_value() + self.state.load(Ordering::Relaxed) != u64::MAX } } @@ -345,7 +345,7 @@ impl TimerShared { } } - /// Gets the cached time-of-expiration value + /// Gets the cached time-of-expiration value. pub(super) fn cached_when(&self) -> u64 { // Cached-when is only accessed under the driver lock, so we can use relaxed self.driver_state.0.cached_when.load(Ordering::Relaxed) @@ -591,7 +591,7 @@ impl TimerHandle { match self.inner.as_ref().state.mark_pending(not_after) { Ok(()) => { // mark this as being on the pending queue in cached_when - self.inner.as_ref().set_cached_when(u64::max_value()); + self.inner.as_ref().set_cached_when(u64::MAX); Ok(()) } Err(tick) => { diff --git a/src/time/driver/handle.rs b/src/time/driver/handle.rs index 9a05a54..7aaf65a 100644 --- a/src/time/driver/handle.rs +++ b/src/time/driver/handle.rs @@ -16,17 +16,17 @@ impl Handle { Handle { time_source, inner } } - /// Returns the time source associated with this handle + /// Returns the time source associated with this handle. pub(super) fn time_source(&self) -> &ClockTime { &self.time_source } - /// Access the driver's inner structure + /// Access the driver's inner structure. pub(super) fn get(&self) -> &super::Inner { &*self.inner } - // Check whether the driver has been shutdown + /// Checks whether the driver has been shutdown. pub(super) fn is_shutdown(&self) -> bool { self.inner.is_shutdown() } @@ -76,7 +76,7 @@ cfg_not_rt! { /// lazy, and so outside executed inside the runtime successfully without /// panicking. pub(crate) fn current() -> Self { - panic!(crate::util::error::CONTEXT_MISSING_ERROR) + panic!("{}", crate::util::error::CONTEXT_MISSING_ERROR) } } } diff --git a/src/time/driver/mod.rs b/src/time/driver/mod.rs index 3eb1004..cf2290b 100644 --- a/src/time/driver/mod.rs +++ b/src/time/driver/mod.rs @@ -4,7 +4,7 @@ #![allow(unused_unsafe)] #![cfg_attr(not(feature = "rt"), allow(dead_code))] -//! Time driver +//! Time driver. mod entry; pub(self) use self::entry::{EntryList, TimerEntry, TimerHandle, TimerShared}; @@ -83,14 +83,23 @@ use std::{num::NonZeroU64, ptr::NonNull, task::Waker}; /// [interval]: crate::time::Interval #[derive(Debug)] pub(crate) struct Driver<P: Park + 'static> { - /// Timing backend in use + /// Timing backend in use. time_source: ClockTime, - /// Shared state + /// Shared state. handle: Handle, - /// Parker to delegate to + /// Parker to delegate to. park: P, + + // When `true`, a call to `park_timeout` should immediately return and time + // should not advance. One reason for this to be `true` is if the task + // passed to `Runtime::block_on` called `task::yield_now()`. + // + // While it may look racy, it only has any effect when the clock is paused + // and pausing the clock is restricted to a single-threaded runtime. + #[cfg(feature = "test-util")] + did_wake: Arc<AtomicBool>, } /// A structure which handles conversion from Instants to u64 timestamps. @@ -137,25 +146,25 @@ struct Inner { // The state is split like this so `Handle` can access `is_shutdown` without locking the mutex pub(super) state: Mutex<InnerState>, - /// True if the driver is being shutdown + /// True if the driver is being shutdown. pub(super) is_shutdown: AtomicBool, } /// Time state shared which must be protected by a `Mutex` struct InnerState { - /// Timing backend in use + /// Timing backend in use. time_source: ClockTime, /// The last published timer `elapsed` value. elapsed: u64, - /// The earliest time at which we promise to wake up without unparking + /// The earliest time at which we promise to wake up without unparking. next_wake: Option<NonZeroU64>, - /// Timer wheel + /// Timer wheel. wheel: wheel::Wheel, - /// Unparker that can be used to wake the time driver + /// Unparker that can be used to wake the time driver. unpark: Box<dyn Unpark>, } @@ -178,6 +187,8 @@ where time_source, handle: Handle::new(Arc::new(inner)), park, + #[cfg(feature = "test-util")] + did_wake: Arc::new(AtomicBool::new(false)), } } @@ -192,8 +203,6 @@ where } fn park_internal(&mut self, limit: Option<Duration>) -> Result<(), P::Error> { - let clock = &self.time_source.clock; - let mut lock = self.handle.get().state.lock(); assert!(!self.handle.is_shutdown()); @@ -217,26 +226,14 @@ where duration = std::cmp::min(limit, duration); } - if clock.is_paused() { - self.park.park_timeout(Duration::from_secs(0))?; - - // Simulate advancing time - clock.advance(duration); - } else { - self.park.park_timeout(duration)?; - } + self.park_timeout(duration)?; } else { self.park.park_timeout(Duration::from_secs(0))?; } } None => { if let Some(duration) = limit { - if clock.is_paused() { - self.park.park_timeout(Duration::from_secs(0))?; - clock.advance(duration); - } else { - self.park.park_timeout(duration)?; - } + self.park_timeout(duration)?; } else { self.park.park()?; } @@ -248,6 +245,39 @@ where Ok(()) } + + cfg_test_util! { + fn park_timeout(&mut self, duration: Duration) -> Result<(), P::Error> { + let clock = &self.time_source.clock; + + if clock.is_paused() { + self.park.park_timeout(Duration::from_secs(0))?; + + // If the time driver was woken, then the park completed + // before the "duration" elapsed (usually caused by a + // yield in `Runtime::block_on`). In this case, we don't + // advance the clock. + if !self.did_wake() { + // Simulate advancing time + clock.advance(duration); + } + } else { + self.park.park_timeout(duration)?; + } + + Ok(()) + } + + fn did_wake(&self) -> bool { + self.did_wake.swap(false, Ordering::SeqCst) + } + } + + cfg_not_test_util! { + fn park_timeout(&mut self, duration: Duration) -> Result<(), P::Error> { + self.park.park_timeout(duration) + } + } } impl Handle { @@ -258,13 +288,21 @@ impl Handle { self.process_at_time(now) } - pub(self) fn process_at_time(&self, now: u64) { + pub(self) fn process_at_time(&self, mut now: u64) { let mut waker_list: [Option<Waker>; 32] = Default::default(); let mut waker_idx = 0; let mut lock = self.get().lock(); - assert!(now >= lock.elapsed); + if now < lock.elapsed { + // Time went backwards! This normally shouldn't happen as the Rust language + // guarantees that an Instant is monotonic, but can happen when running + // Linux in a VM on a Windows host due to std incorrectly trusting the + // hardware clock to be monotonic. + // + // See <https://github.com/tokio-rs/tokio/issues/3619> for more information. + now = lock.elapsed; + } while let Some(entry) = lock.wheel.poll(now) { debug_assert!(unsafe { entry.is_pending() }); @@ -387,11 +425,11 @@ impl<P> Park for Driver<P> where P: Park + 'static, { - type Unpark = P::Unpark; + type Unpark = TimerUnpark<P>; type Error = P::Error; fn unpark(&self) -> Self::Unpark { - self.park.unpark() + TimerUnpark::new(self) } fn park(&mut self) -> Result<(), Self::Error> { @@ -426,6 +464,33 @@ where } } +pub(crate) struct TimerUnpark<P: Park + 'static> { + inner: P::Unpark, + + #[cfg(feature = "test-util")] + did_wake: Arc<AtomicBool>, +} + +impl<P: Park + 'static> TimerUnpark<P> { + fn new(driver: &Driver<P>) -> TimerUnpark<P> { + TimerUnpark { + inner: driver.park.unpark(), + + #[cfg(feature = "test-util")] + did_wake: driver.did_wake.clone(), + } + } +} + +impl<P: Park + 'static> Unpark for TimerUnpark<P> { + fn unpark(&self) { + #[cfg(feature = "test-util")] + self.did_wake.store(true, Ordering::SeqCst); + + self.inner.unpark(); + } +} + // ===== impl Inner ===== impl Inner { diff --git a/src/time/driver/sleep.rs b/src/time/driver/sleep.rs index 8658813..43ff694 100644 --- a/src/time/driver/sleep.rs +++ b/src/time/driver/sleep.rs @@ -1,25 +1,53 @@ use crate::time::driver::{Handle, TimerEntry}; use crate::time::{error::Error, Duration, Instant}; +use crate::util::trace; use pin_project_lite::pin_project; use std::future::Future; +use std::panic::Location; use std::pin::Pin; use std::task::{self, Poll}; +cfg_trace! { + use crate::time::driver::ClockTime; +} + /// Waits until `deadline` is reached. /// /// No work is performed while awaiting on the sleep future to complete. `Sleep` /// operates at millisecond granularity and should not be used for tasks that /// require high-resolution timers. /// +/// To run something regularly on a schedule, see [`interval`]. +/// /// # Cancellation /// /// Canceling a sleep instance is done by dropping the returned future. No additional /// cleanup work is required. +/// +/// # Examples +/// +/// Wait 100ms and print "100 ms have elapsed". +/// +/// ``` +/// use tokio::time::{sleep_until, Instant, Duration}; +/// +/// #[tokio::main] +/// async fn main() { +/// sleep_until(Instant::now() + Duration::from_millis(100)).await; +/// println!("100 ms have elapsed"); +/// } +/// ``` +/// +/// See the documentation for the [`Sleep`] type for more examples. +/// +/// [`Sleep`]: struct@crate::time::Sleep +/// [`interval`]: crate::time::interval() // Alias for old name in 0.x #[cfg_attr(docsrs, doc(alias = "delay_until"))] +#[cfg_attr(tokio_track_caller, track_caller)] pub fn sleep_until(deadline: Instant) -> Sleep { - Sleep::new_timeout(deadline) + return Sleep::new_timeout(deadline, trace::caller_location()); } /// Waits until `duration` has elapsed. @@ -54,13 +82,20 @@ pub fn sleep_until(deadline: Instant) -> Sleep { /// } /// ``` /// +/// See the documentation for the [`Sleep`] type for more examples. +/// +/// [`Sleep`]: struct@crate::time::Sleep /// [`interval`]: crate::time::interval() // Alias for old name in 0.x #[cfg_attr(docsrs, doc(alias = "delay_for"))] +#[cfg_attr(docsrs, doc(alias = "wait"))] +#[cfg_attr(tokio_track_caller, track_caller)] pub fn sleep(duration: Duration) -> Sleep { + let location = trace::caller_location(); + match Instant::now().checked_add(duration) { - Some(deadline) => sleep_until(deadline), - None => sleep_until(Instant::far_future()), + Some(deadline) => Sleep::new_timeout(deadline, location), + None => Sleep::new_timeout(Instant::far_future(), location), } } @@ -157,7 +192,7 @@ pin_project! { #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct Sleep { - deadline: Instant, + inner: Inner, // The link between the `Sleep` instance and the timer that drives it. #[pin] @@ -165,21 +200,87 @@ pin_project! { } } +cfg_trace! { + #[derive(Debug)] + struct Inner { + deadline: Instant, + resource_span: tracing::Span, + async_op_span: tracing::Span, + time_source: ClockTime, + } +} + +cfg_not_trace! { + #[derive(Debug)] + struct Inner { + deadline: Instant, + } +} + impl Sleep { - pub(crate) fn new_timeout(deadline: Instant) -> Sleep { + #[cfg_attr(not(all(tokio_unstable, feature = "tracing")), allow(unused_variables))] + pub(crate) fn new_timeout( + deadline: Instant, + location: Option<&'static Location<'static>>, + ) -> Sleep { let handle = Handle::current(); let entry = TimerEntry::new(&handle, deadline); - Sleep { deadline, entry } + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = { + let time_source = handle.time_source().clone(); + let deadline_tick = time_source.deadline_to_tick(deadline); + let duration = deadline_tick.checked_sub(time_source.now()).unwrap_or(0); + + #[cfg(tokio_track_caller)] + let location = location.expect("should have location if tracking caller"); + + #[cfg(tokio_track_caller)] + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "Sleep", + kind = "timer", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + #[cfg(not(tokio_track_caller))] + let resource_span = + tracing::trace_span!("runtime.resource", concrete_type = "Sleep", kind = "timer"); + + let async_op_span = + tracing::trace_span!("runtime.resource.async_op", source = "Sleep::new_timeout"); + + tracing::trace!( + target: "runtime::resource::state_update", + parent: resource_span.id(), + duration = duration, + duration.unit = "ms", + duration.op = "override", + ); + + Inner { + deadline, + resource_span, + async_op_span, + time_source, + } + }; + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = Inner { deadline }; + + Sleep { inner, entry } } - pub(crate) fn far_future() -> Sleep { - Self::new_timeout(Instant::far_future()) + pub(crate) fn far_future(location: Option<&'static Location<'static>>) -> Sleep { + Self::new_timeout(Instant::far_future(), location) } /// Returns the instant at which the future will complete. pub fn deadline(&self) -> Instant { - self.deadline + self.inner.deadline } /// Returns `true` if `Sleep` has elapsed. @@ -215,39 +316,87 @@ impl Sleep { /// # } /// ``` /// + /// See also the top-level examples. + /// /// [`Pin::as_mut`]: fn@std::pin::Pin::as_mut pub fn reset(self: Pin<&mut Self>, deadline: Instant) { + self.reset_inner(deadline) + } + + fn reset_inner(self: Pin<&mut Self>, deadline: Instant) { let me = self.project(); me.entry.reset(deadline); - *me.deadline = deadline; + (*me.inner).deadline = deadline; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + { + me.inner.async_op_span = + tracing::trace_span!("runtime.resource.async_op", source = "Sleep::reset"); + + tracing::trace!( + target: "runtime::resource::state_update", + parent: me.inner.resource_span.id(), + duration = { + let now = me.inner.time_source.now(); + let deadline_tick = me.inner.time_source.deadline_to_tick(deadline); + deadline_tick.checked_sub(now).unwrap_or(0) + }, + duration.unit = "ms", + duration.op = "override", + ); + } } - fn poll_elapsed(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { - let me = self.project(); + cfg_not_trace! { + fn poll_elapsed(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { + let me = self.project(); - // Keep track of task budget - let coop = ready!(crate::coop::poll_proceed(cx)); + // Keep track of task budget + let coop = ready!(crate::coop::poll_proceed(cx)); - me.entry.poll_elapsed(cx).map(move |r| { - coop.made_progress(); - r - }) + me.entry.poll_elapsed(cx).map(move |r| { + coop.made_progress(); + r + }) + } + } + + cfg_trace! { + fn poll_elapsed(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> { + let me = self.project(); + // Keep track of task budget + let coop = ready!(trace_poll_op!( + "poll_elapsed", + crate::coop::poll_proceed(cx), + me.inner.resource_span.id(), + )); + + let result = me.entry.poll_elapsed(cx).map(move |r| { + coop.made_progress(); + r + }); + + trace_poll_op!("poll_elapsed", result, me.inner.resource_span.id()) + } } } impl Future for Sleep { type Output = (); + // `poll_elapsed` can return an error in two cases: + // + // - AtCapacity: this is a pathological case where far too many + // sleep instances have been scheduled. + // - Shutdown: No timer has been setup, which is a mis-use error. + // + // Both cases are extremely rare, and pretty accurately fit into + // "logic errors", so we just panic in this case. A user couldn't + // really do much better if we passed the error onwards. fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> { - // `poll_elapsed` can return an error in two cases: - // - // - AtCapacity: this is a pathological case where far too many - // sleep instances have been scheduled. - // - Shutdown: No timer has been setup, which is a mis-use error. - // - // Both cases are extremely rare, and pretty accurately fit into - // "logic errors", so we just panic in this case. A user couldn't - // really do much better if we passed the error onwards. + #[cfg(all(tokio_unstable, feature = "tracing"))] + let _span = self.inner.async_op_span.clone().entered(); + match ready!(self.as_mut().poll_elapsed(cx)) { Ok(()) => Poll::Ready(()), Err(e) => panic!("timer error: {}", e), diff --git a/src/time/driver/wheel/level.rs b/src/time/driver/wheel/level.rs index 81d6b58..34d3176 100644 --- a/src/time/driver/wheel/level.rs +++ b/src/time/driver/wheel/level.rs @@ -250,7 +250,7 @@ fn level_range(level: usize) -> u64 { LEVEL_MULT as u64 * slot_range(level) } -/// Convert a duration (milliseconds) and a level to a slot position +/// Converts a duration (milliseconds) and a level to a slot position. fn slot_for(duration: u64, level: usize) -> usize { ((duration >> (level * 6)) % LEVEL_MULT as u64) as usize } diff --git a/src/time/driver/wheel/mod.rs b/src/time/driver/wheel/mod.rs index 24bf517..f088f2c 100644 --- a/src/time/driver/wheel/mod.rs +++ b/src/time/driver/wheel/mod.rs @@ -46,11 +46,11 @@ pub(crate) struct Wheel { /// precision of 1 millisecond. const NUM_LEVELS: usize = 6; -/// The maximum duration of a `Sleep` +/// The maximum duration of a `Sleep`. pub(super) const MAX_DURATION: u64 = (1 << (6 * NUM_LEVELS)) - 1; impl Wheel { - /// Create a new timing wheel + /// Creates a new timing wheel. pub(crate) fn new() -> Wheel { let levels = (0..NUM_LEVELS).map(Level::new).collect(); @@ -61,13 +61,13 @@ impl Wheel { } } - /// Return the number of milliseconds that have elapsed since the timing + /// Returns the number of milliseconds that have elapsed since the timing /// wheel's creation. pub(crate) fn elapsed(&self) -> u64 { self.elapsed } - /// Insert an entry into the timing wheel. + /// Inserts an entry into the timing wheel. /// /// # Arguments /// @@ -115,11 +115,11 @@ impl Wheel { Ok(when) } - /// Remove `item` from the timing wheel. + /// Removes `item` from the timing wheel. pub(crate) unsafe fn remove(&mut self, item: NonNull<TimerShared>) { unsafe { let when = item.as_ref().cached_when(); - if when == u64::max_value() { + if when == u64::MAX { self.pending.remove(item); } else { debug_assert!( @@ -136,7 +136,7 @@ impl Wheel { } } - /// Instant at which to poll + /// Instant at which to poll. pub(crate) fn poll_at(&self) -> Option<u64> { self.next_expiration().map(|expiration| expiration.deadline) } diff --git a/src/time/driver/wheel/stack.rs b/src/time/driver/wheel/stack.rs index e7ed137..80651c3 100644 --- a/src/time/driver/wheel/stack.rs +++ b/src/time/driver/wheel/stack.rs @@ -3,7 +3,7 @@ use crate::time::driver::Entry; use std::ptr; -/// A doubly linked stack +/// A doubly linked stack. #[derive(Debug)] pub(crate) struct Stack { head: Option<OwnedItem>, @@ -50,7 +50,7 @@ impl Stack { self.head = Some(entry); } - /// Pops an item from the stack + /// Pops an item from the stack. pub(crate) fn pop(&mut self) -> Option<OwnedItem> { let entry = self.head.take(); diff --git a/src/time/error.rs b/src/time/error.rs index 8674feb..63f0a3b 100644 --- a/src/time/error.rs +++ b/src/time/error.rs @@ -40,7 +40,7 @@ impl From<Kind> for Error { } } -/// Error returned by `Timeout`. +/// Errors returned by `Timeout`. #[derive(Debug, PartialEq)] pub struct Elapsed(()); @@ -72,7 +72,7 @@ impl Error { matches!(self.0, Kind::AtCapacity) } - /// Create an error representing a misconfigured timer. + /// Creates an error representing a misconfigured timer. pub fn invalid() -> Error { Error(Invalid) } diff --git a/src/time/instant.rs b/src/time/instant.rs index 1f8e663..f7cf12d 100644 --- a/src/time/instant.rs +++ b/src/time/instant.rs @@ -98,7 +98,7 @@ impl Instant { } /// Returns the amount of time elapsed from another instant to this one, or - /// zero duration if that instant is earlier than this one. + /// zero duration if that instant is later than this one. /// /// # Examples /// diff --git a/src/time/interval.rs b/src/time/interval.rs index 20cfcec..7e07e51 100644 --- a/src/time/interval.rs +++ b/src/time/interval.rs @@ -1,17 +1,20 @@ use crate::future::poll_fn; use crate::time::{sleep_until, Duration, Instant, Sleep}; -use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; +use std::{convert::TryInto, future::Future}; -/// Creates new `Interval` that yields with interval of `duration`. The first -/// tick completes immediately. +/// Creates new [`Interval`] that yields with interval of `period`. The first +/// tick completes immediately. The default [`MissedTickBehavior`] is +/// [`Burst`](MissedTickBehavior::Burst), but this can be configured +/// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior). /// -/// An interval will tick indefinitely. At any time, the `Interval` value can be -/// dropped. This cancels the interval. +/// An interval will tick indefinitely. At any time, the [`Interval`] value can +/// be dropped. This cancels the interval. /// -/// This function is equivalent to `interval_at(Instant::now(), period)`. +/// This function is equivalent to +/// [`interval_at(Instant::now(), period)`](interval_at). /// /// # Panics /// @@ -26,9 +29,9 @@ use std::task::{Context, Poll}; /// async fn main() { /// let mut interval = time::interval(Duration::from_millis(10)); /// -/// interval.tick().await; -/// interval.tick().await; -/// interval.tick().await; +/// interval.tick().await; // ticks immediately +/// interval.tick().await; // ticks after 10ms +/// interval.tick().await; // ticks after 10ms /// /// // approximately 20ms have elapsed. /// } @@ -36,10 +39,10 @@ use std::task::{Context, Poll}; /// /// A simple example using `interval` to execute a task every two seconds. /// -/// The difference between `interval` and [`sleep`] is that an `interval` -/// measures the time since the last tick, which means that `.tick().await` +/// The difference between `interval` and [`sleep`] is that an [`Interval`] +/// measures the time since the last tick, which means that [`.tick().await`] /// may wait for a shorter time than the duration specified for the interval -/// if some time has passed between calls to `.tick().await`. +/// if some time has passed between calls to [`.tick().await`]. /// /// If the tick in the example below was replaced with [`sleep`], the task /// would only be executed once every three seconds, and not every two @@ -64,17 +67,20 @@ use std::task::{Context, Poll}; /// ``` /// /// [`sleep`]: crate::time::sleep() +/// [`.tick().await`]: Interval::tick pub fn interval(period: Duration) -> Interval { assert!(period > Duration::new(0, 0), "`period` must be non-zero."); interval_at(Instant::now(), period) } -/// Creates new `Interval` that yields with interval of `period` with the -/// first tick completing at `start`. +/// Creates new [`Interval`] that yields with interval of `period` with the +/// first tick completing at `start`. The default [`MissedTickBehavior`] is +/// [`Burst`](MissedTickBehavior::Burst), but this can be configured +/// by calling [`set_missed_tick_behavior`](Interval::set_missed_tick_behavior). /// -/// An interval will tick indefinitely. At any time, the `Interval` value can be -/// dropped. This cancels the interval. +/// An interval will tick indefinitely. At any time, the [`Interval`] value can +/// be dropped. This cancels the interval. /// /// # Panics /// @@ -90,9 +96,9 @@ pub fn interval(period: Duration) -> Interval { /// let start = Instant::now() + Duration::from_millis(50); /// let mut interval = interval_at(start, Duration::from_millis(10)); /// -/// interval.tick().await; -/// interval.tick().await; -/// interval.tick().await; +/// interval.tick().await; // ticks after 50ms +/// interval.tick().await; // ticks after 10ms +/// interval.tick().await; // ticks after 10ms /// /// // approximately 70ms have elapsed. /// } @@ -103,19 +109,249 @@ pub fn interval_at(start: Instant, period: Duration) -> Interval { Interval { delay: Box::pin(sleep_until(start)), period, + missed_tick_behavior: Default::default(), } } -/// Interval returned by [`interval`](interval) and [`interval_at`](interval_at). +/// Defines the behavior of an [`Interval`] when it misses a tick. +/// +/// Sometimes, an [`Interval`]'s tick is missed. For example, consider the +/// following: +/// +/// ``` +/// use tokio::time::{self, Duration}; +/// # async fn task_that_takes_one_to_three_millis() {} +/// +/// #[tokio::main] +/// async fn main() { +/// // ticks every 2 seconds +/// let mut interval = time::interval(Duration::from_millis(2)); +/// for _ in 0..5 { +/// interval.tick().await; +/// // if this takes more than 2 milliseconds, a tick will be delayed +/// task_that_takes_one_to_three_millis().await; +/// } +/// } +/// ``` +/// +/// Generally, a tick is missed if too much time is spent without calling +/// [`Interval::tick()`]. +/// +/// By default, when a tick is missed, [`Interval`] fires ticks as quickly as it +/// can until it is "caught up" in time to where it should be. +/// `MissedTickBehavior` can be used to specify a different behavior for +/// [`Interval`] to exhibit. Each variant represents a different strategy. +/// +/// Note that because the executor cannot guarantee exact precision with timers, +/// these strategies will only apply when the delay is greater than 5 +/// milliseconds. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MissedTickBehavior { + /// Ticks as fast as possible until caught up. + /// + /// When this strategy is used, [`Interval`] schedules ticks "normally" (the + /// same as it would have if the ticks hadn't been delayed), which results + /// in it firing ticks as fast as possible until it is caught up in time to + /// where it should be. Unlike [`Delay`] and [`Skip`], the ticks yielded + /// when `Burst` is used (the [`Instant`]s that [`tick`](Interval::tick) + /// yields) aren't different than they would have been if a tick had not + /// been missed. Like [`Skip`], and unlike [`Delay`], the ticks may be + /// shortened. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work | work | work -| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration}; + /// # async fn task_that_takes_200_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// + /// task_that_takes_200_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // Since we are more than 100ms after the start of `interval`, this will + /// // also resolve immediately. + /// interval.tick().await; + /// + /// // Also resolves immediately, because it was supposed to resolve at + /// // 150ms after the start of `interval` + /// interval.tick().await; + /// + /// // Resolves immediately + /// interval.tick().await; + /// + /// // Since we have gotten to 200ms after the start of `interval`, this + /// // will resolve after 50ms + /// interval.tick().await; + /// # } + /// ``` + /// + /// This is the default behavior when [`Interval`] is created with + /// [`interval`] and [`interval_at`]. + /// + /// [`Delay`]: MissedTickBehavior::Delay + /// [`Skip`]: MissedTickBehavior::Skip + Burst, + + /// Tick at multiples of `period` from when [`tick`] was called, rather than + /// from `start`. + /// + /// When this strategy is used and [`Interval`] has missed a tick, instead + /// of scheduling ticks to fire at multiples of `period` from `start` (the + /// time when the first tick was fired), it schedules all future ticks to + /// happen at a regular `period` from the point when [`tick`] was called. + /// Unlike [`Burst`] and [`Skip`], ticks are not shortened, and they aren't + /// guaranteed to happen at a multiple of `period` from `start` any longer. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work -----| work -----| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration, MissedTickBehavior}; + /// # async fn task_that_takes_more_than_50_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + /// + /// task_that_takes_more_than_50_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // But this one, rather than also resolving immediately, as might happen + /// // with the `Burst` or `Skip` behaviors, will not resolve until + /// // 50ms after the call to `tick` up above. That is, in `tick`, when we + /// // recognize that we missed a tick, we schedule the next tick to happen + /// // 50ms (or whatever the `period` is) from right then, not from when + /// // were were *supposed* to tick + /// interval.tick().await; + /// # } + /// ``` + /// + /// [`Burst`]: MissedTickBehavior::Burst + /// [`Skip`]: MissedTickBehavior::Skip + /// [`tick`]: Interval::tick + Delay, + + /// Skips missed ticks and tick on the next multiple of `period` from + /// `start`. + /// + /// When this strategy is used, [`Interval`] schedules the next tick to fire + /// at the next-closest tick that is a multiple of `period` away from + /// `start` (the point where [`Interval`] first ticked). Like [`Burst`], all + /// ticks remain multiples of `period` away from `start`, but unlike + /// [`Burst`], the ticks may not be *one* multiple of `period` away from the + /// last tick. Like [`Delay`], the ticks are no longer the same as they + /// would have been if ticks had not been missed, but unlike [`Delay`], and + /// like [`Burst`], the ticks may be shortened to be less than one `period` + /// away from each other. + /// + /// This looks something like this: + /// ```text + /// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | + /// Actual ticks: | work -----| delay | work ---| work -----| work -----| + /// ``` + /// + /// In code: + /// + /// ``` + /// use tokio::time::{interval, Duration, MissedTickBehavior}; + /// # async fn task_that_takes_75_millis() {} + /// + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut interval = interval(Duration::from_millis(50)); + /// interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + /// + /// task_that_takes_75_millis().await; + /// // The `Interval` has missed a tick + /// + /// // Since we have exceeded our timeout, this will resolve immediately + /// interval.tick().await; + /// + /// // This one will resolve after 25ms, 100ms after the start of + /// // `interval`, which is the closest multiple of `period` from the start + /// // of `interval` after the call to `tick` up above. + /// interval.tick().await; + /// # } + /// ``` + /// + /// [`Burst`]: MissedTickBehavior::Burst + /// [`Delay`]: MissedTickBehavior::Delay + Skip, +} + +impl MissedTickBehavior { + /// If a tick is missed, this method is called to determine when the next tick should happen. + fn next_timeout(&self, timeout: Instant, now: Instant, period: Duration) -> Instant { + match self { + Self::Burst => timeout + period, + Self::Delay => now + period, + Self::Skip => { + now + period + - Duration::from_nanos( + ((now - timeout).as_nanos() % period.as_nanos()) + .try_into() + // This operation is practically guaranteed not to + // fail, as in order for it to fail, `period` would + // have to be longer than `now - timeout`, and both + // would have to be longer than 584 years. + // + // If it did fail, there's not a good way to pass + // the error along to the user, so we just panic. + .expect( + "too much time has elapsed since the interval was supposed to tick", + ), + ) + } + } + } +} + +impl Default for MissedTickBehavior { + /// Returns [`MissedTickBehavior::Burst`]. + /// + /// For most usecases, the [`Burst`] strategy is what is desired. + /// Additionally, to preserve backwards compatibility, the [`Burst`] + /// strategy must be the default. For these reasons, + /// [`MissedTickBehavior::Burst`] is the default for [`MissedTickBehavior`]. + /// See [`Burst`] for more details. + /// + /// [`Burst`]: MissedTickBehavior::Burst + fn default() -> Self { + Self::Burst + } +} + +/// Interval returned by [`interval`] and [`interval_at`]. /// /// This type allows you to wait on a sequence of instants with a certain -/// duration between each instant. Unlike calling [`sleep`](crate::time::sleep) -/// in a loop, this lets you count the time spent between the calls to `sleep` -/// as well. +/// duration between each instant. Unlike calling [`sleep`] in a loop, this lets +/// you count the time spent between the calls to [`sleep`] as well. /// /// An `Interval` can be turned into a `Stream` with [`IntervalStream`]. /// -/// [`IntervalStream`]: https://docs.rs/tokio-stream/0.1/tokio_stream/wrappers/struct.IntervalStream.html +/// [`IntervalStream`]: https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.IntervalStream.html +/// [`sleep`]: crate::time::sleep #[derive(Debug)] pub struct Interval { /// Future that completes the next time the `Interval` yields a value. @@ -123,11 +359,19 @@ pub struct Interval { /// The duration between values yielded by `Interval`. period: Duration, + + /// The strategy `Interval` should use when a tick is missed. + missed_tick_behavior: MissedTickBehavior, } impl Interval { /// Completes when the next instant in the interval has been reached. /// + /// # Cancel safety + /// + /// This method is cancellation safe. If `tick` is used as the branch in a `tokio::select!` and + /// another branch completes first, then no tick has been consumed. + /// /// # Examples /// /// ``` @@ -150,7 +394,7 @@ impl Interval { poll_fn(|cx| self.poll_tick(cx)).await } - /// Poll for the next instant in the interval to be reached. + /// Polls for the next instant in the interval to be reached. /// /// This method can return the following values: /// @@ -159,21 +403,50 @@ impl Interval { /// /// When this method returns `Poll::Pending`, the current task is scheduled /// to receive a wakeup when the instant has elapsed. Note that on multiple - /// calls to `poll_tick`, only the `Waker` from the `Context` passed to the - /// most recent call is scheduled to receive a wakeup. + /// calls to `poll_tick`, only the [`Waker`](std::task::Waker) from the + /// [`Context`] passed to the most recent call is scheduled to receive a + /// wakeup. pub fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> { // Wait for the delay to be done ready!(Pin::new(&mut self.delay).poll(cx)); - // Get the `now` by looking at the `delay` deadline - let now = self.delay.deadline(); + // Get the time when we were scheduled to tick + let timeout = self.delay.deadline(); + + let now = Instant::now(); + + // If a tick was not missed, and thus we are being called before the + // next tick is due, just schedule the next tick normally, one `period` + // after `timeout` + // + // However, if a tick took excessively long and we are now behind, + // schedule the next tick according to how the user specified with + // `MissedTickBehavior` + let next = if now > timeout + Duration::from_millis(5) { + self.missed_tick_behavior + .next_timeout(timeout, now, self.period) + } else { + timeout + self.period + }; - // The next interval value is `duration` after the one that just - // yielded. - let next = now + self.period; self.delay.as_mut().reset(next); - // Return the current instant - Poll::Ready(now) + // Return the time when we were scheduled to tick + Poll::Ready(timeout) + } + + /// Returns the [`MissedTickBehavior`] strategy currently being used. + pub fn missed_tick_behavior(&self) -> MissedTickBehavior { + self.missed_tick_behavior + } + + /// Sets the [`MissedTickBehavior`] strategy that should be used. + pub fn set_missed_tick_behavior(&mut self, behavior: MissedTickBehavior) { + self.missed_tick_behavior = behavior; + } + + /// Returns the period of the interval. + pub fn period(&self) -> Duration { + self.period } } diff --git a/src/time/mod.rs b/src/time/mod.rs index 98bb2af..281990e 100644 --- a/src/time/mod.rs +++ b/src/time/mod.rs @@ -3,21 +3,21 @@ //! This module provides a number of types for executing code after a set period //! of time. //! -//! * `Sleep` is a future that does no work and completes at a specific `Instant` +//! * [`Sleep`] is a future that does no work and completes at a specific [`Instant`] //! in time. //! -//! * `Interval` is a stream yielding a value at a fixed period. It is -//! initialized with a `Duration` and repeatedly yields each time the duration +//! * [`Interval`] is a stream yielding a value at a fixed period. It is +//! initialized with a [`Duration`] and repeatedly yields each time the duration //! elapses. //! -//! * `Timeout`: Wraps a future or stream, setting an upper bound to the amount +//! * [`Timeout`]: Wraps a future or stream, setting an upper bound to the amount //! of time it is allowed to execute. If the future or stream does not //! complete in time, then it is canceled and an error is returned. //! //! These types are sufficient for handling a large number of scenarios //! involving time. //! -//! These types must be used from within the context of the `Runtime`. +//! These types must be used from within the context of the [`Runtime`](crate::runtime::Runtime). //! //! # Examples //! @@ -55,8 +55,8 @@ //! A simple example using [`interval`] to execute a task every two seconds. //! //! The difference between [`interval`] and [`sleep`] is that an [`interval`] -//! measures the time since the last tick, which means that `.tick().await` -//! may wait for a shorter time than the duration specified for the interval +//! measures the time since the last tick, which means that `.tick().await` may +//! wait for a shorter time than the duration specified for the interval //! if some time has passed between calls to `.tick().await`. //! //! If the tick in the example below was replaced with [`sleep`], the task @@ -81,7 +81,6 @@ //! } //! ``` //! -//! [`sleep`]: crate::time::sleep() //! [`interval`]: crate::time::interval() mod clock; @@ -100,7 +99,7 @@ mod instant; pub use self::instant::Instant; mod interval; -pub use interval::{interval, interval_at, Interval}; +pub use interval::{interval, interval_at, Interval, MissedTickBehavior}; mod timeout; #[doc(inline)] diff --git a/src/time/timeout.rs b/src/time/timeout.rs index 61964ad..6725caa 100644 --- a/src/time/timeout.rs +++ b/src/time/timeout.rs @@ -4,14 +4,17 @@ //! //! [`Timeout`]: struct@Timeout -use crate::time::{error::Elapsed, sleep_until, Duration, Instant, Sleep}; +use crate::{ + time::{error::Elapsed, sleep_until, Duration, Instant, Sleep}, + util::trace, +}; use pin_project_lite::pin_project; use std::future::Future; use std::pin::Pin; use std::task::{self, Poll}; -/// Require a `Future` to complete before the specified duration has elapsed. +/// Requires a `Future` to complete before the specified duration has elapsed. /// /// If the future completes before the duration has elapsed, then the completed /// value is returned. Otherwise, an error is returned and the future is @@ -45,19 +48,22 @@ use std::task::{self, Poll}; /// } /// # } /// ``` +#[cfg_attr(tokio_track_caller, track_caller)] pub fn timeout<T>(duration: Duration, future: T) -> Timeout<T> where T: Future, { + let location = trace::caller_location(); + let deadline = Instant::now().checked_add(duration); let delay = match deadline { - Some(deadline) => Sleep::new_timeout(deadline), - None => Sleep::far_future(), + Some(deadline) => Sleep::new_timeout(deadline, location), + None => Sleep::far_future(location), }; Timeout::new_with_delay(future, delay) } -/// Require a `Future` to complete before the specified instant in time. +/// Requires a `Future` to complete before the specified instant in time. /// /// If the future completes before the instant is reached, then the completed /// value is returned. Otherwise, an error is returned. diff --git a/src/util/bit.rs b/src/util/bit.rs index 392a0e8..a43c2c2 100644 --- a/src/util/bit.rs +++ b/src/util/bit.rs @@ -27,7 +27,7 @@ impl Pack { pointer_width() - (self.mask >> self.shift).leading_zeros() } - /// Max representable value + /// Max representable value. pub(crate) const fn max_value(&self) -> usize { (1 << self.width()) - 1 } @@ -60,7 +60,7 @@ impl fmt::Debug for Pack { } } -/// Returns the width of a pointer in bits +/// Returns the width of a pointer in bits. pub(crate) const fn pointer_width() -> u32 { std::mem::size_of::<usize>() as u32 * 8 } @@ -71,7 +71,7 @@ pub(crate) const fn mask_for(n: u32) -> usize { shift | (shift - 1) } -/// Unpack a value using a mask & shift +/// Unpacks a value using a mask & shift. pub(crate) const fn unpack(src: usize, mask: usize, shift: u32) -> usize { (src & mask) >> shift } diff --git a/src/util/error.rs b/src/util/error.rs index 0e52364..8f252c0 100644 --- a/src/util/error.rs +++ b/src/util/error.rs @@ -7,3 +7,11 @@ pub(crate) const CONTEXT_MISSING_ERROR: &str = /// Error string explaining that the Tokio context is shutting down and cannot drive timers. pub(crate) const RUNTIME_SHUTTING_DOWN_ERROR: &str = "A Tokio 1.x context was found, but it is being shutdown."; + +// some combinations of features might not use this +#[allow(dead_code)] +/// Error string explaining that the Tokio context is not available because the +/// thread-local storing it has been destroyed. This usually only happens during +/// destructors of other thread-locals. +pub(crate) const THREAD_LOCAL_DESTROYED_ERROR: &str = + "The Tokio context thread-local variable has been destroyed."; diff --git a/src/util/linked_list.rs b/src/util/linked_list.rs index dd00e14..894d216 100644 --- a/src/util/linked_list.rs +++ b/src/util/linked_list.rs @@ -1,6 +1,6 @@ #![cfg_attr(not(feature = "full"), allow(dead_code))] -//! An intrusive double linked list of data +//! An intrusive double linked list of data. //! //! The data structure supports tracking pinned nodes. Most of the data //! structure's APIs are `unsafe` as they require the caller to ensure the @@ -46,10 +46,11 @@ pub(crate) unsafe trait Link { /// This is usually a pointer-ish type. type Handle; - /// Node type + /// Node type. type Target; - /// Convert the handle to a raw pointer without consuming the handle + /// Convert the handle to a raw pointer without consuming the handle. + #[allow(clippy::wrong_self_convention)] fn as_raw(handle: &Self::Handle) -> NonNull<Self::Target>; /// Convert the raw pointer to a handle @@ -59,7 +60,7 @@ pub(crate) unsafe trait Link { unsafe fn pointers(target: NonNull<Self::Target>) -> NonNull<Pointers<Self::Target>>; } -/// Previous / next pointers +/// Previous / next pointers. pub(crate) struct Pointers<T> { inner: UnsafeCell<PointersInner<T>>, } @@ -77,7 +78,7 @@ pub(crate) struct Pointers<T> { /// #[repr(C)]. /// /// See this link for more information: -/// https://github.com/rust-lang/rust/pull/82834 +/// <https://github.com/rust-lang/rust/pull/82834> #[repr(C)] struct PointersInner<T> { /// The previous node in the list. null if there is no previous node. @@ -93,7 +94,7 @@ struct PointersInner<T> { next: Option<NonNull<T>>, /// This type is !Unpin due to the heuristic from: - /// https://github.com/rust-lang/rust/pull/82834 + /// <https://github.com/rust-lang/rust/pull/82834> _pin: PhantomPinned, } @@ -235,37 +236,6 @@ impl<L: Link> Default for LinkedList<L, L::Target> { } } -// ===== impl Iter ===== - -cfg_rt_multi_thread! { - pub(crate) struct Iter<'a, T: Link> { - curr: Option<NonNull<T::Target>>, - _p: core::marker::PhantomData<&'a T>, - } - - impl<L: Link> LinkedList<L, L::Target> { - pub(crate) fn iter(&self) -> Iter<'_, L> { - Iter { - curr: self.head, - _p: core::marker::PhantomData, - } - } - } - - impl<'a, T: Link> Iterator for Iter<'a, T> { - type Item = &'a T::Target; - - fn next(&mut self) -> Option<&'a T::Target> { - let curr = self.curr?; - // safety: the pointer references data contained by the list - self.curr = unsafe { T::pointers(curr).as_ref() }.get_next(); - - // safety: the value is still owned by the linked list. - Some(unsafe { &*curr.as_ptr() }) - } - } -} - // ===== impl DrainFilter ===== cfg_io_readiness! { @@ -644,24 +614,6 @@ mod tests { } } - #[test] - fn iter() { - let a = entry(5); - let b = entry(7); - - let mut list = LinkedList::<&Entry, <&Entry as Link>::Target>::new(); - - assert_eq!(0, list.iter().count()); - - list.push_front(a.as_ref()); - list.push_front(b.as_ref()); - - let mut i = list.iter(); - assert_eq!(7, i.next().unwrap().val); - assert_eq!(5, i.next().unwrap().val); - assert!(i.next().is_none()); - } - proptest::proptest! { #[test] fn fuzz_linked_list(ops: Vec<usize>) { diff --git a/src/util/mod.rs b/src/util/mod.rs index b267125..df30f2b 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -4,6 +4,29 @@ cfg_io_driver! { } #[cfg(any( + // io driver uses `WakeList` directly + feature = "net", + feature = "process", + // `sync` enables `Notify` and `batch_semaphore`, which require `WakeList`. + feature = "sync", + // `fs` uses `batch_semaphore`, which requires `WakeList`. + feature = "fs", + // rt and signal use `Notify`, which requires `WakeList`. + feature = "rt", + feature = "signal", +))] +mod wake_list; +#[cfg(any( + feature = "net", + feature = "process", + feature = "sync", + feature = "fs", + feature = "rt", + feature = "signal", +))] +pub(crate) use wake_list::WakeList; + +#[cfg(any( feature = "fs", feature = "net", feature = "process", @@ -21,6 +44,12 @@ cfg_rt! { mod wake; pub(crate) use wake::WakerRef; pub(crate) use wake::{waker_ref, Wake}; + + mod sync_wrapper; + pub(crate) use sync_wrapper::SyncWrapper; + + mod vec_deque_cell; + pub(crate) use vec_deque_cell::VecDequeCell; } cfg_rt_multi_thread! { diff --git a/src/util/rand.rs b/src/util/rand.rs index 5660103..6b19c8b 100644 --- a/src/util/rand.rs +++ b/src/util/rand.rs @@ -1,12 +1,12 @@ use std::cell::Cell; -/// Fast random number generate +/// Fast random number generate. /// /// Implement xorshift64+: 2 32-bit xorshift sequences added together. -/// Shift triplet [17,7,16] was calculated as indicated in Marsaglia's -/// Xorshift paper: https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf +/// Shift triplet `[17,7,16]` was calculated as indicated in Marsaglia's +/// Xorshift paper: <https://www.jstatsoft.org/article/view/v008i14/xorshift.pdf> /// This generator passes the SmallCrush suite, part of TestU01 framework: -/// http://simul.iro.umontreal.ca/testu01/tu01.html +/// <http://simul.iro.umontreal.ca/testu01/tu01.html> #[derive(Debug)] pub(crate) struct FastRand { one: Cell<u32>, @@ -14,7 +14,7 @@ pub(crate) struct FastRand { } impl FastRand { - /// Initialize a new, thread-local, fast random number generator. + /// Initializes a new, thread-local, fast random number generator. pub(crate) fn new(seed: u64) -> FastRand { let one = (seed >> 32) as u32; let mut two = seed as u32; diff --git a/src/util/slab.rs b/src/util/slab.rs index efc72e1..97355d5 100644 --- a/src/util/slab.rs +++ b/src/util/slab.rs @@ -85,11 +85,11 @@ pub(crate) struct Address(usize); /// An entry in the slab. pub(crate) trait Entry: Default { - /// Reset the entry's value and track the generation. + /// Resets the entry's value and track the generation. fn reset(&self); } -/// A reference to a value stored in the slab +/// A reference to a value stored in the slab. pub(crate) struct Ref<T> { value: *const Value<T>, } @@ -101,9 +101,9 @@ const NUM_PAGES: usize = 19; const PAGE_INITIAL_SIZE: usize = 32; const PAGE_INDEX_SHIFT: u32 = PAGE_INITIAL_SIZE.trailing_zeros() + 1; -/// A page in the slab +/// A page in the slab. struct Page<T> { - /// Slots + /// Slots. slots: Mutex<Slots<T>>, // Number of slots currently being used. This is not guaranteed to be up to @@ -116,7 +116,7 @@ struct Page<T> { // The number of slots the page can hold. len: usize, - // Length of all previous pages combined + // Length of all previous pages combined. prev_len: usize, } @@ -128,9 +128,9 @@ struct CachedPage<T> { init: usize, } -/// Page state +/// Page state. struct Slots<T> { - /// Slots + /// Slots. slots: Vec<Slot<T>>, head: usize, @@ -159,9 +159,9 @@ struct Slot<T> { next: u32, } -/// Value paired with a reference to the page +/// Value paired with a reference to the page. struct Value<T> { - /// Value stored in the value + /// Value stored in the value. value: T, /// Pointer to the page containing the slot. @@ -171,7 +171,7 @@ struct Value<T> { } impl<T> Slab<T> { - /// Create a new, empty, slab + /// Create a new, empty, slab. pub(crate) fn new() -> Slab<T> { // Initializing arrays is a bit annoying. Instead of manually writing // out an array and every single entry, `Default::default()` is used to @@ -296,7 +296,7 @@ impl<T> Slab<T> { // Remove the slots vector from the page. This is done so that the // freeing process is done outside of the lock's critical section. - let vec = mem::replace(&mut slots.slots, vec![]); + let vec = mem::take(&mut slots.slots); slots.head = 0; // Drop the lock so we can drop the vector outside the lock below. @@ -455,7 +455,7 @@ impl<T> Page<T> { addr.0 - self.prev_len } - /// Returns the address for the given slot + /// Returns the address for the given slot. fn addr(&self, slot: usize) -> Address { Address(slot + self.prev_len) } @@ -478,7 +478,7 @@ impl<T> Default for Page<T> { } impl<T> Page<T> { - /// Release a slot into the page's free list + /// Release a slot into the page's free list. fn release(&self, value: *const Value<T>) { let mut locked = self.slots.lock(); @@ -492,7 +492,7 @@ impl<T> Page<T> { } impl<T> CachedPage<T> { - /// Refresh the cache + /// Refreshes the cache. fn refresh(&mut self, page: &Page<T>) { let slots = page.slots.lock(); @@ -502,7 +502,7 @@ impl<T> CachedPage<T> { } } - // Get a value by index + /// Gets a value by index. fn get(&self, idx: usize) -> &T { assert!(idx < self.init); @@ -576,7 +576,7 @@ impl<T: Entry> Slot<T> { } impl<T> Value<T> { - // Release the slot, returning the `Arc<Page<T>>` logically owned by the ref. + /// Releases the slot, returning the `Arc<Page<T>>` logically owned by the ref. fn release(&self) -> Arc<Page<T>> { // Safety: called by `Ref`, which owns an `Arc<Page<T>>` instance. let page = unsafe { Arc::from_raw(self.page) }; diff --git a/src/util/sync_wrapper.rs b/src/util/sync_wrapper.rs new file mode 100644 index 0000000..5ffc8f9 --- /dev/null +++ b/src/util/sync_wrapper.rs @@ -0,0 +1,26 @@ +//! This module contains a type that can make `Send + !Sync` types `Sync` by +//! disallowing all immutable access to the value. +//! +//! A similar primitive is provided in the `sync_wrapper` crate. + +pub(crate) struct SyncWrapper<T> { + value: T, +} + +// safety: The SyncWrapper being send allows you to send the inner value across +// thread boundaries. +unsafe impl<T: Send> Send for SyncWrapper<T> {} + +// safety: An immutable reference to a SyncWrapper is useless, so moving such an +// immutable reference across threads is safe. +unsafe impl<T> Sync for SyncWrapper<T> {} + +impl<T> SyncWrapper<T> { + pub(crate) fn new(value: T) -> Self { + Self { value } + } + + pub(crate) fn into_inner(self) -> T { + self.value + } +} diff --git a/src/util/trace.rs b/src/util/trace.rs index 96a9db9..e3c26f9 100644 --- a/src/util/trace.rs +++ b/src/util/trace.rs @@ -4,32 +4,45 @@ cfg_trace! { #[inline] #[cfg_attr(tokio_track_caller, track_caller)] - pub(crate) fn task<F>(task: F, kind: &'static str) -> Instrumented<F> { + pub(crate) fn task<F>(task: F, kind: &'static str, name: Option<&str>) -> Instrumented<F> { use tracing::instrument::Instrument; #[cfg(tokio_track_caller)] let location = std::panic::Location::caller(); #[cfg(tokio_track_caller)] let span = tracing::trace_span!( target: "tokio::task", - "task", + "runtime.spawn", %kind, - spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), + task.name = %name.unwrap_or_default(), + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), ); #[cfg(not(tokio_track_caller))] let span = tracing::trace_span!( target: "tokio::task", - "task", + "runtime.spawn", %kind, + task.name = %name.unwrap_or_default(), ); task.instrument(span) } } } +cfg_time! { + #[cfg_attr(tokio_track_caller, track_caller)] + pub(crate) fn caller_location() -> Option<&'static std::panic::Location<'static>> { + #[cfg(all(tokio_track_caller, tokio_unstable, feature = "tracing"))] + return Some(std::panic::Location::caller()); + #[cfg(not(all(tokio_track_caller, tokio_unstable, feature = "tracing")))] + None + } +} cfg_not_trace! { cfg_rt! { #[inline] - pub(crate) fn task<F>(task: F, _: &'static str) -> F { + pub(crate) fn task<F>(task: F, _: &'static str, _name: Option<&str>) -> F { // nop task } diff --git a/src/util/vec_deque_cell.rs b/src/util/vec_deque_cell.rs new file mode 100644 index 0000000..b4e124c --- /dev/null +++ b/src/util/vec_deque_cell.rs @@ -0,0 +1,53 @@ +use crate::loom::cell::UnsafeCell; + +use std::collections::VecDeque; +use std::marker::PhantomData; + +/// This type is like VecDeque, except that it is not Sync and can be modified +/// through immutable references. +pub(crate) struct VecDequeCell<T> { + inner: UnsafeCell<VecDeque<T>>, + _not_sync: PhantomData<*const ()>, +} + +// This is Send for the same reasons that RefCell<VecDeque<T>> is Send. +unsafe impl<T: Send> Send for VecDequeCell<T> {} + +impl<T> VecDequeCell<T> { + pub(crate) fn with_capacity(cap: usize) -> Self { + Self { + inner: UnsafeCell::new(VecDeque::with_capacity(cap)), + _not_sync: PhantomData, + } + } + + /// Safety: This method may not be called recursively. + #[inline] + unsafe fn with_inner<F, R>(&self, f: F) -> R + where + F: FnOnce(&mut VecDeque<T>) -> R, + { + // safety: This type is not Sync, so concurrent calls of this method + // cannot happen. Furthermore, the caller guarantees that the method is + // not called recursively. Finally, this is the only place that can + // create mutable references to the inner VecDeque. This ensures that + // any mutable references created here are exclusive. + self.inner.with_mut(|ptr| f(&mut *ptr)) + } + + pub(crate) fn pop_front(&self) -> Option<T> { + unsafe { self.with_inner(VecDeque::pop_front) } + } + + pub(crate) fn push_back(&self, item: T) { + unsafe { + self.with_inner(|inner| inner.push_back(item)); + } + } + + /// Replaces the inner VecDeque with an empty VecDeque and return the current + /// contents. + pub(crate) fn take(&self) -> VecDeque<T> { + unsafe { self.with_inner(|inner| std::mem::take(inner)) } + } +} diff --git a/src/util/wake.rs b/src/util/wake.rs index 001577d..8f89668 100644 --- a/src/util/wake.rs +++ b/src/util/wake.rs @@ -4,12 +4,12 @@ use std::ops::Deref; use std::sync::Arc; use std::task::{RawWaker, RawWakerVTable, Waker}; -/// Simplified waking interface based on Arcs +/// Simplified waking interface based on Arcs. pub(crate) trait Wake: Send + Sync { - /// Wake by value + /// Wake by value. fn wake(self: Arc<Self>); - /// Wake by reference + /// Wake by reference. fn wake_by_ref(arc_self: &Arc<Self>); } @@ -54,11 +54,7 @@ unsafe fn inc_ref_count<T: Wake>(data: *const ()) { let arc = ManuallyDrop::new(Arc::<T>::from_raw(data as *const T)); // Now increase refcount, but don't drop new refcount either - let arc_clone: ManuallyDrop<_> = arc.clone(); - - // Drop explicitly to avoid clippy warnings - drop(arc); - drop(arc_clone); + let _arc_clone: ManuallyDrop<_> = arc.clone(); } unsafe fn clone_arc_raw<T: Wake>(data: *const ()) -> RawWaker { diff --git a/src/util/wake_list.rs b/src/util/wake_list.rs new file mode 100644 index 0000000..aa569dd --- /dev/null +++ b/src/util/wake_list.rs @@ -0,0 +1,53 @@ +use core::mem::MaybeUninit; +use core::ptr; +use std::task::Waker; + +const NUM_WAKERS: usize = 32; + +pub(crate) struct WakeList { + inner: [MaybeUninit<Waker>; NUM_WAKERS], + curr: usize, +} + +impl WakeList { + pub(crate) fn new() -> Self { + Self { + inner: unsafe { + // safety: Create an uninitialized array of `MaybeUninit`. The + // `assume_init` is safe because the type we are claiming to + // have initialized here is a bunch of `MaybeUninit`s, which do + // not require initialization. + MaybeUninit::uninit().assume_init() + }, + curr: 0, + } + } + + #[inline] + pub(crate) fn can_push(&self) -> bool { + self.curr < NUM_WAKERS + } + + pub(crate) fn push(&mut self, val: Waker) { + debug_assert!(self.can_push()); + + self.inner[self.curr] = MaybeUninit::new(val); + self.curr += 1; + } + + pub(crate) fn wake_all(&mut self) { + assert!(self.curr <= NUM_WAKERS); + while self.curr > 0 { + self.curr -= 1; + let waker = unsafe { ptr::read(self.inner[self.curr].as_mut_ptr()) }; + waker.wake(); + } + } +} + +impl Drop for WakeList { + fn drop(&mut self) { + let slice = ptr::slice_from_raw_parts_mut(self.inner.as_mut_ptr() as *mut Waker, self.curr); + unsafe { ptr::drop_in_place(slice) }; + } +} diff --git a/tests/async_send_sync.rs b/tests/async_send_sync.rs index 211c572..aa14970 100644 --- a/tests/async_send_sync.rs +++ b/tests/async_send_sync.rs @@ -1,16 +1,33 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -#![allow(clippy::type_complexity)] +#![allow(clippy::type_complexity, clippy::diverging_sub_expression)] use std::cell::Cell; use std::future::Future; -use std::io::{Cursor, SeekFrom}; +use std::io::SeekFrom; use std::net::SocketAddr; use std::pin::Pin; use std::rc::Rc; use tokio::net::TcpStream; use tokio::time::{Duration, Instant}; +// The names of these structs behaves better when sorted. +// Send: Yes, Sync: Yes +#[derive(Clone)] +struct YY {} + +// Send: Yes, Sync: No +#[derive(Clone)] +struct YN { + _value: Cell<u8>, +} + +// Send: No, Sync: No +#[derive(Clone)] +struct NN { + _value: Rc<u8>, +} + #[allow(dead_code)] type BoxFutureSync<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + Sync>>; #[allow(dead_code)] @@ -19,11 +36,11 @@ type BoxFutureSend<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + type BoxFuture<T> = std::pin::Pin<Box<dyn std::future::Future<Output = T>>>; #[allow(dead_code)] -type BoxAsyncRead = std::pin::Pin<Box<dyn tokio::io::AsyncBufRead>>; +type BoxAsyncRead = std::pin::Pin<Box<dyn tokio::io::AsyncBufRead + Send + Sync>>; #[allow(dead_code)] -type BoxAsyncSeek = std::pin::Pin<Box<dyn tokio::io::AsyncSeek>>; +type BoxAsyncSeek = std::pin::Pin<Box<dyn tokio::io::AsyncSeek + Send + Sync>>; #[allow(dead_code)] -type BoxAsyncWrite = std::pin::Pin<Box<dyn tokio::io::AsyncWrite>>; +type BoxAsyncWrite = std::pin::Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>; #[allow(dead_code)] fn require_send<T: Send>(_t: &T) {} @@ -59,310 +76,594 @@ macro_rules! into_todo { x }}; } -macro_rules! assert_value { - ($type:ty: Send & Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - require_send(&f); - require_sync(&f); - }; - }; - ($type:ty: !Send & Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - AmbiguousIfSend::some_item(&f); - require_sync(&f); - }; - }; - ($type:ty: Send & !Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - require_send(&f); - AmbiguousIfSync::some_item(&f); - }; + +macro_rules! async_assert_fn_send { + (Send & $(!)?Sync & $(!)?Unpin, $value:expr) => { + require_send(&$value); }; - ($type:ty: !Send & !Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - AmbiguousIfSend::some_item(&f); - AmbiguousIfSync::some_item(&f); - }; - }; - ($type:ty: Unpin) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f: $type = todo!(); - require_unpin(&f); - }; + (!Send & $(!)?Sync & $(!)?Unpin, $value:expr) => { + AmbiguousIfSend::some_item(&$value); }; } -macro_rules! async_assert_fn { - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Send & Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - require_send(&f); - require_sync(&f); - }; +macro_rules! async_assert_fn_sync { + ($(!)?Send & Sync & $(!)?Unpin, $value:expr) => { + require_sync(&$value); }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Send & !Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - require_send(&f); - AmbiguousIfSync::some_item(&f); - }; + ($(!)?Send & !Sync & $(!)?Unpin, $value:expr) => { + AmbiguousIfSync::some_item(&$value); }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Send & Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - AmbiguousIfSend::some_item(&f); - require_sync(&f); - }; +} +macro_rules! async_assert_fn_unpin { + ($(!)?Send & $(!)?Sync & Unpin, $value:expr) => { + require_unpin(&$value); }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Send & !Sync) => { - #[allow(unreachable_code)] - #[allow(unused_variables)] - const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - AmbiguousIfSend::some_item(&f); - AmbiguousIfSync::some_item(&f); - }; + ($(!)?Send & $(!)?Sync & !Unpin, $value:expr) => { + AmbiguousIfUnpin::some_item(&$value); }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): !Unpin) => { +} + +macro_rules! async_assert_fn { + ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): $($tok:tt)*) => { #[allow(unreachable_code)] #[allow(unused_variables)] const _: fn() = || { let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - AmbiguousIfUnpin::some_item(&f); + async_assert_fn_send!($($tok)*, f); + async_assert_fn_sync!($($tok)*, f); + async_assert_fn_unpin!($($tok)*, f); }; }; - ($($f:ident $(< $($generic:ty),* > )? )::+($($arg:ty),*): Unpin) => { +} +macro_rules! assert_value { + ($type:ty: $($tok:tt)*) => { #[allow(unreachable_code)] #[allow(unused_variables)] const _: fn() = || { - let f = $($f $(::<$($generic),*>)? )::+( $( into_todo!($arg) ),* ); - require_unpin(&f); + let f: $type = todo!(); + async_assert_fn_send!($($tok)*, f); + async_assert_fn_sync!($($tok)*, f); + async_assert_fn_unpin!($($tok)*, f); }; }; } -async_assert_fn!(tokio::io::copy(&mut TcpStream, &mut TcpStream): Send & Sync); -async_assert_fn!(tokio::io::empty(): Send & Sync); -async_assert_fn!(tokio::io::repeat(u8): Send & Sync); -async_assert_fn!(tokio::io::sink(): Send & Sync); -async_assert_fn!(tokio::io::split(TcpStream): Send & Sync); -async_assert_fn!(tokio::io::stderr(): Send & Sync); -async_assert_fn!(tokio::io::stdin(): Send & Sync); -async_assert_fn!(tokio::io::stdout(): Send & Sync); -async_assert_fn!(tokio::io::Split<Cursor<Vec<u8>>>::next_segment(_): Send & Sync); - -async_assert_fn!(tokio::fs::canonicalize(&str): Send & Sync); -async_assert_fn!(tokio::fs::copy(&str, &str): Send & Sync); -async_assert_fn!(tokio::fs::create_dir(&str): Send & Sync); -async_assert_fn!(tokio::fs::create_dir_all(&str): Send & Sync); -async_assert_fn!(tokio::fs::hard_link(&str, &str): Send & Sync); -async_assert_fn!(tokio::fs::metadata(&str): Send & Sync); -async_assert_fn!(tokio::fs::read(&str): Send & Sync); -async_assert_fn!(tokio::fs::read_dir(&str): Send & Sync); -async_assert_fn!(tokio::fs::read_link(&str): Send & Sync); -async_assert_fn!(tokio::fs::read_to_string(&str): Send & Sync); -async_assert_fn!(tokio::fs::remove_dir(&str): Send & Sync); -async_assert_fn!(tokio::fs::remove_dir_all(&str): Send & Sync); -async_assert_fn!(tokio::fs::remove_file(&str): Send & Sync); -async_assert_fn!(tokio::fs::rename(&str, &str): Send & Sync); -async_assert_fn!(tokio::fs::set_permissions(&str, std::fs::Permissions): Send & Sync); -async_assert_fn!(tokio::fs::symlink_metadata(&str): Send & Sync); -async_assert_fn!(tokio::fs::write(&str, Vec<u8>): Send & Sync); -async_assert_fn!(tokio::fs::ReadDir::next_entry(_): Send & Sync); -async_assert_fn!(tokio::fs::OpenOptions::open(_, &str): Send & Sync); -async_assert_fn!(tokio::fs::DirEntry::metadata(_): Send & Sync); -async_assert_fn!(tokio::fs::DirEntry::file_type(_): Send & Sync); +assert_value!(tokio::fs::DirBuilder: Send & Sync & Unpin); +assert_value!(tokio::fs::DirEntry: Send & Sync & Unpin); +assert_value!(tokio::fs::File: Send & Sync & Unpin); +assert_value!(tokio::fs::OpenOptions: Send & Sync & Unpin); +assert_value!(tokio::fs::ReadDir: Send & Sync & Unpin); -async_assert_fn!(tokio::fs::File::open(&str): Send & Sync); -async_assert_fn!(tokio::fs::File::create(&str): Send & Sync); -async_assert_fn!(tokio::fs::File::sync_all(_): Send & Sync); -async_assert_fn!(tokio::fs::File::sync_data(_): Send & Sync); -async_assert_fn!(tokio::fs::File::set_len(_, u64): Send & Sync); -async_assert_fn!(tokio::fs::File::metadata(_): Send & Sync); -async_assert_fn!(tokio::fs::File::try_clone(_): Send & Sync); -async_assert_fn!(tokio::fs::File::into_std(_): Send & Sync); -async_assert_fn!(tokio::fs::File::set_permissions(_, std::fs::Permissions): Send & Sync); +async_assert_fn!(tokio::fs::canonicalize(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::copy(&str, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::create_dir(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::create_dir_all(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::hard_link(&str, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::metadata(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::read(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::read_dir(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::read_link(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::read_to_string(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::remove_dir(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::remove_dir_all(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::remove_file(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::rename(&str, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::set_permissions(&str, std::fs::Permissions): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::symlink_metadata(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::write(&str, Vec<u8>): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::ReadDir::next_entry(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::OpenOptions::open(_, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::DirBuilder::create(_, &str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::DirEntry::metadata(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::DirEntry::file_type(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::open(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::create(&str): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::sync_all(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::sync_data(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::set_len(_, u64): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::metadata(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::try_clone(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::into_std(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::fs::File::set_permissions(_, std::fs::Permissions): Send & Sync & !Unpin); -async_assert_fn!(tokio::net::lookup_host(SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::TcpListener::bind(SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::TcpListener::accept(_): Send & Sync); -async_assert_fn!(tokio::net::TcpStream::connect(SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::TcpStream::peek(_, &mut [u8]): Send & Sync); -async_assert_fn!(tokio::net::tcp::ReadHalf::peek(_, &mut [u8]): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::bind(SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::connect(_, SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::send(_, &[u8]): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::recv(_, &mut [u8]): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::send_to(_, &[u8], SocketAddr): Send & Sync); -async_assert_fn!(tokio::net::UdpSocket::recv_from(_, &mut [u8]): Send & Sync); +assert_value!(tokio::net::TcpListener: Send & Sync & Unpin); +assert_value!(tokio::net::TcpSocket: Send & Sync & Unpin); +assert_value!(tokio::net::TcpStream: Send & Sync & Unpin); +assert_value!(tokio::net::UdpSocket: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::OwnedReadHalf: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::OwnedWriteHalf: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::ReadHalf<'_>: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::ReuniteError: Send & Sync & Unpin); +assert_value!(tokio::net::tcp::WriteHalf<'_>: Send & Sync & Unpin); +async_assert_fn!(tokio::net::TcpListener::accept(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpListener::bind(SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::connect(SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::peek(_, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::readable(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::ready(_, tokio::io::Interest): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::TcpStream::writable(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::bind(SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::connect(_, SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::peek_from(_, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::readable(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::ready(_, tokio::io::Interest): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::recv(_, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::recv_from(_, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::send(_, &[u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::send_to(_, &[u8], SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::UdpSocket::writable(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::lookup_host(SocketAddr): Send & Sync & !Unpin); +async_assert_fn!(tokio::net::tcp::ReadHalf::peek(_, &mut [u8]): Send & Sync & !Unpin); #[cfg(unix)] mod unix_datagram { use super::*; - async_assert_fn!(tokio::net::UnixListener::bind(&str): Send & Sync); - async_assert_fn!(tokio::net::UnixListener::accept(_): Send & Sync); - async_assert_fn!(tokio::net::UnixDatagram::send(_, &[u8]): Send & Sync); - async_assert_fn!(tokio::net::UnixDatagram::recv(_, &mut [u8]): Send & Sync); - async_assert_fn!(tokio::net::UnixDatagram::send_to(_, &[u8], &str): Send & Sync); - async_assert_fn!(tokio::net::UnixDatagram::recv_from(_, &mut [u8]): Send & Sync); - async_assert_fn!(tokio::net::UnixStream::connect(&str): Send & Sync); + use tokio::net::*; + assert_value!(UnixDatagram: Send & Sync & Unpin); + assert_value!(UnixListener: Send & Sync & Unpin); + assert_value!(UnixStream: Send & Sync & Unpin); + assert_value!(unix::OwnedReadHalf: Send & Sync & Unpin); + assert_value!(unix::OwnedWriteHalf: Send & Sync & Unpin); + assert_value!(unix::ReadHalf<'_>: Send & Sync & Unpin); + assert_value!(unix::ReuniteError: Send & Sync & Unpin); + assert_value!(unix::SocketAddr: Send & Sync & Unpin); + assert_value!(unix::UCred: Send & Sync & Unpin); + assert_value!(unix::WriteHalf<'_>: Send & Sync & Unpin); + async_assert_fn!(UnixDatagram::readable(_): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::ready(_, tokio::io::Interest): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::recv(_, &mut [u8]): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::recv_from(_, &mut [u8]): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::send(_, &[u8]): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::send_to(_, &[u8], &str): Send & Sync & !Unpin); + async_assert_fn!(UnixDatagram::writable(_): Send & Sync & !Unpin); + async_assert_fn!(UnixListener::accept(_): Send & Sync & !Unpin); + async_assert_fn!(UnixStream::connect(&str): Send & Sync & !Unpin); + async_assert_fn!(UnixStream::readable(_): Send & Sync & !Unpin); + async_assert_fn!(UnixStream::ready(_, tokio::io::Interest): Send & Sync & !Unpin); + async_assert_fn!(UnixStream::writable(_): Send & Sync & !Unpin); } -async_assert_fn!(tokio::process::Child::wait_with_output(_): Send & Sync); -async_assert_fn!(tokio::signal::ctrl_c(): Send & Sync); -#[cfg(unix)] -async_assert_fn!(tokio::signal::unix::Signal::recv(_): Send & Sync); +#[cfg(windows)] +mod windows_named_pipe { + use super::*; + use tokio::net::windows::named_pipe::*; + assert_value!(ClientOptions: Send & Sync & Unpin); + assert_value!(NamedPipeClient: Send & Sync & Unpin); + assert_value!(NamedPipeServer: Send & Sync & Unpin); + assert_value!(PipeEnd: Send & Sync & Unpin); + assert_value!(PipeInfo: Send & Sync & Unpin); + assert_value!(PipeMode: Send & Sync & Unpin); + assert_value!(ServerOptions: Send & Sync & Unpin); + async_assert_fn!(NamedPipeClient::readable(_): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeClient::ready(_, tokio::io::Interest): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeClient::writable(_): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeServer::connect(_): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeServer::readable(_): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeServer::ready(_, tokio::io::Interest): Send & Sync & !Unpin); + async_assert_fn!(NamedPipeServer::writable(_): Send & Sync & !Unpin); +} -async_assert_fn!(tokio::sync::Barrier::wait(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<u8>::lock(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<Cell<u8>>::lock(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<Rc<u8>>::lock(_): !Send & !Sync); -async_assert_fn!(tokio::sync::Mutex<u8>::lock_owned(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<Cell<u8>>::lock_owned(_): Send & Sync); -async_assert_fn!(tokio::sync::Mutex<Rc<u8>>::lock_owned(_): !Send & !Sync); -async_assert_fn!(tokio::sync::Notify::notified(_): Send & Sync); -async_assert_fn!(tokio::sync::RwLock<u8>::read(_): Send & Sync); -async_assert_fn!(tokio::sync::RwLock<u8>::write(_): Send & Sync); -async_assert_fn!(tokio::sync::RwLock<Cell<u8>>::read(_): !Send & !Sync); -async_assert_fn!(tokio::sync::RwLock<Cell<u8>>::write(_): !Send & !Sync); -async_assert_fn!(tokio::sync::RwLock<Rc<u8>>::read(_): !Send & !Sync); -async_assert_fn!(tokio::sync::RwLock<Rc<u8>>::write(_): !Send & !Sync); -async_assert_fn!(tokio::sync::Semaphore::acquire(_): Send & Sync); +assert_value!(tokio::process::Child: Send & Sync & Unpin); +assert_value!(tokio::process::ChildStderr: Send & Sync & Unpin); +assert_value!(tokio::process::ChildStdin: Send & Sync & Unpin); +assert_value!(tokio::process::ChildStdout: Send & Sync & Unpin); +assert_value!(tokio::process::Command: Send & Sync & Unpin); +async_assert_fn!(tokio::process::Child::kill(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::process::Child::wait(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::process::Child::wait_with_output(_): Send & Sync & !Unpin); -async_assert_fn!(tokio::sync::broadcast::Receiver<u8>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::broadcast::Receiver<Cell<u8>>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::broadcast::Receiver<Rc<u8>>::recv(_): !Send & !Sync); +async_assert_fn!(tokio::signal::ctrl_c(): Send & Sync & !Unpin); +#[cfg(unix)] +mod unix_signal { + use super::*; + assert_value!(tokio::signal::unix::Signal: Send & Sync & Unpin); + assert_value!(tokio::signal::unix::SignalKind: Send & Sync & Unpin); + async_assert_fn!(tokio::signal::unix::Signal::recv(_): Send & Sync & !Unpin); +} +#[cfg(windows)] +mod windows_signal { + use super::*; + assert_value!(tokio::signal::windows::CtrlC: Send & Sync & Unpin); + assert_value!(tokio::signal::windows::CtrlBreak: Send & Sync & Unpin); + async_assert_fn!(tokio::signal::windows::CtrlC::recv(_): Send & Sync & !Unpin); + async_assert_fn!(tokio::signal::windows::CtrlBreak::recv(_): Send & Sync & !Unpin); +} -async_assert_fn!(tokio::sync::mpsc::Receiver<u8>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::Receiver<Cell<u8>>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::Receiver<Rc<u8>>::recv(_): !Send & !Sync); -async_assert_fn!(tokio::sync::mpsc::Sender<u8>::send(_, u8): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::Sender<Cell<u8>>::send(_, Cell<u8>): Send & !Sync); -async_assert_fn!(tokio::sync::mpsc::Sender<Rc<u8>>::send(_, Rc<u8>): !Send & !Sync); +assert_value!(tokio::sync::AcquireError: Send & Sync & Unpin); +assert_value!(tokio::sync::Barrier: Send & Sync & Unpin); +assert_value!(tokio::sync::BarrierWaitResult: Send & Sync & Unpin); +assert_value!(tokio::sync::MappedMutexGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::MappedMutexGuard<'_, YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::MappedMutexGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::Mutex<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::Mutex<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::Mutex<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::MutexGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::MutexGuard<'_, YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::MutexGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::Notify: Send & Sync & Unpin); +assert_value!(tokio::sync::OnceCell<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OnceCell<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::OnceCell<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedMutexGuard<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedMutexGuard<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedMutexGuard<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockMappedWriteGuard<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockMappedWriteGuard<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockMappedWriteGuard<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockReadGuard<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockReadGuard<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockReadGuard<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockWriteGuard<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockWriteGuard<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::OwnedRwLockWriteGuard<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::OwnedSemaphorePermit: Send & Sync & Unpin); +assert_value!(tokio::sync::RwLock<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLock<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLock<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::RwLockMappedWriteGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockMappedWriteGuard<'_, YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockMappedWriteGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::RwLockReadGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockReadGuard<'_, YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockReadGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::RwLockWriteGuard<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockWriteGuard<'_, YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::RwLockWriteGuard<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::Semaphore: Send & Sync & Unpin); +assert_value!(tokio::sync::SemaphorePermit<'_>: Send & Sync & Unpin); +assert_value!(tokio::sync::TryAcquireError: Send & Sync & Unpin); +assert_value!(tokio::sync::TryLockError: Send & Sync & Unpin); +assert_value!(tokio::sync::broadcast::Receiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::broadcast::Receiver<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::broadcast::Receiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::broadcast::Sender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::broadcast::Sender<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::broadcast::Sender<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::futures::Notified<'_>: Send & Sync & !Unpin); +assert_value!(tokio::sync::mpsc::OwnedPermit<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::OwnedPermit<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::OwnedPermit<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Permit<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::Permit<'_, YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Permit<'_, YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Receiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::Receiver<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Receiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Sender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::Sender<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::Sender<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedReceiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedReceiver<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedReceiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedSender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedSender<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::UnboundedSender<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendError<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendError<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendError<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendTimeoutError<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendTimeoutError<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::SendTimeoutError<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::TrySendError<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::TrySendError<YN>: Send & !Sync & Unpin); +assert_value!(tokio::sync::mpsc::error::TrySendError<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::oneshot::Receiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::oneshot::Receiver<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::oneshot::Receiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::oneshot::Sender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::oneshot::Sender<YN>: Send & Sync & Unpin); +assert_value!(tokio::sync::oneshot::Sender<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::watch::Receiver<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Receiver<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Receiver<YY>: Send & Sync & Unpin); +assert_value!(tokio::sync::watch::Ref<'_, NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Ref<'_, YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Ref<'_, YY>: !Send & Sync & Unpin); +assert_value!(tokio::sync::watch::Sender<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Sender<YN>: !Send & !Sync & Unpin); +assert_value!(tokio::sync::watch::Sender<YY>: Send & Sync & Unpin); +async_assert_fn!(tokio::sync::Barrier::wait(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<NN>::lock(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<NN>::lock_owned(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<YN>::lock(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<YN>::lock_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<YY>::lock(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Mutex<YY>::lock_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Notify::notified(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = NN> + Send + Sync>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = NN> + Send>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = NN>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<NN>> + Send + Sync>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<NN>> + Send>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<NN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<NN>>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YN> + Send + Sync>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YN> + Send>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YN>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YN>> + Send + Sync>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YN>> + Send>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YN>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YN>>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YY> + Send + Sync>>): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YY> + Send>>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_init( _, fn() -> Pin<Box<dyn Future<Output = YY>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YY>> + Send + Sync>>): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YY>> + Send>>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::OnceCell<YY>::get_or_try_init( _, fn() -> Pin<Box<dyn Future<Output = std::io::Result<YY>>>>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<NN>::read(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<NN>::write(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<YN>::read(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<YN>::write(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<YY>::read(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::RwLock<YY>::write(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Semaphore::acquire(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Semaphore::acquire_many(_, u32): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Semaphore::acquire_many_owned(_, u32): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::Semaphore::acquire_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::broadcast::Receiver<NN>::recv(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::broadcast::Receiver<YN>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::broadcast::Receiver<YY>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Receiver<NN>::recv(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Receiver<YN>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Receiver<YY>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::reserve(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::reserve_owned(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::send(_, NN): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<NN>::send_timeout(_, NN, Duration): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::reserve(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::reserve_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::send(_, YN): Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YN>::send_timeout(_, YN, Duration): Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::reserve(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::reserve_owned(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::send(_, YY): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::Sender<YY>::send_timeout(_, YY, Duration): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<NN>::recv(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<YN>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<YY>::recv(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedSender<NN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedSender<YN>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::mpsc::UnboundedSender<YY>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::oneshot::Sender<NN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::oneshot::Sender<YN>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::oneshot::Sender<YY>::closed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Receiver<NN>::changed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Receiver<YN>::changed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Receiver<YY>::changed(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Sender<NN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Sender<YN>::closed(_): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::sync::watch::Sender<YY>::closed(_): Send & Sync & !Unpin); -async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<u8>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<Cell<u8>>::recv(_): Send & Sync); -async_assert_fn!(tokio::sync::mpsc::UnboundedReceiver<Rc<u8>>::recv(_): !Send & !Sync); +async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSync<()>): Send & Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSync<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSync<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSend<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::LocalSet::run_until(_, BoxFutureSync<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::task::unconstrained(BoxFuture<()>): !Send & !Sync & Unpin); +async_assert_fn!(tokio::task::unconstrained(BoxFutureSend<()>): Send & !Sync & Unpin); +async_assert_fn!(tokio::task::unconstrained(BoxFutureSync<()>): Send & Sync & Unpin); +assert_value!(tokio::task::LocalSet: !Send & !Sync & Unpin); +assert_value!(tokio::task::JoinHandle<YY>: Send & Sync & Unpin); +assert_value!(tokio::task::JoinHandle<YN>: Send & Sync & Unpin); +assert_value!(tokio::task::JoinHandle<NN>: !Send & !Sync & Unpin); +assert_value!(tokio::task::JoinError: Send & Sync & Unpin); -async_assert_fn!(tokio::sync::watch::Receiver<u8>::changed(_): Send & Sync); -async_assert_fn!(tokio::sync::watch::Sender<u8>::closed(_): Send & Sync); -async_assert_fn!(tokio::sync::watch::Sender<Cell<u8>>::closed(_): !Send & !Sync); -async_assert_fn!(tokio::sync::watch::Sender<Rc<u8>>::closed(_): !Send & !Sync); +assert_value!(tokio::runtime::Builder: Send & Sync & Unpin); +assert_value!(tokio::runtime::EnterGuard<'_>: Send & Sync & Unpin); +assert_value!(tokio::runtime::Handle: Send & Sync & Unpin); +assert_value!(tokio::runtime::Runtime: Send & Sync & Unpin); -async_assert_fn!(tokio::sync::OnceCell<u8>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = u8> + Send + Sync>>): Send & Sync); -async_assert_fn!(tokio::sync::OnceCell<u8>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = u8> + Send>>): Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<u8>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = u8>>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Cell<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Cell<u8>> + Send + Sync>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Cell<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Cell<u8>> + Send>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Cell<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Cell<u8>>>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Rc<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Rc<u8>> + Send + Sync>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Rc<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Rc<u8>> + Send>>): !Send & !Sync); -async_assert_fn!(tokio::sync::OnceCell<Rc<u8>>::get_or_init( - _, fn() -> Pin<Box<dyn Future<Output = Rc<u8>>>>): !Send & !Sync); -assert_value!(tokio::sync::OnceCell<u8>: Send & Sync); -assert_value!(tokio::sync::OnceCell<Cell<u8>>: Send & !Sync); -assert_value!(tokio::sync::OnceCell<Rc<u8>>: !Send & !Sync); +assert_value!(tokio::time::Interval: Send & Sync & Unpin); +assert_value!(tokio::time::Instant: Send & Sync & Unpin); +assert_value!(tokio::time::Sleep: Send & Sync & !Unpin); +assert_value!(tokio::time::Timeout<BoxFutureSync<()>>: Send & Sync & !Unpin); +assert_value!(tokio::time::Timeout<BoxFutureSend<()>>: Send & !Sync & !Unpin); +assert_value!(tokio::time::Timeout<BoxFuture<()>>: !Send & !Sync & !Unpin); +assert_value!(tokio::time::error::Elapsed: Send & Sync & Unpin); +assert_value!(tokio::time::error::Error: Send & Sync & Unpin); +async_assert_fn!(tokio::time::advance(Duration): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::sleep(Duration): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::sleep_until(Instant): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::timeout(Duration, BoxFutureSync<()>): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::timeout(Duration, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::time::timeout(Duration, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSync<()>): Send & Sync & !Unpin); +async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSend<()>): Send & !Sync & !Unpin); +async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Send & !Sync & !Unpin); +async_assert_fn!(tokio::time::Interval::tick(_): Send & Sync & !Unpin); -async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSync<()>): Send & Sync); -async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFutureSend<()>): Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<u32>::scope(_, u32, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSync<()>): Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFutureSend<()>): Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Cell<u32>>::scope(_, Cell<u32>, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSync<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFutureSend<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalKey<Rc<u32>>::scope(_, Rc<u32>, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::task::LocalSet::run_until(_, BoxFutureSync<()>): !Send & !Sync); -assert_value!(tokio::task::LocalSet: !Send & !Sync); +assert_value!(tokio::io::BufReader<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::BufStream<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::BufWriter<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::DuplexStream: Send & Sync & Unpin); +assert_value!(tokio::io::Empty: Send & Sync & Unpin); +assert_value!(tokio::io::Interest: Send & Sync & Unpin); +assert_value!(tokio::io::Lines<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::ReadBuf<'_>: Send & Sync & Unpin); +assert_value!(tokio::io::ReadHalf<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::Ready: Send & Sync & Unpin); +assert_value!(tokio::io::Repeat: Send & Sync & Unpin); +assert_value!(tokio::io::Sink: Send & Sync & Unpin); +assert_value!(tokio::io::Split<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::Stderr: Send & Sync & Unpin); +assert_value!(tokio::io::Stdin: Send & Sync & Unpin); +assert_value!(tokio::io::Stdout: Send & Sync & Unpin); +assert_value!(tokio::io::Take<TcpStream>: Send & Sync & Unpin); +assert_value!(tokio::io::WriteHalf<TcpStream>: Send & Sync & Unpin); +async_assert_fn!(tokio::io::copy(&mut TcpStream, &mut TcpStream): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::copy_bidirectional(&mut TcpStream, &mut TcpStream): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::copy_buf(&mut tokio::io::BufReader<TcpStream>, &mut TcpStream): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::empty(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::repeat(u8): Send & Sync & Unpin); +async_assert_fn!(tokio::io::sink(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::split(TcpStream): Send & Sync & Unpin); +async_assert_fn!(tokio::io::stderr(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::stdin(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::stdout(): Send & Sync & Unpin); +async_assert_fn!(tokio::io::Split<tokio::io::BufReader<TcpStream>>::next_segment(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::Lines<tokio::io::BufReader<TcpStream>>::next_line(_): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncBufReadExt::read_until(&mut BoxAsyncRead, u8, &mut Vec<u8>): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncBufReadExt::read_line(&mut BoxAsyncRead, &mut String): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncBufReadExt::fill_buf(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read(&mut BoxAsyncRead, &mut [u8]): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_buf(&mut BoxAsyncRead, &mut Vec<u8>): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncReadExt::read_exact(&mut BoxAsyncRead, &mut [u8]): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncReadExt::read_u8(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i8(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u16(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i16(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u32(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i32(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u64(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i64(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u128(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i128(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_f32(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_f64(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u16_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i16_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u32_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i32_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u64_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i64_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_u128_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_i128_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_f32_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_f64_le(&mut BoxAsyncRead): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncReadExt::read_to_end(&mut BoxAsyncRead, &mut Vec<u8>): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncReadExt::read_to_string(&mut BoxAsyncRead, &mut String): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncSeekExt::seek(&mut BoxAsyncSeek, SeekFrom): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncSeekExt::stream_position(&mut BoxAsyncSeek): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncWriteExt::write(&mut BoxAsyncWrite, &[u8]): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_vectored(&mut BoxAsyncWrite, _): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_buf(&mut BoxAsyncWrite, &mut bytes::Bytes): Send + & Sync + & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_all_buf(&mut BoxAsyncWrite, &mut bytes::Bytes): Send + & Sync + & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_all(&mut BoxAsyncWrite, &[u8]): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncWriteExt::write_u8(&mut BoxAsyncWrite, u8): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncWriteExt::write_i8(&mut BoxAsyncWrite, i8): Send & Sync & !Unpin); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u16(&mut BoxAsyncWrite, u16): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i16(&mut BoxAsyncWrite, i16): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u32(&mut BoxAsyncWrite, u32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i32(&mut BoxAsyncWrite, i32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u64(&mut BoxAsyncWrite, u64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i64(&mut BoxAsyncWrite, i64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u128(&mut BoxAsyncWrite, u128): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i128(&mut BoxAsyncWrite, i128): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_f32(&mut BoxAsyncWrite, f32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_f64(&mut BoxAsyncWrite, f64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u16_le(&mut BoxAsyncWrite, u16): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i16_le(&mut BoxAsyncWrite, i16): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u32_le(&mut BoxAsyncWrite, u32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i32_le(&mut BoxAsyncWrite, i32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u64_le(&mut BoxAsyncWrite, u64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i64_le(&mut BoxAsyncWrite, i64): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_u128_le(&mut BoxAsyncWrite, u128): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_i128_le(&mut BoxAsyncWrite, i128): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_f32_le(&mut BoxAsyncWrite, f32): Send & Sync & !Unpin +); +async_assert_fn!( + tokio::io::AsyncWriteExt::write_f64_le(&mut BoxAsyncWrite, f64): Send & Sync & !Unpin +); +async_assert_fn!(tokio::io::AsyncWriteExt::flush(&mut BoxAsyncWrite): Send & Sync & !Unpin); +async_assert_fn!(tokio::io::AsyncWriteExt::shutdown(&mut BoxAsyncWrite): Send & Sync & !Unpin); -async_assert_fn!(tokio::time::advance(Duration): Send & Sync); -async_assert_fn!(tokio::time::sleep(Duration): Send & Sync); -async_assert_fn!(tokio::time::sleep_until(Instant): Send & Sync); -async_assert_fn!(tokio::time::timeout(Duration, BoxFutureSync<()>): Send & Sync); -async_assert_fn!(tokio::time::timeout(Duration, BoxFutureSend<()>): Send & !Sync); -async_assert_fn!(tokio::time::timeout(Duration, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSync<()>): Send & Sync); -async_assert_fn!(tokio::time::timeout_at(Instant, BoxFutureSend<()>): Send & !Sync); -async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Send & !Sync); -async_assert_fn!(tokio::time::Interval::tick(_): Send & Sync); +#[cfg(unix)] +mod unix_asyncfd { + use super::*; + use tokio::io::unix::*; -assert_value!(tokio::time::Interval: Unpin); -async_assert_fn!(tokio::time::sleep(Duration): !Unpin); -async_assert_fn!(tokio::time::sleep_until(Instant): !Unpin); -async_assert_fn!(tokio::time::timeout(Duration, BoxFuture<()>): !Unpin); -async_assert_fn!(tokio::time::timeout_at(Instant, BoxFuture<()>): !Unpin); -async_assert_fn!(tokio::time::Interval::tick(_): !Unpin); -async_assert_fn!(tokio::io::AsyncBufReadExt::read_until(&mut BoxAsyncRead, u8, &mut Vec<u8>): !Unpin); -async_assert_fn!(tokio::io::AsyncBufReadExt::read_line(&mut BoxAsyncRead, &mut String): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read(&mut BoxAsyncRead, &mut [u8]): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_exact(&mut BoxAsyncRead, &mut [u8]): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u8(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i8(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u16(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i16(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u32(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i32(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u64(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i64(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u128(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i128(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u16_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i16_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u32_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i32_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u64_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i64_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_u128_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_i128_le(&mut BoxAsyncRead): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_to_end(&mut BoxAsyncRead, &mut Vec<u8>): !Unpin); -async_assert_fn!(tokio::io::AsyncReadExt::read_to_string(&mut BoxAsyncRead, &mut String): !Unpin); -async_assert_fn!(tokio::io::AsyncSeekExt::seek(&mut BoxAsyncSeek, SeekFrom): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write(&mut BoxAsyncWrite, &[u8]): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_all(&mut BoxAsyncWrite, &[u8]): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u8(&mut BoxAsyncWrite, u8): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i8(&mut BoxAsyncWrite, i8): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u16(&mut BoxAsyncWrite, u16): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i16(&mut BoxAsyncWrite, i16): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u32(&mut BoxAsyncWrite, u32): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i32(&mut BoxAsyncWrite, i32): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u64(&mut BoxAsyncWrite, u64): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i64(&mut BoxAsyncWrite, i64): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u128(&mut BoxAsyncWrite, u128): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i128(&mut BoxAsyncWrite, i128): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u16_le(&mut BoxAsyncWrite, u16): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i16_le(&mut BoxAsyncWrite, i16): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u32_le(&mut BoxAsyncWrite, u32): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i32_le(&mut BoxAsyncWrite, i32): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u64_le(&mut BoxAsyncWrite, u64): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i64_le(&mut BoxAsyncWrite, i64): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_u128_le(&mut BoxAsyncWrite, u128): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::write_i128_le(&mut BoxAsyncWrite, i128): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::flush(&mut BoxAsyncWrite): !Unpin); -async_assert_fn!(tokio::io::AsyncWriteExt::shutdown(&mut BoxAsyncWrite): !Unpin); + struct ImplsFd<T> { + _t: T, + } + impl<T> std::os::unix::io::AsRawFd for ImplsFd<T> { + fn as_raw_fd(&self) -> std::os::unix::io::RawFd { + unreachable!() + } + } + + assert_value!(AsyncFd<ImplsFd<YY>>: Send & Sync & Unpin); + assert_value!(AsyncFd<ImplsFd<YN>>: Send & !Sync & Unpin); + assert_value!(AsyncFd<ImplsFd<NN>>: !Send & !Sync & Unpin); + assert_value!(AsyncFdReadyGuard<'_, ImplsFd<YY>>: Send & Sync & Unpin); + assert_value!(AsyncFdReadyGuard<'_, ImplsFd<YN>>: !Send & !Sync & Unpin); + assert_value!(AsyncFdReadyGuard<'_, ImplsFd<NN>>: !Send & !Sync & Unpin); + assert_value!(AsyncFdReadyMutGuard<'_, ImplsFd<YY>>: Send & Sync & Unpin); + assert_value!(AsyncFdReadyMutGuard<'_, ImplsFd<YN>>: Send & !Sync & Unpin); + assert_value!(AsyncFdReadyMutGuard<'_, ImplsFd<NN>>: !Send & !Sync & Unpin); + assert_value!(TryIoError: Send & Sync & Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YY>>::readable(_): Send & Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YY>>::readable_mut(_): Send & Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YY>>::writable(_): Send & Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YY>>::writable_mut(_): Send & Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YN>>::readable(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YN>>::readable_mut(_): Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YN>>::writable(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<YN>>::writable_mut(_): Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<NN>>::readable(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<NN>>::readable_mut(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<NN>>::writable(_): !Send & !Sync & !Unpin); + async_assert_fn!(AsyncFd<ImplsFd<NN>>::writable_mut(_): !Send & !Sync & !Unpin); +} diff --git a/tests/fs_file.rs b/tests/fs_file.rs index bf2f1d7..f645e61 100644 --- a/tests/fs_file.rs +++ b/tests/fs_file.rs @@ -1,12 +1,11 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::fs::File; -use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; -use tokio_test::task; - use std::io::prelude::*; use tempfile::NamedTempFile; +use tokio::fs::File; +use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom}; +use tokio_test::task; const HELLO: &[u8] = b"hello world..."; @@ -51,6 +50,19 @@ async fn basic_write_and_shutdown() { } #[tokio::test] +async fn rewind_seek_position() { + let tempfile = tempfile(); + + let mut file = File::create(tempfile.path()).await.unwrap(); + + file.seek(SeekFrom::Current(10)).await.unwrap(); + + file.rewind().await.unwrap(); + + assert_eq!(file.stream_position().await.unwrap(), 0); +} + +#[tokio::test] async fn coop() { let mut tempfile = tempfile(); tempfile.write_all(HELLO).unwrap(); diff --git a/tests/io_async_fd.rs b/tests/io_async_fd.rs index d1586bb..5a6875e 100644 --- a/tests/io_async_fd.rs +++ b/tests/io_async_fd.rs @@ -13,10 +13,9 @@ use std::{ task::{Context, Waker}, }; -use nix::errno::Errno; use nix::unistd::{close, read, write}; -use futures::{poll, FutureExt}; +use futures::poll; use tokio::io::unix::{AsyncFd, AsyncFdReadyGuard}; use tokio_test::{assert_err, assert_pending}; @@ -56,10 +55,6 @@ impl TestWaker { } } -fn is_blocking(e: &nix::Error) -> bool { - Some(Errno::EAGAIN) == e.as_errno() -} - #[derive(Debug)] struct FileDescriptor { fd: RawFd, @@ -73,11 +68,7 @@ impl AsRawFd for FileDescriptor { impl Read for &FileDescriptor { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - match read(self.fd, buf) { - Ok(n) => Ok(n), - Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), - Err(e) => Err(io::Error::new(ErrorKind::Other, e)), - } + read(self.fd, buf).map_err(io::Error::from) } } @@ -89,11 +80,7 @@ impl Read for FileDescriptor { impl Write for &FileDescriptor { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - match write(self.fd, buf) { - Ok(n) => Ok(n), - Err(e) if is_blocking(&e) => Err(ErrorKind::WouldBlock.into()), - Err(e) => Err(io::Error::new(ErrorKind::Other, e)), - } + write(self.fd, buf).map_err(io::Error::from) } fn flush(&mut self) -> io::Result<()> { @@ -176,10 +163,11 @@ async fn initially_writable() { afd_a.writable().await.unwrap().clear_ready(); afd_b.writable().await.unwrap().clear_ready(); - futures::select_biased! { - _ = tokio::time::sleep(Duration::from_millis(10)).fuse() => {}, - _ = afd_a.readable().fuse() => panic!("Unexpected readable state"), - _ = afd_b.readable().fuse() => panic!("Unexpected readable state"), + tokio::select! { + biased; + _ = tokio::time::sleep(Duration::from_millis(10)) => {}, + _ = afd_a.readable() => panic!("Unexpected readable state"), + _ = afd_b.readable() => panic!("Unexpected readable state"), } } @@ -366,12 +354,13 @@ async fn multiple_waiters() { futures::future::pending::<()>().await; }; - futures::select_biased! { - guard = afd_a.readable().fuse() => { + tokio::select! { + biased; + guard = afd_a.readable() => { tokio::task::yield_now().await; guard.unwrap().clear_ready() }, - _ = notify_barrier.fuse() => unreachable!(), + _ = notify_barrier => unreachable!(), } std::mem::drop(afd_a); diff --git a/tests/io_buf_reader.rs b/tests/io_buf_reader.rs new file mode 100644 index 0000000..0d3f6ba --- /dev/null +++ b/tests/io_buf_reader.rs @@ -0,0 +1,379 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +// https://github.com/rust-lang/futures-rs/blob/1803948ff091b4eabf7f3bf39e16bbbdefca5cc8/futures/tests/io_buf_reader.rs + +use futures::task::{noop_waker_ref, Context, Poll}; +use std::cmp; +use std::io::{self, Cursor}; +use std::pin::Pin; +use tokio::io::{ + AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, AsyncWriteExt, + BufReader, ReadBuf, SeekFrom, +}; +use tokio_test::task::spawn; +use tokio_test::{assert_pending, assert_ready}; + +macro_rules! run_fill_buf { + ($reader:expr) => {{ + let mut cx = Context::from_waker(noop_waker_ref()); + loop { + if let Poll::Ready(x) = Pin::new(&mut $reader).poll_fill_buf(&mut cx) { + break x; + } + } + }}; +} + +struct MaybePending<'a> { + inner: &'a [u8], + ready_read: bool, + ready_fill_buf: bool, +} + +impl<'a> MaybePending<'a> { + fn new(inner: &'a [u8]) -> Self { + Self { + inner, + ready_read: false, + ready_fill_buf: false, + } + } +} + +impl AsyncRead for MaybePending<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + if self.ready_read { + self.ready_read = false; + Pin::new(&mut self.inner).poll_read(cx, buf) + } else { + self.ready_read = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } +} + +impl AsyncBufRead for MaybePending<'_> { + fn poll_fill_buf(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { + if self.ready_fill_buf { + self.ready_fill_buf = false; + if self.inner.is_empty() { + return Poll::Ready(Ok(&[])); + } + let len = cmp::min(2, self.inner.len()); + Poll::Ready(Ok(&self.inner[0..len])) + } else { + self.ready_fill_buf = true; + Poll::Pending + } + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + self.inner = &self.inner[amt..]; + } +} + +#[tokio::test] +async fn test_buffered_reader() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, inner); + + let mut buf = [0, 0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 3); + assert_eq!(buf, [5, 6, 7]); + assert_eq!(reader.buffer(), []); + + let mut buf = [0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 2); + assert_eq!(buf, [0, 1]); + assert_eq!(reader.buffer(), []); + + let mut buf = [0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [2]); + assert_eq!(reader.buffer(), [3]); + + let mut buf = [0, 0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [3, 0, 0]); + assert_eq!(reader.buffer(), []); + + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [4, 0, 0]); + assert_eq!(reader.buffer(), []); + + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn test_buffered_reader_seek() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, Cursor::new(inner)); + + assert_eq!(reader.seek(SeekFrom::Start(3)).await.unwrap(), 3); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1][..]); + assert!(reader.seek(SeekFrom::Current(i64::MIN)).await.is_err()); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1][..]); + assert_eq!(reader.seek(SeekFrom::Current(1)).await.unwrap(), 4); + assert_eq!(run_fill_buf!(reader).unwrap(), &[1, 2][..]); + Pin::new(&mut reader).consume(1); + assert_eq!(reader.seek(SeekFrom::Current(-2)).await.unwrap(), 3); +} + +#[tokio::test] +async fn test_buffered_reader_seek_underflow() { + // gimmick reader that yields its position modulo 256 for each byte + struct PositionReader { + pos: u64, + } + impl AsyncRead for PositionReader { + fn poll_read( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let b = buf.initialize_unfilled(); + let len = b.len(); + for x in b { + *x = self.pos as u8; + self.pos = self.pos.wrapping_add(1); + } + buf.advance(len); + Poll::Ready(Ok(())) + } + } + impl AsyncSeek for PositionReader { + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + match pos { + SeekFrom::Start(n) => { + self.pos = n; + } + SeekFrom::Current(n) => { + self.pos = self.pos.wrapping_add(n as u64); + } + SeekFrom::End(n) => { + self.pos = u64::MAX.wrapping_add(n as u64); + } + } + Ok(()) + } + fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> { + Poll::Ready(Ok(self.pos)) + } + } + + let mut reader = BufReader::with_capacity(5, PositionReader { pos: 0 }); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1, 2, 3, 4][..]); + assert_eq!(reader.seek(SeekFrom::End(-5)).await.unwrap(), u64::MAX - 5); + assert_eq!(run_fill_buf!(reader).unwrap().len(), 5); + // the following seek will require two underlying seeks + let expected = 9_223_372_036_854_775_802; + assert_eq!( + reader.seek(SeekFrom::Current(i64::MIN)).await.unwrap(), + expected + ); + assert_eq!(run_fill_buf!(reader).unwrap().len(), 5); + // seeking to 0 should empty the buffer. + assert_eq!(reader.seek(SeekFrom::Current(0)).await.unwrap(), expected); + assert_eq!(reader.get_ref().pos, expected); +} + +#[tokio::test] +async fn test_short_reads() { + /// A dummy reader intended at testing short-reads propagation. + struct ShortReader { + lengths: Vec<usize>, + } + + impl AsyncRead for ShortReader { + fn poll_read( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + if !self.lengths.is_empty() { + buf.advance(self.lengths.remove(0)); + } + Poll::Ready(Ok(())) + } + } + + let inner = ShortReader { + lengths: vec![0, 1, 2, 0, 1, 0], + }; + let mut reader = BufReader::new(inner); + let mut buf = [0, 0]; + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); + assert_eq!(reader.read(&mut buf).await.unwrap(), 1); + assert_eq!(reader.read(&mut buf).await.unwrap(), 2); + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); + assert_eq!(reader.read(&mut buf).await.unwrap(), 1); + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn maybe_pending() { + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, MaybePending::new(inner)); + + let mut buf = [0, 0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 3); + assert_eq!(buf, [5, 6, 7]); + assert_eq!(reader.buffer(), []); + + let mut buf = [0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 2); + assert_eq!(buf, [0, 1]); + assert_eq!(reader.buffer(), []); + + let mut buf = [0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [2]); + assert_eq!(reader.buffer(), [3]); + + let mut buf = [0, 0, 0]; + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [3, 0, 0]); + assert_eq!(reader.buffer(), []); + + let nread = reader.read(&mut buf).await.unwrap(); + assert_eq!(nread, 1); + assert_eq!(buf, [4, 0, 0]); + assert_eq!(reader.buffer(), []); + + assert_eq!(reader.read(&mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn maybe_pending_buf_read() { + let inner = MaybePending::new(&[0, 1, 2, 3, 1, 0]); + let mut reader = BufReader::with_capacity(2, inner); + let mut v = Vec::new(); + reader.read_until(3, &mut v).await.unwrap(); + assert_eq!(v, [0, 1, 2, 3]); + v.clear(); + reader.read_until(1, &mut v).await.unwrap(); + assert_eq!(v, [1]); + v.clear(); + reader.read_until(8, &mut v).await.unwrap(); + assert_eq!(v, [0]); + v.clear(); + reader.read_until(9, &mut v).await.unwrap(); + assert_eq!(v, []); +} + +// https://github.com/rust-lang/futures-rs/pull/1573#discussion_r281162309 +#[tokio::test] +async fn maybe_pending_seek() { + struct MaybePendingSeek<'a> { + inner: Cursor<&'a [u8]>, + ready: bool, + seek_res: Option<io::Result<()>>, + } + + impl<'a> MaybePendingSeek<'a> { + fn new(inner: &'a [u8]) -> Self { + Self { + inner: Cursor::new(inner), + ready: true, + seek_res: None, + } + } + } + + impl AsyncRead for MaybePendingSeek<'_> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_read(cx, buf) + } + } + + impl AsyncBufRead for MaybePendingSeek<'_> { + fn poll_fill_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<io::Result<&[u8]>> { + let this: *mut Self = &mut *self as *mut _; + Pin::new(&mut unsafe { &mut *this }.inner).poll_fill_buf(cx) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + Pin::new(&mut self.inner).consume(amt) + } + } + + impl AsyncSeek for MaybePendingSeek<'_> { + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + self.seek_res = Some(Pin::new(&mut self.inner).start_seek(pos)); + Ok(()) + } + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + if self.ready { + self.ready = false; + self.seek_res.take().unwrap_or(Ok(()))?; + Pin::new(&mut self.inner).poll_complete(cx) + } else { + self.ready = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + let inner: &[u8] = &[5, 6, 7, 0, 1, 2, 3, 4]; + let mut reader = BufReader::with_capacity(2, MaybePendingSeek::new(inner)); + + assert_eq!(reader.seek(SeekFrom::Current(3)).await.unwrap(), 3); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1][..]); + assert!(reader.seek(SeekFrom::Current(i64::MIN)).await.is_err()); + assert_eq!(run_fill_buf!(reader).unwrap(), &[0, 1][..]); + assert_eq!(reader.seek(SeekFrom::Current(1)).await.unwrap(), 4); + assert_eq!(run_fill_buf!(reader).unwrap(), &[1, 2][..]); + Pin::new(&mut reader).consume(1); + assert_eq!(reader.seek(SeekFrom::Current(-2)).await.unwrap(), 3); +} + +// This tests the AsyncBufReadExt::fill_buf wrapper. +#[tokio::test] +async fn test_fill_buf_wrapper() { + let (mut write, read) = tokio::io::duplex(16); + + let mut read = BufReader::new(read); + write.write_all(b"hello world").await.unwrap(); + + assert_eq!(read.fill_buf().await.unwrap(), b"hello world"); + read.consume(b"hello ".len()); + assert_eq!(read.fill_buf().await.unwrap(), b"world"); + assert_eq!(read.fill_buf().await.unwrap(), b"world"); + read.consume(b"world".len()); + + let mut fill = spawn(read.fill_buf()); + assert_pending!(fill.poll()); + + write.write_all(b"foo bar").await.unwrap(); + assert_eq!(assert_ready!(fill.poll()).unwrap(), b"foo bar"); + drop(fill); + + drop(write); + assert_eq!(read.fill_buf().await.unwrap(), b"foo bar"); + read.consume(b"foo bar".len()); + assert_eq!(read.fill_buf().await.unwrap(), b""); +} diff --git a/tests/io_buf_writer.rs b/tests/io_buf_writer.rs new file mode 100644 index 0000000..47a0d46 --- /dev/null +++ b/tests/io_buf_writer.rs @@ -0,0 +1,537 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +// https://github.com/rust-lang/futures-rs/blob/1803948ff091b4eabf7f3bf39e16bbbdefca5cc8/futures/tests/io_buf_writer.rs + +use futures::task::{Context, Poll}; +use std::io::{self, Cursor}; +use std::pin::Pin; +use tokio::io::{AsyncSeek, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter, SeekFrom}; + +use futures::future; +use tokio_test::assert_ok; + +use std::cmp; +use std::io::IoSlice; + +mod support { + pub(crate) mod io_vec; +} +use support::io_vec::IoBufs; + +struct MaybePending { + inner: Vec<u8>, + ready: bool, +} + +impl MaybePending { + fn new(inner: Vec<u8>) -> Self { + Self { + inner, + ready: false, + } + } +} + +impl AsyncWrite for MaybePending { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + if self.ready { + self.ready = false; + Pin::new(&mut self.inner).poll_write(cx, buf) + } else { + self.ready = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } +} + +async fn write_vectored<W>(writer: &mut W, bufs: &[IoSlice<'_>]) -> io::Result<usize> +where + W: AsyncWrite + Unpin, +{ + let mut writer = Pin::new(writer); + future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, bufs)).await +} + +#[tokio::test] +async fn buf_writer() { + let mut writer = BufWriter::with_capacity(2, Vec::new()); + + writer.write(&[0, 1]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1]); + + writer.write(&[2]).await.unwrap(); + assert_eq!(writer.buffer(), [2]); + assert_eq!(*writer.get_ref(), [0, 1]); + + writer.write(&[3]).await.unwrap(); + assert_eq!(writer.buffer(), [2, 3]); + assert_eq!(*writer.get_ref(), [0, 1]); + + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3]); + + writer.write(&[4]).await.unwrap(); + writer.write(&[5]).await.unwrap(); + assert_eq!(writer.buffer(), [4, 5]); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3]); + + writer.write(&[6]).await.unwrap(); + assert_eq!(writer.buffer(), [6]); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5]); + + writer.write(&[7, 8]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8]); + + writer.write(&[9, 10, 11]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); + + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(*writer.get_ref(), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]); +} + +#[tokio::test] +async fn buf_writer_inner_flushes() { + let mut w = BufWriter::with_capacity(3, Vec::new()); + w.write(&[0, 1]).await.unwrap(); + assert_eq!(*w.get_ref(), []); + w.flush().await.unwrap(); + let w = w.into_inner(); + assert_eq!(w, [0, 1]); +} + +#[tokio::test] +async fn buf_writer_seek() { + let mut w = BufWriter::with_capacity(3, Cursor::new(Vec::new())); + w.write_all(&[0, 1, 2, 3, 4, 5]).await.unwrap(); + w.write_all(&[6, 7]).await.unwrap(); + assert_eq!(w.seek(SeekFrom::Current(0)).await.unwrap(), 8); + assert_eq!(&w.get_ref().get_ref()[..], &[0, 1, 2, 3, 4, 5, 6, 7][..]); + assert_eq!(w.seek(SeekFrom::Start(2)).await.unwrap(), 2); + w.write_all(&[8, 9]).await.unwrap(); + w.flush().await.unwrap(); + assert_eq!(&w.into_inner().into_inner()[..], &[0, 1, 8, 9, 4, 5, 6, 7]); +} + +#[tokio::test] +async fn maybe_pending_buf_writer() { + let mut writer = BufWriter::with_capacity(2, MaybePending::new(Vec::new())); + + writer.write(&[0, 1]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(&writer.get_ref().inner, &[0, 1]); + + writer.write(&[2]).await.unwrap(); + assert_eq!(writer.buffer(), [2]); + assert_eq!(&writer.get_ref().inner, &[0, 1]); + + writer.write(&[3]).await.unwrap(); + assert_eq!(writer.buffer(), [2, 3]); + assert_eq!(&writer.get_ref().inner, &[0, 1]); + + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(&writer.get_ref().inner, &[0, 1, 2, 3]); + + writer.write(&[4]).await.unwrap(); + writer.write(&[5]).await.unwrap(); + assert_eq!(writer.buffer(), [4, 5]); + assert_eq!(&writer.get_ref().inner, &[0, 1, 2, 3]); + + writer.write(&[6]).await.unwrap(); + assert_eq!(writer.buffer(), [6]); + assert_eq!(writer.get_ref().inner, &[0, 1, 2, 3, 4, 5]); + + writer.write(&[7, 8]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!(writer.get_ref().inner, &[0, 1, 2, 3, 4, 5, 6, 7, 8]); + + writer.write(&[9, 10, 11]).await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!( + writer.get_ref().inner, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + ); + + writer.flush().await.unwrap(); + assert_eq!(writer.buffer(), []); + assert_eq!( + &writer.get_ref().inner, + &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + ); +} + +#[tokio::test] +async fn maybe_pending_buf_writer_inner_flushes() { + let mut w = BufWriter::with_capacity(3, MaybePending::new(Vec::new())); + w.write(&[0, 1]).await.unwrap(); + assert_eq!(&w.get_ref().inner, &[]); + w.flush().await.unwrap(); + let w = w.into_inner().inner; + assert_eq!(w, [0, 1]); +} + +#[tokio::test] +async fn maybe_pending_buf_writer_seek() { + struct MaybePendingSeek { + inner: Cursor<Vec<u8>>, + ready_write: bool, + ready_seek: bool, + seek_res: Option<io::Result<()>>, + } + + impl MaybePendingSeek { + fn new(inner: Vec<u8>) -> Self { + Self { + inner: Cursor::new(inner), + ready_write: false, + ready_seek: false, + seek_res: None, + } + } + } + + impl AsyncWrite for MaybePendingSeek { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + if self.ready_write { + self.ready_write = false; + Pin::new(&mut self.inner).poll_write(cx, buf) + } else { + self.ready_write = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.inner).poll_shutdown(cx) + } + } + + impl AsyncSeek for MaybePendingSeek { + fn start_seek(mut self: Pin<&mut Self>, pos: SeekFrom) -> io::Result<()> { + self.seek_res = Some(Pin::new(&mut self.inner).start_seek(pos)); + Ok(()) + } + fn poll_complete(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> { + if self.ready_seek { + self.ready_seek = false; + self.seek_res.take().unwrap_or(Ok(()))?; + Pin::new(&mut self.inner).poll_complete(cx) + } else { + self.ready_seek = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + let mut w = BufWriter::with_capacity(3, MaybePendingSeek::new(Vec::new())); + w.write_all(&[0, 1, 2, 3, 4, 5]).await.unwrap(); + w.write_all(&[6, 7]).await.unwrap(); + assert_eq!(w.seek(SeekFrom::Current(0)).await.unwrap(), 8); + assert_eq!( + &w.get_ref().inner.get_ref()[..], + &[0, 1, 2, 3, 4, 5, 6, 7][..] + ); + assert_eq!(w.seek(SeekFrom::Start(2)).await.unwrap(), 2); + w.write_all(&[8, 9]).await.unwrap(); + w.flush().await.unwrap(); + assert_eq!( + &w.into_inner().inner.into_inner()[..], + &[0, 1, 8, 9, 4, 5, 6, 7] + ); +} + +struct MockWriter { + data: Vec<u8>, + write_len: usize, + vectored: bool, +} + +impl MockWriter { + fn new(write_len: usize) -> Self { + MockWriter { + data: Vec::new(), + write_len, + vectored: false, + } + } + + fn vectored(write_len: usize) -> Self { + MockWriter { + data: Vec::new(), + write_len, + vectored: true, + } + } + + fn write_up_to(&mut self, buf: &[u8], limit: usize) -> usize { + let len = cmp::min(buf.len(), limit); + self.data.extend_from_slice(&buf[..len]); + len + } +} + +impl AsyncWrite for MockWriter { + fn poll_write( + self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + let this = self.get_mut(); + let n = this.write_up_to(buf, this.write_len); + Ok(n).into() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll<Result<usize, io::Error>> { + let this = self.get_mut(); + let mut total_written = 0; + for buf in bufs { + let n = this.write_up_to(buf, this.write_len - total_written); + total_written += n; + if total_written == this.write_len { + break; + } + } + Ok(total_written).into() + } + + fn is_write_vectored(&self) -> bool { + self.vectored + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> { + Ok(()).into() + } +} + +#[tokio::test] +async fn write_vectored_empty_on_non_vectored() { + let mut w = BufWriter::new(MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &[]).await); + assert_eq!(n, 0); + + let io_vec = [IoSlice::new(&[]); 3]; + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 0); + + assert_ok!(w.flush().await); + assert!(w.get_ref().data.is_empty()); +} + +#[tokio::test] +async fn write_vectored_empty_on_vectored() { + let mut w = BufWriter::new(MockWriter::vectored(4)); + let n = assert_ok!(write_vectored(&mut w, &[]).await); + assert_eq!(n, 0); + + let io_vec = [IoSlice::new(&[]); 3]; + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 0); + + assert_ok!(w.flush().await); + assert!(w.get_ref().data.is_empty()); +} + +#[tokio::test] +async fn write_vectored_basic_on_non_vectored() { + let msg = b"foo bar baz"; + let bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let mut w = BufWriter::new(MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &bufs).await); + assert_eq!(n, msg.len()); + assert!(w.buffer() == &msg[..]); + assert_ok!(w.flush().await); + assert_eq!(w.get_ref().data, msg); +} + +#[tokio::test] +async fn write_vectored_basic_on_vectored() { + let msg = b"foo bar baz"; + let bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let mut w = BufWriter::new(MockWriter::vectored(4)); + let n = assert_ok!(write_vectored(&mut w, &bufs).await); + assert_eq!(n, msg.len()); + assert!(w.buffer() == &msg[..]); + assert_ok!(w.flush().await); + assert_eq!(w.get_ref().data, msg); +} + +#[tokio::test] +async fn write_vectored_large_total_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let io_vec = IoBufs::new(&mut bufs); + let mut w = BufWriter::with_capacity(8, MockWriter::new(4)); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 8); + assert!(w.buffer() == &msg[..8]); + let io_vec = io_vec.advance(n); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 3); + assert!(w.get_ref().data.as_slice() == &msg[..8]); + assert!(w.buffer() == &msg[8..]); +} + +#[tokio::test] +async fn write_vectored_large_total_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&msg[4..8]), + IoSlice::new(&msg[8..]), + ]; + let io_vec = IoBufs::new(&mut bufs); + let mut w = BufWriter::with_capacity(8, MockWriter::vectored(10)); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 10); + assert!(w.buffer().is_empty()); + let io_vec = io_vec.advance(n); + let n = assert_ok!(write_vectored(&mut w, &io_vec).await); + assert_eq!(n, 1); + assert!(w.get_ref().data.as_slice() == &msg[..10]); + assert!(w.buffer() == &msg[10..]); +} + +struct VectoredWriteHarness { + writer: BufWriter<MockWriter>, + buf_capacity: usize, +} + +impl VectoredWriteHarness { + fn new(buf_capacity: usize) -> Self { + VectoredWriteHarness { + writer: BufWriter::with_capacity(buf_capacity, MockWriter::new(4)), + buf_capacity, + } + } + + fn with_vectored_backend(buf_capacity: usize) -> Self { + VectoredWriteHarness { + writer: BufWriter::with_capacity(buf_capacity, MockWriter::vectored(4)), + buf_capacity, + } + } + + async fn write_all<'a, 'b>(&mut self, mut io_vec: IoBufs<'a, 'b>) -> usize { + let mut total_written = 0; + while !io_vec.is_empty() { + let n = assert_ok!(write_vectored(&mut self.writer, &io_vec).await); + assert!(n != 0); + assert!(self.writer.buffer().len() <= self.buf_capacity); + total_written += n; + io_vec = io_vec.advance(n); + } + total_written + } + + async fn flush(&mut self) -> &[u8] { + assert_ok!(self.writer.flush().await); + &self.writer.get_ref().data + } +} + +#[tokio::test] +async fn write_vectored_odd_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&[]), + IoSlice::new(&msg[4..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::new(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_odd_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&msg[0..4]), + IoSlice::new(&[]), + IoSlice::new(&msg[4..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::with_vectored_backend(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_large_slice_on_non_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&[]), + IoSlice::new(&msg[..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::new(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} + +#[tokio::test] +async fn write_vectored_large_slice_on_vectored() { + let msg = b"foo bar baz"; + let mut bufs = [ + IoSlice::new(&[]), + IoSlice::new(&msg[..9]), + IoSlice::new(&msg[9..]), + ]; + let mut h = VectoredWriteHarness::with_vectored_backend(8); + let bytes_written = h.write_all(IoBufs::new(&mut bufs)).await; + assert_eq!(bytes_written, msg.len()); + assert_eq!(h.flush().await, msg); +} diff --git a/tests/io_copy.rs b/tests/io_copy.rs index 9ed7995..005e170 100644 --- a/tests/io_copy.rs +++ b/tests/io_copy.rs @@ -1,7 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::{self, AsyncRead, ReadBuf}; +use bytes::BytesMut; +use futures::ready; +use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_test::assert_ok; use std::pin::Pin; @@ -34,3 +36,52 @@ async fn copy() { assert_eq!(n, 11); assert_eq!(wr, b"hello world"); } + +#[tokio::test] +async fn proxy() { + struct BufferedWd { + buf: BytesMut, + writer: io::DuplexStream, + } + + impl AsyncWrite for BufferedWd { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.get_mut().buf.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + let this = self.get_mut(); + + while !this.buf.is_empty() { + let n = ready!(Pin::new(&mut this.writer).poll_write(cx, &this.buf))?; + let _ = this.buf.split_to(n); + } + + Pin::new(&mut this.writer).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Pin::new(&mut self.writer).poll_shutdown(cx) + } + } + + let (rd, wd) = io::duplex(1024); + let mut rd = rd.take(1024); + let mut wd = BufferedWd { + buf: BytesMut::new(), + writer: wd, + }; + + // write start bytes + assert_ok!(wd.write_all(&[0x42; 512]).await); + assert_ok!(wd.flush().await); + + let n = assert_ok!(io::copy(&mut rd, &mut wd).await); + + assert_eq!(n, 1024); +} diff --git a/tests/io_copy_bidirectional.rs b/tests/io_copy_bidirectional.rs index 17c0597..0e82b29 100644 --- a/tests/io_copy_bidirectional.rs +++ b/tests/io_copy_bidirectional.rs @@ -26,7 +26,7 @@ async fn block_write(s: &mut TcpStream) -> usize { result = s.write(&BUF) => { copied += result.expect("write error") }, - _ = tokio::time::sleep(Duration::from_millis(100)) => { + _ = tokio::time::sleep(Duration::from_millis(10)) => { break; } } @@ -42,7 +42,7 @@ where { // We run the test twice, with streams passed to copy_bidirectional in // different orders, in order to ensure that the two arguments are - // interchangable. + // interchangeable. let (a, mut a1) = make_socketpair().await; let (b, mut b1) = make_socketpair().await; diff --git a/tests/io_fill_buf.rs b/tests/io_fill_buf.rs new file mode 100644 index 0000000..0b2ebd7 --- /dev/null +++ b/tests/io_fill_buf.rs @@ -0,0 +1,34 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tempfile::NamedTempFile; +use tokio::fs::File; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio_test::assert_ok; + +#[tokio::test] +async fn fill_buf_file() { + let file = NamedTempFile::new().unwrap(); + + assert_ok!(std::fs::write(file.path(), b"hello")); + + let file = assert_ok!(File::open(file.path()).await); + let mut file = BufReader::new(file); + + let mut contents = Vec::new(); + + loop { + let consumed = { + let buffer = assert_ok!(file.fill_buf().await); + if buffer.is_empty() { + break; + } + contents.extend_from_slice(buffer); + buffer.len() + }; + + file.consume(consumed); + } + + assert_eq!(contents, b"hello"); +} diff --git a/tests/io_mem_stream.rs b/tests/io_mem_stream.rs index 3335214..520391a 100644 --- a/tests/io_mem_stream.rs +++ b/tests/io_mem_stream.rs @@ -63,6 +63,26 @@ async fn disconnect() { } #[tokio::test] +#[cfg(not(target_os = "android"))] +async fn disconnect_reader() { + let (a, mut b) = duplex(2); + + let t1 = tokio::spawn(async move { + // this will block, as not all data fits into duplex + b.write_all(b"ping").await.unwrap_err(); + }); + + let t2 = tokio::spawn(async move { + // here we drop the reader side, and we expect the writer in the other + // task to exit with an error + drop(a); + }); + + t2.await.unwrap(); + t1.await.unwrap(); +} + +#[tokio::test] async fn max_write_size() { let (mut a, mut b) = duplex(32); @@ -73,11 +93,11 @@ async fn max_write_size() { assert_eq!(n, 4); }); - let t2 = tokio::spawn(async move { - let mut buf = [0u8; 4]; - b.read_exact(&mut buf).await.unwrap(); - }); + let mut buf = [0u8; 4]; + b.read_exact(&mut buf).await.unwrap(); t1.await.unwrap(); - t2.await.unwrap(); + + // drop b only after task t1 finishes writing + drop(b); } diff --git a/tests/io_poll_aio.rs b/tests/io_poll_aio.rs new file mode 100644 index 0000000..f044af5 --- /dev/null +++ b/tests/io_poll_aio.rs @@ -0,0 +1,375 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(target_os = "freebsd", feature = "net"))] + +use mio_aio::{AioCb, AioFsyncMode, LioCb}; +use std::{ + future::Future, + mem, + os::unix::io::{AsRawFd, RawFd}, + pin::Pin, + task::{Context, Poll}, +}; +use tempfile::tempfile; +use tokio::io::bsd::{Aio, AioSource}; +use tokio_test::assert_pending; + +mod aio { + use super::*; + + /// Adapts mio_aio::AioCb (which implements mio::event::Source) to AioSource + struct WrappedAioCb<'a>(AioCb<'a>); + impl<'a> AioSource for WrappedAioCb<'a> { + fn register(&mut self, kq: RawFd, token: usize) { + self.0.register_raw(kq, token) + } + fn deregister(&mut self) { + self.0.deregister_raw() + } + } + + /// A very crude implementation of an AIO-based future + struct FsyncFut(Aio<WrappedAioCb<'static>>); + + impl Future for FsyncFut { + type Output = std::io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let poll_result = self.0.poll_ready(cx); + match poll_result { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok(_ev)) => { + // At this point, we could clear readiness. But there's no + // point, since we're about to drop the Aio. + let result = (*self.0).0.aio_return(); + match result { + Ok(_) => Poll::Ready(Ok(())), + Err(e) => Poll::Ready(Err(e.into())), + } + } + } + } + } + + /// Low-level AIO Source + /// + /// An example bypassing mio_aio and Nix to demonstrate how the kevent + /// registration actually works, under the hood. + struct LlSource(Pin<Box<libc::aiocb>>); + + impl AioSource for LlSource { + fn register(&mut self, kq: RawFd, token: usize) { + let mut sev: libc::sigevent = unsafe { mem::MaybeUninit::zeroed().assume_init() }; + sev.sigev_notify = libc::SIGEV_KEVENT; + sev.sigev_signo = kq; + sev.sigev_value = libc::sigval { + sival_ptr: token as *mut libc::c_void, + }; + self.0.aio_sigevent = sev; + } + + fn deregister(&mut self) { + unsafe { + self.0.aio_sigevent = mem::zeroed(); + } + } + } + + struct LlFut(Aio<LlSource>); + + impl Future for LlFut { + type Output = std::io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let poll_result = self.0.poll_ready(cx); + match poll_result { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok(_ev)) => { + let r = unsafe { libc::aio_return(self.0 .0.as_mut().get_unchecked_mut()) }; + assert_eq!(0, r); + Poll::Ready(Ok(())) + } + } + } + } + + /// A very simple object that can implement AioSource and can be reused. + /// + /// mio_aio normally assumes that each AioCb will be consumed on completion. + /// This somewhat contrived example shows how an Aio object can be reused + /// anyway. + struct ReusableFsyncSource { + aiocb: Pin<Box<AioCb<'static>>>, + fd: RawFd, + token: usize, + } + impl ReusableFsyncSource { + fn fsync(&mut self) { + self.aiocb.register_raw(self.fd, self.token); + self.aiocb.fsync(AioFsyncMode::O_SYNC).unwrap(); + } + fn new(aiocb: AioCb<'static>) -> Self { + ReusableFsyncSource { + aiocb: Box::pin(aiocb), + fd: 0, + token: 0, + } + } + fn reset(&mut self, aiocb: AioCb<'static>) { + self.aiocb = Box::pin(aiocb); + } + } + impl AioSource for ReusableFsyncSource { + fn register(&mut self, kq: RawFd, token: usize) { + self.fd = kq; + self.token = token; + } + fn deregister(&mut self) { + self.fd = 0; + } + } + + struct ReusableFsyncFut<'a>(&'a mut Aio<ReusableFsyncSource>); + impl<'a> Future for ReusableFsyncFut<'a> { + type Output = std::io::Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let poll_result = self.0.poll_ready(cx); + match poll_result { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok(ev)) => { + // Since this future uses a reusable Aio, we must clear + // its readiness here. That makes the future + // non-idempotent; the caller can't poll it repeatedly after + // it has already returned Ready. But that's ok; most + // futures behave this way. + self.0.clear_ready(ev); + let result = (*self.0).aiocb.aio_return(); + match result { + Ok(_) => Poll::Ready(Ok(())), + Err(e) => Poll::Ready(Err(e.into())), + } + } + } + } + } + + #[tokio::test] + async fn fsync() { + let f = tempfile().unwrap(); + let fd = f.as_raw_fd(); + let aiocb = AioCb::from_fd(fd, 0); + let source = WrappedAioCb(aiocb); + let mut poll_aio = Aio::new_for_aio(source).unwrap(); + (*poll_aio).0.fsync(AioFsyncMode::O_SYNC).unwrap(); + let fut = FsyncFut(poll_aio); + fut.await.unwrap(); + } + + #[tokio::test] + async fn ll_fsync() { + let f = tempfile().unwrap(); + let fd = f.as_raw_fd(); + let mut aiocb: libc::aiocb = unsafe { mem::MaybeUninit::zeroed().assume_init() }; + aiocb.aio_fildes = fd; + let source = LlSource(Box::pin(aiocb)); + let mut poll_aio = Aio::new_for_aio(source).unwrap(); + let r = unsafe { + let p = (*poll_aio).0.as_mut().get_unchecked_mut(); + libc::aio_fsync(libc::O_SYNC, p) + }; + assert_eq!(0, r); + let fut = LlFut(poll_aio); + fut.await.unwrap(); + } + + /// A suitably crafted future type can reuse an Aio object + #[tokio::test] + async fn reuse() { + let f = tempfile().unwrap(); + let fd = f.as_raw_fd(); + let aiocb0 = AioCb::from_fd(fd, 0); + let source = ReusableFsyncSource::new(aiocb0); + let mut poll_aio = Aio::new_for_aio(source).unwrap(); + poll_aio.fsync(); + let fut0 = ReusableFsyncFut(&mut poll_aio); + fut0.await.unwrap(); + + let aiocb1 = AioCb::from_fd(fd, 0); + poll_aio.reset(aiocb1); + let mut ctx = Context::from_waker(futures::task::noop_waker_ref()); + assert_pending!(poll_aio.poll_ready(&mut ctx)); + poll_aio.fsync(); + let fut1 = ReusableFsyncFut(&mut poll_aio); + fut1.await.unwrap(); + } +} + +mod lio { + use super::*; + + struct WrappedLioCb<'a>(LioCb<'a>); + impl<'a> AioSource for WrappedLioCb<'a> { + fn register(&mut self, kq: RawFd, token: usize) { + self.0.register_raw(kq, token) + } + fn deregister(&mut self) { + self.0.deregister_raw() + } + } + + /// A very crude lio_listio-based Future + struct LioFut(Option<Aio<WrappedLioCb<'static>>>); + + impl Future for LioFut { + type Output = std::io::Result<Vec<isize>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let poll_result = self.0.as_mut().unwrap().poll_ready(cx); + match poll_result { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok(_ev)) => { + // At this point, we could clear readiness. But there's no + // point, since we're about to drop the Aio. + let r = self.0.take().unwrap().into_inner().0.into_results(|iter| { + iter.map(|lr| lr.result.unwrap()).collect::<Vec<isize>>() + }); + Poll::Ready(Ok(r)) + } + } + } + } + + /// Minimal example demonstrating reuse of an Aio object with lio + /// readiness. mio_aio::LioCb actually does something similar under the + /// hood. + struct ReusableLioSource { + liocb: Option<LioCb<'static>>, + fd: RawFd, + token: usize, + } + impl ReusableLioSource { + fn new(liocb: LioCb<'static>) -> Self { + ReusableLioSource { + liocb: Some(liocb), + fd: 0, + token: 0, + } + } + fn reset(&mut self, liocb: LioCb<'static>) { + self.liocb = Some(liocb); + } + fn submit(&mut self) { + self.liocb + .as_mut() + .unwrap() + .register_raw(self.fd, self.token); + self.liocb.as_mut().unwrap().submit().unwrap(); + } + } + impl AioSource for ReusableLioSource { + fn register(&mut self, kq: RawFd, token: usize) { + self.fd = kq; + self.token = token; + } + fn deregister(&mut self) { + self.fd = 0; + } + } + struct ReusableLioFut<'a>(&'a mut Aio<ReusableLioSource>); + impl<'a> Future for ReusableLioFut<'a> { + type Output = std::io::Result<Vec<isize>>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let poll_result = self.0.poll_ready(cx); + match poll_result { + Poll::Pending => Poll::Pending, + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Ready(Ok(ev)) => { + // Since this future uses a reusable Aio, we must clear + // its readiness here. That makes the future + // non-idempotent; the caller can't poll it repeatedly after + // it has already returned Ready. But that's ok; most + // futures behave this way. + self.0.clear_ready(ev); + let r = (*self.0).liocb.take().unwrap().into_results(|iter| { + iter.map(|lr| lr.result.unwrap()).collect::<Vec<isize>>() + }); + Poll::Ready(Ok(r)) + } + } + } + } + + /// An lio_listio operation with one write element + #[tokio::test] + async fn onewrite() { + const WBUF: &[u8] = b"abcdef"; + let f = tempfile().unwrap(); + + let mut builder = mio_aio::LioCbBuilder::with_capacity(1); + builder = builder.emplace_slice( + f.as_raw_fd(), + 0, + &WBUF[..], + 0, + mio_aio::LioOpcode::LIO_WRITE, + ); + let liocb = builder.finish(); + let source = WrappedLioCb(liocb); + let mut poll_aio = Aio::new_for_lio(source).unwrap(); + + // Send the operation to the kernel + (*poll_aio).0.submit().unwrap(); + let fut = LioFut(Some(poll_aio)); + let v = fut.await.unwrap(); + assert_eq!(v.len(), 1); + assert_eq!(v[0] as usize, WBUF.len()); + } + + /// A suitably crafted future type can reuse an Aio object + #[tokio::test] + async fn reuse() { + const WBUF: &[u8] = b"abcdef"; + let f = tempfile().unwrap(); + + let mut builder0 = mio_aio::LioCbBuilder::with_capacity(1); + builder0 = builder0.emplace_slice( + f.as_raw_fd(), + 0, + &WBUF[..], + 0, + mio_aio::LioOpcode::LIO_WRITE, + ); + let liocb0 = builder0.finish(); + let source = ReusableLioSource::new(liocb0); + let mut poll_aio = Aio::new_for_aio(source).unwrap(); + poll_aio.submit(); + let fut0 = ReusableLioFut(&mut poll_aio); + let v = fut0.await.unwrap(); + assert_eq!(v.len(), 1); + assert_eq!(v[0] as usize, WBUF.len()); + + // Now reuse the same Aio + let mut builder1 = mio_aio::LioCbBuilder::with_capacity(1); + builder1 = builder1.emplace_slice( + f.as_raw_fd(), + 0, + &WBUF[..], + 0, + mio_aio::LioOpcode::LIO_WRITE, + ); + let liocb1 = builder1.finish(); + poll_aio.reset(liocb1); + let mut ctx = Context::from_waker(futures::task::noop_waker_ref()); + assert_pending!(poll_aio.poll_ready(&mut ctx)); + poll_aio.submit(); + let fut1 = ReusableLioFut(&mut poll_aio); + let v = fut1.await.unwrap(); + assert_eq!(v.len(), 1); + assert_eq!(v[0] as usize, WBUF.len()); + } +} diff --git a/tests/io_split.rs b/tests/io_split.rs index db168e9..a012166 100644 --- a/tests/io_split.rs +++ b/tests/io_split.rs @@ -50,10 +50,10 @@ fn is_send_and_sync() { fn split_stream_id() { let (r1, w1) = split(RW); let (r2, w2) = split(RW); - assert_eq!(r1.is_pair_of(&w1), true); - assert_eq!(r1.is_pair_of(&w2), false); - assert_eq!(r2.is_pair_of(&w2), true); - assert_eq!(r2.is_pair_of(&w1), false); + assert!(r1.is_pair_of(&w1)); + assert!(!r1.is_pair_of(&w2)); + assert!(r2.is_pair_of(&w2)); + assert!(!r2.is_pair_of(&w1)); } #[test] diff --git a/tests/io_write_all_buf.rs b/tests/io_write_all_buf.rs new file mode 100644 index 0000000..7c8b619 --- /dev/null +++ b/tests/io_write_all_buf.rs @@ -0,0 +1,96 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio_test::{assert_err, assert_ok}; + +use bytes::{Buf, Bytes, BytesMut}; +use std::cmp; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +#[tokio::test] +async fn write_all_buf() { + struct Wr { + buf: BytesMut, + cnt: usize, + } + + impl AsyncWrite for Wr { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll<io::Result<usize>> { + let n = cmp::min(4, buf.len()); + dbg!(buf); + let buf = &buf[0..n]; + + self.cnt += 1; + self.buf.extend(buf); + Ok(buf.len()).into() + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Ok(()).into() + } + } + + let mut wr = Wr { + buf: BytesMut::with_capacity(64), + cnt: 0, + }; + + let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world")); + + assert_ok!(wr.write_all_buf(&mut buf).await); + assert_eq!(wr.buf, b"helloworld"[..]); + // expect 4 writes, [hell],[o],[worl],[d] + assert_eq!(wr.cnt, 4); + assert!(!buf.has_remaining()); +} + +#[tokio::test] +async fn write_buf_err() { + /// Error out after writing the first 4 bytes + struct Wr { + cnt: usize, + } + + impl AsyncWrite for Wr { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll<io::Result<usize>> { + self.cnt += 1; + if self.cnt == 2 { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops"))); + } + Poll::Ready(Ok(4)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> { + Ok(()).into() + } + } + + let mut wr = Wr { cnt: 0 }; + + let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world")); + + assert_err!(wr.write_all_buf(&mut buf).await); + assert_eq!( + buf.copy_to_bytes(buf.remaining()), + Bytes::from_static(b"oworld") + ); +} diff --git a/tests/macros_select.rs b/tests/macros_select.rs index ea06d51..b4f8544 100644 --- a/tests/macros_select.rs +++ b/tests/macros_select.rs @@ -359,11 +359,22 @@ async fn join_with_select() { async fn use_future_in_if_condition() { use tokio::time::{self, Duration}; - let sleep = time::sleep(Duration::from_millis(50)); - tokio::pin!(sleep); + tokio::select! { + _ = time::sleep(Duration::from_millis(10)), if false => { + panic!("if condition ignored") + } + _ = async { 1u32 } => { + } + } +} + +#[tokio::test] +async fn use_future_in_if_condition_biased() { + use tokio::time::{self, Duration}; tokio::select! { - _ = time::sleep(Duration::from_millis(50)), if false => { + biased; + _ = time::sleep(Duration::from_millis(10)), if false => { panic!("if condition ignored") } _ = async { 1u32 } => { @@ -459,10 +470,7 @@ async fn require_mutable(_: &mut i32) {} async fn async_noop() {} async fn async_never() -> ! { - use tokio::time::Duration; - loop { - tokio::time::sleep(Duration::from_millis(10)).await; - } + futures::future::pending().await } // From https://github.com/tokio-rs/tokio/issues/2857 @@ -540,3 +548,39 @@ async fn biased_eventually_ready() { assert_eq!(count, 3); } + +// https://github.com/tokio-rs/tokio/issues/3830 +// https://github.com/rust-lang/rust-clippy/issues/7304 +#[warn(clippy::default_numeric_fallback)] +pub async fn default_numeric_fallback() { + tokio::select! { + _ = async {} => (), + else => (), + } +} + +// https://github.com/tokio-rs/tokio/issues/4182 +#[tokio::test] +async fn mut_ref_patterns() { + tokio::select! { + Some(mut foo) = async { Some("1".to_string()) } => { + assert_eq!(foo, "1"); + foo = "2".to_string(); + assert_eq!(foo, "2"); + }, + }; + + tokio::select! { + Some(ref foo) = async { Some("1".to_string()) } => { + assert_eq!(*foo, "1"); + }, + }; + + tokio::select! { + Some(ref mut foo) = async { Some("1".to_string()) } => { + assert_eq!(*foo, "1"); + *foo = "2".to_string(); + assert_eq!(*foo, "2"); + }, + }; +} diff --git a/tests/macros_test.rs b/tests/macros_test.rs index 8396398..bca2c91 100644 --- a/tests/macros_test.rs +++ b/tests/macros_test.rs @@ -2,26 +2,47 @@ use tokio::test; #[test] async fn test_macro_can_be_used_via_use() { - tokio::spawn(async { - assert_eq!(1 + 1, 2); - }) - .await - .unwrap(); + tokio::spawn(async {}).await.unwrap(); } #[tokio::test] async fn test_macro_is_resilient_to_shadowing() { - tokio::spawn(async { - assert_eq!(1 + 1, 2); - }) - .await - .unwrap(); + tokio::spawn(async {}).await.unwrap(); } // https://github.com/tokio-rs/tokio/issues/3403 #[rustfmt::skip] // this `rustfmt::skip` is necessary because unused_braces does not warn if the block contains newline. #[tokio::main] -async fn unused_braces_main() { println!("hello") } +pub async fn unused_braces_main() { println!("hello") } #[rustfmt::skip] // this `rustfmt::skip` is necessary because unused_braces does not warn if the block contains newline. #[tokio::test] async fn unused_braces_test() { assert_eq!(1 + 1, 2) } + +// https://github.com/tokio-rs/tokio/pull/3766#issuecomment-835508651 +#[std::prelude::v1::test] +fn trait_method() { + trait A { + fn f(self); + } + impl A for () { + #[tokio::main] + async fn f(self) {} + } + ().f() +} + +// https://github.com/tokio-rs/tokio/issues/4175 +#[tokio::main] +pub async fn issue_4175_main_1() -> ! { + panic!(); +} +#[tokio::main] +pub async fn issue_4175_main_2() -> std::io::Result<()> { + panic!(); +} +#[allow(unreachable_code)] +#[tokio::test] +pub async fn issue_4175_test() -> std::io::Result<()> { + return Ok(()); + panic!(); +} diff --git a/tests/named_pipe.rs b/tests/named_pipe.rs new file mode 100644 index 0000000..2055c3c --- /dev/null +++ b/tests/named_pipe.rs @@ -0,0 +1,393 @@ +#![cfg(feature = "full")] +#![cfg(all(windows))] + +use std::io; +use std::mem; +use std::os::windows::io::AsRawHandle; +use std::time::Duration; +use tokio::io::AsyncWriteExt; +use tokio::net::windows::named_pipe::{ClientOptions, PipeMode, ServerOptions}; +use tokio::time; +use winapi::shared::winerror; + +#[tokio::test] +async fn test_named_pipe_client_drop() -> io::Result<()> { + const PIPE_NAME: &str = r"\\.\pipe\test-named-pipe-client-drop"; + + let mut server = ServerOptions::new().create(PIPE_NAME)?; + + assert_eq!(num_instances("test-named-pipe-client-drop")?, 1); + + let client = ClientOptions::new().open(PIPE_NAME)?; + + server.connect().await?; + drop(client); + + // instance will be broken because client is gone + match server.write_all(b"ping").await { + Err(e) if e.raw_os_error() == Some(winerror::ERROR_NO_DATA as i32) => (), + x => panic!("{:?}", x), + } + + Ok(()) +} + +#[tokio::test] +async fn test_named_pipe_single_client() -> io::Result<()> { + use tokio::io::{AsyncBufReadExt as _, BufReader}; + + const PIPE_NAME: &str = r"\\.\pipe\test-named-pipe-single-client"; + + let server = ServerOptions::new().create(PIPE_NAME)?; + + let server = tokio::spawn(async move { + // Note: we wait for a client to connect. + server.connect().await?; + + let mut server = BufReader::new(server); + + let mut buf = String::new(); + server.read_line(&mut buf).await?; + server.write_all(b"pong\n").await?; + Ok::<_, io::Error>(buf) + }); + + let client = tokio::spawn(async move { + let client = ClientOptions::new().open(PIPE_NAME)?; + + let mut client = BufReader::new(client); + + let mut buf = String::new(); + client.write_all(b"ping\n").await?; + client.read_line(&mut buf).await?; + Ok::<_, io::Error>(buf) + }); + + let (server, client) = tokio::try_join!(server, client)?; + + assert_eq!(server?, "ping\n"); + assert_eq!(client?, "pong\n"); + + Ok(()) +} + +#[tokio::test] +async fn test_named_pipe_multi_client() -> io::Result<()> { + use tokio::io::{AsyncBufReadExt as _, BufReader}; + + const PIPE_NAME: &str = r"\\.\pipe\test-named-pipe-multi-client"; + const N: usize = 10; + + // The first server needs to be constructed early so that clients can + // be correctly connected. Otherwise calling .wait will cause the client to + // error. + let mut server = ServerOptions::new().create(PIPE_NAME)?; + + let server = tokio::spawn(async move { + for _ in 0..N { + // Wait for client to connect. + server.connect().await?; + let mut inner = BufReader::new(server); + + // Construct the next server to be connected before sending the one + // we already have of onto a task. This ensures that the server + // isn't closed (after it's done in the task) before a new one is + // available. Otherwise the client might error with + // `io::ErrorKind::NotFound`. + server = ServerOptions::new().create(PIPE_NAME)?; + + let _ = tokio::spawn(async move { + let mut buf = String::new(); + inner.read_line(&mut buf).await?; + inner.write_all(b"pong\n").await?; + inner.flush().await?; + Ok::<_, io::Error>(()) + }); + } + + Ok::<_, io::Error>(()) + }); + + let mut clients = Vec::new(); + + for _ in 0..N { + clients.push(tokio::spawn(async move { + // This showcases a generic connect loop. + // + // We immediately try to create a client, if it's not found or the + // pipe is busy we use the specialized wait function on the client + // builder. + let client = loop { + match ClientOptions::new().open(PIPE_NAME) { + Ok(client) => break client, + Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), + Err(e) if e.kind() == io::ErrorKind::NotFound => (), + Err(e) => return Err(e), + } + + // Wait for a named pipe to become available. + time::sleep(Duration::from_millis(10)).await; + }; + + let mut client = BufReader::new(client); + + let mut buf = String::new(); + client.write_all(b"ping\n").await?; + client.flush().await?; + client.read_line(&mut buf).await?; + Ok::<_, io::Error>(buf) + })); + } + + for client in clients { + let result = client.await?; + assert_eq!(result?, "pong\n"); + } + + server.await??; + Ok(()) +} + +#[tokio::test] +async fn test_named_pipe_multi_client_ready() -> io::Result<()> { + use tokio::io::Interest; + + const PIPE_NAME: &str = r"\\.\pipe\test-named-pipe-multi-client-ready"; + const N: usize = 10; + + // The first server needs to be constructed early so that clients can + // be correctly connected. Otherwise calling .wait will cause the client to + // error. + let mut server = ServerOptions::new().create(PIPE_NAME)?; + + let server = tokio::spawn(async move { + for _ in 0..N { + // Wait for client to connect. + server.connect().await?; + + let inner_server = server; + + // Construct the next server to be connected before sending the one + // we already have of onto a task. This ensures that the server + // isn't closed (after it's done in the task) before a new one is + // available. Otherwise the client might error with + // `io::ErrorKind::NotFound`. + server = ServerOptions::new().create(PIPE_NAME)?; + + let _ = tokio::spawn(async move { + let server = inner_server; + + { + let mut read_buf = [0u8; 5]; + let mut read_buf_cursor = 0; + + loop { + server.readable().await?; + + let buf = &mut read_buf[read_buf_cursor..]; + + match server.try_read(buf) { + Ok(n) => { + read_buf_cursor += n; + + if read_buf_cursor == read_buf.len() { + break; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + return Err(e); + } + } + } + }; + + { + let write_buf = b"pong\n"; + let mut write_buf_cursor = 0; + + loop { + server.writable().await?; + let buf = &write_buf[write_buf_cursor..]; + + match server.try_write(buf) { + Ok(n) => { + write_buf_cursor += n; + + if write_buf_cursor == write_buf.len() { + break; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + return Err(e); + } + } + } + } + + Ok::<_, io::Error>(()) + }); + } + + Ok::<_, io::Error>(()) + }); + + let mut clients = Vec::new(); + + for _ in 0..N { + clients.push(tokio::spawn(async move { + // This showcases a generic connect loop. + // + // We immediately try to create a client, if it's not found or the + // pipe is busy we use the specialized wait function on the client + // builder. + let client = loop { + match ClientOptions::new().open(PIPE_NAME) { + Ok(client) => break client, + Err(e) if e.raw_os_error() == Some(winerror::ERROR_PIPE_BUSY as i32) => (), + Err(e) if e.kind() == io::ErrorKind::NotFound => (), + Err(e) => return Err(e), + } + + // Wait for a named pipe to become available. + time::sleep(Duration::from_millis(10)).await; + }; + + let mut read_buf = [0u8; 5]; + let mut read_buf_cursor = 0; + let write_buf = b"ping\n"; + let mut write_buf_cursor = 0; + + loop { + let mut interest = Interest::READABLE; + if write_buf_cursor < write_buf.len() { + interest |= Interest::WRITABLE; + } + + let ready = client.ready(interest).await?; + + if ready.is_readable() { + let buf = &mut read_buf[read_buf_cursor..]; + + match client.try_read(buf) { + Ok(n) => { + read_buf_cursor += n; + + if read_buf_cursor == read_buf.len() { + break; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + return Err(e); + } + } + } + + if ready.is_writable() { + let buf = &write_buf[write_buf_cursor..]; + + if buf.is_empty() { + continue; + } + + match client.try_write(buf) { + Ok(n) => { + write_buf_cursor += n; + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + continue; + } + Err(e) => { + return Err(e); + } + } + } + } + + let buf = String::from_utf8_lossy(&read_buf).into_owned(); + + Ok::<_, io::Error>(buf) + })); + } + + for client in clients { + let result = client.await?; + assert_eq!(result?, "pong\n"); + } + + server.await??; + Ok(()) +} + +// This tests what happens when a client tries to disconnect. +#[tokio::test] +async fn test_named_pipe_mode_message() -> io::Result<()> { + const PIPE_NAME: &str = r"\\.\pipe\test-named-pipe-mode-message"; + + let server = ServerOptions::new() + .pipe_mode(PipeMode::Message) + .create(PIPE_NAME)?; + + let _ = ClientOptions::new().open(PIPE_NAME)?; + server.connect().await?; + Ok(()) +} + +fn num_instances(pipe_name: impl AsRef<str>) -> io::Result<u32> { + use ntapi::ntioapi; + use winapi::shared::ntdef; + + let mut name = pipe_name.as_ref().encode_utf16().collect::<Vec<_>>(); + let mut name = ntdef::UNICODE_STRING { + Length: (name.len() * mem::size_of::<u16>()) as u16, + MaximumLength: (name.len() * mem::size_of::<u16>()) as u16, + Buffer: name.as_mut_ptr(), + }; + let root = std::fs::File::open(r"\\.\Pipe\")?; + let mut io_status_block = unsafe { mem::zeroed() }; + let mut file_directory_information = [0_u8; 1024]; + + let status = unsafe { + ntioapi::NtQueryDirectoryFile( + root.as_raw_handle(), + std::ptr::null_mut(), + None, + std::ptr::null_mut(), + &mut io_status_block, + &mut file_directory_information as *mut _ as *mut _, + 1024, + ntioapi::FileDirectoryInformation, + 0, + &mut name, + 0, + ) + }; + + if status as u32 != winerror::NO_ERROR { + return Err(io::Error::last_os_error()); + } + + let info = unsafe { + mem::transmute::<_, &ntioapi::FILE_DIRECTORY_INFORMATION>(&file_directory_information) + }; + let raw_name = unsafe { + std::slice::from_raw_parts( + info.FileName.as_ptr(), + info.FileNameLength as usize / mem::size_of::<u16>(), + ) + }; + let name = String::from_utf16(raw_name).unwrap(); + let num_instances = unsafe { *info.EndOfFile.QuadPart() }; + + assert_eq!(name, pipe_name.as_ref()); + + Ok(num_instances as u32) +} diff --git a/tests/no_rt.rs b/tests/no_rt.rs index 8437b80..6845850 100644 --- a/tests/no_rt.rs +++ b/tests/no_rt.rs @@ -26,7 +26,7 @@ fn panics_when_no_reactor() { async fn timeout_value() { let (_tx, rx) = oneshot::channel::<()>(); - let dur = Duration::from_millis(20); + let dur = Duration::from_millis(10); let _ = timeout(dur, rx).await; } diff --git a/tests/process_arg0.rs b/tests/process_arg0.rs new file mode 100644 index 0000000..4fabea0 --- /dev/null +++ b/tests/process_arg0.rs @@ -0,0 +1,13 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "full", unix))] + +use tokio::process::Command; + +#[tokio::test] +async fn arg0() { + let mut cmd = Command::new("sh"); + cmd.arg0("test_string").arg("-c").arg("echo $0"); + + let output = cmd.output().await.unwrap(); + assert_eq!(output.stdout, b"test_string\n"); +} diff --git a/tests/process_kill_on_drop.rs b/tests/process_kill_on_drop.rs index 00f5c6d..658e4ad 100644 --- a/tests/process_kill_on_drop.rs +++ b/tests/process_kill_on_drop.rs @@ -1,6 +1,7 @@ #![cfg(all(unix, feature = "process"))] #![warn(rust_2018_idioms)] +use std::io::ErrorKind; use std::process::Stdio; use std::time::Duration; use tokio::io::AsyncReadExt; @@ -24,11 +25,12 @@ async fn kill_on_drop() { ", ]); - let mut child = cmd - .kill_on_drop(true) - .stdout(Stdio::piped()) - .spawn() - .unwrap(); + let e = cmd.kill_on_drop(true).stdout(Stdio::piped()).spawn(); + if e.is_err() && e.as_ref().unwrap_err().kind() == ErrorKind::NotFound { + println!("bash not available; skipping test"); + return; + } + let mut child = e.unwrap(); sleep(Duration::from_secs(2)).await; diff --git a/tests/process_raw_handle.rs b/tests/process_raw_handle.rs new file mode 100644 index 0000000..727e66d --- /dev/null +++ b/tests/process_raw_handle.rs @@ -0,0 +1,23 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] +#![cfg(windows)] + +use tokio::process::Command; +use winapi::um::processthreadsapi::GetProcessId; + +#[tokio::test] +async fn obtain_raw_handle() { + let mut cmd = Command::new("cmd"); + cmd.kill_on_drop(true); + cmd.arg("/c"); + cmd.arg("pause"); + + let child = cmd.spawn().unwrap(); + + let orig_id = child.id().expect("missing id"); + assert!(orig_id > 0); + + let handle = child.raw_handle().expect("process stopped"); + let handled_id = unsafe { GetProcessId(handle as _) }; + assert_eq!(handled_id, orig_id); +} diff --git a/tests/rt_basic.rs b/tests/rt_basic.rs index 4b1bdad..70056b1 100644 --- a/tests/rt_basic.rs +++ b/tests/rt_basic.rs @@ -3,10 +3,14 @@ use tokio::runtime::Runtime; use tokio::sync::oneshot; +use tokio::time::{timeout, Duration}; use tokio_test::{assert_err, assert_ok}; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::task::{Context, Poll}; use std::thread; -use tokio::time::{timeout, Duration}; mod support { pub(crate) mod mpsc_stream; @@ -136,6 +140,35 @@ fn acquire_mutex_in_drop() { } #[test] +fn drop_tasks_in_context() { + static SUCCESS: AtomicBool = AtomicBool::new(false); + + struct ContextOnDrop; + + impl Future for ContextOnDrop { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Pending + } + } + + impl Drop for ContextOnDrop { + fn drop(&mut self) { + if tokio::runtime::Handle::try_current().is_ok() { + SUCCESS.store(true, Ordering::SeqCst); + } + } + } + + let rt = rt(); + rt.spawn(ContextOnDrop); + drop(rt); + + assert!(SUCCESS.load(Ordering::SeqCst)); +} + +#[test] #[should_panic( expected = "A Tokio 1.x context was found, but timers are disabled. Call `enable_time` on the runtime builder to enable timers." )] diff --git a/tests/rt_common.rs b/tests/rt_common.rs index cb1d0f6..e5fc7a9 100644 --- a/tests/rt_common.rs +++ b/tests/rt_common.rs @@ -647,6 +647,7 @@ rt_test! { } #[test] + #[cfg(not(target_os = "android"))] fn panic_in_task() { let rt = rt(); let (tx, rx) = oneshot::channel(); diff --git a/tests/rt_handle_block_on.rs b/tests/rt_handle_block_on.rs index 5234258..17878c8 100644 --- a/tests/rt_handle_block_on.rs +++ b/tests/rt_handle_block_on.rs @@ -388,6 +388,28 @@ rt_test! { rt.block_on(async { some_non_async_function() }); } + + #[test] + fn spawn_after_runtime_dropped() { + use futures::future::FutureExt; + + let rt = rt(); + + let handle = rt.block_on(async move { + Handle::current() + }); + + let jh1 = handle.spawn(futures::future::pending::<()>()); + + drop(rt); + + let jh2 = handle.spawn(futures::future::pending::<()>()); + + let err1 = jh1.now_or_never().unwrap().unwrap_err(); + let err2 = jh2.now_or_never().unwrap().unwrap_err(); + assert!(err1.is_cancelled()); + assert!(err2.is_cancelled()); + } } multi_threaded_rt_test! { diff --git a/tests/rt_threaded.rs b/tests/rt_threaded.rs index 19b381c..5f047a7 100644 --- a/tests/rt_threaded.rs +++ b/tests/rt_threaded.rs @@ -12,8 +12,8 @@ use std::future::Future; use std::pin::Pin; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; -use std::sync::{mpsc, Arc}; -use std::task::{Context, Poll}; +use std::sync::{mpsc, Arc, Mutex}; +use std::task::{Context, Poll, Waker}; #[test] fn single_thread() { @@ -54,6 +54,7 @@ fn many_oneshot_futures() { drop(rt); } } + #[test] fn many_multishot_futures() { const CHAIN: usize = 200; @@ -405,6 +406,98 @@ async fn hang_on_shutdown() { tokio::time::sleep(std::time::Duration::from_secs(1)).await; } +/// Demonstrates tokio-rs/tokio#3869 +#[test] +fn wake_during_shutdown() { + struct Shared { + waker: Option<Waker>, + } + + struct MyFuture { + shared: Arc<Mutex<Shared>>, + put_waker: bool, + } + + impl MyFuture { + fn new() -> (Self, Self) { + let shared = Arc::new(Mutex::new(Shared { waker: None })); + let f1 = MyFuture { + shared: shared.clone(), + put_waker: true, + }; + let f2 = MyFuture { + shared, + put_waker: false, + }; + (f1, f2) + } + } + + impl Future for MyFuture { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let me = Pin::into_inner(self); + let mut lock = me.shared.lock().unwrap(); + println!("poll {}", me.put_waker); + if me.put_waker { + println!("putting"); + lock.waker = Some(cx.waker().clone()); + } + Poll::Pending + } + } + + impl Drop for MyFuture { + fn drop(&mut self) { + println!("drop {} start", self.put_waker); + let mut lock = self.shared.lock().unwrap(); + if !self.put_waker { + lock.waker.take().unwrap().wake(); + } + drop(lock); + println!("drop {} stop", self.put_waker); + } + } + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(1) + .enable_all() + .build() + .unwrap(); + + let (f1, f2) = MyFuture::new(); + + rt.spawn(f1); + rt.spawn(f2); + + rt.block_on(async { tokio::time::sleep(tokio::time::Duration::from_millis(20)).await }); +} + +#[should_panic] +#[tokio::test] +async fn test_block_in_place1() { + tokio::task::block_in_place(|| {}); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_block_in_place2() { + tokio::task::block_in_place(|| {}); +} + +#[should_panic] +#[tokio::main(flavor = "current_thread")] +#[test] +async fn test_block_in_place3() { + tokio::task::block_in_place(|| {}); +} + +#[tokio::main] +#[test] +async fn test_block_in_place4() { + tokio::task::block_in_place(|| {}); +} + fn rt() -> Runtime { Runtime::new().unwrap() } diff --git a/tests/support/io_vec.rs b/tests/support/io_vec.rs new file mode 100644 index 0000000..4ea47c7 --- /dev/null +++ b/tests/support/io_vec.rs @@ -0,0 +1,45 @@ +use std::io::IoSlice; +use std::ops::Deref; +use std::slice; + +pub struct IoBufs<'a, 'b>(&'b mut [IoSlice<'a>]); + +impl<'a, 'b> IoBufs<'a, 'b> { + pub fn new(slices: &'b mut [IoSlice<'a>]) -> Self { + IoBufs(slices) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn advance(mut self, n: usize) -> IoBufs<'a, 'b> { + let mut to_remove = 0; + let mut remaining_len = n; + for slice in self.0.iter() { + if remaining_len < slice.len() { + break; + } else { + remaining_len -= slice.len(); + to_remove += 1; + } + } + self.0 = self.0.split_at_mut(to_remove).1; + if let Some(slice) = self.0.first_mut() { + let tail = &slice[remaining_len..]; + // Safety: recasts slice to the original lifetime + let tail = unsafe { slice::from_raw_parts(tail.as_ptr(), tail.len()) }; + *slice = IoSlice::new(tail); + } else if remaining_len != 0 { + panic!("advance past the end of the slice vector"); + } + self + } +} + +impl<'a, 'b> Deref for IoBufs<'a, 'b> { + type Target = [IoSlice<'a>]; + fn deref(&self) -> &[IoSlice<'a>] { + self.0 + } +} diff --git a/tests/support/mock_file.rs b/tests/support/mock_file.rs deleted file mode 100644 index 1ce326b..0000000 --- a/tests/support/mock_file.rs +++ /dev/null @@ -1,295 +0,0 @@ -#![allow(clippy::unnecessary_operation)] - -use std::collections::VecDeque; -use std::fmt; -use std::fs::{Metadata, Permissions}; -use std::io; -use std::io::prelude::*; -use std::io::SeekFrom; -use std::path::PathBuf; -use std::sync::{Arc, Mutex}; - -pub struct File { - shared: Arc<Mutex<Shared>>, -} - -pub struct Handle { - shared: Arc<Mutex<Shared>>, -} - -struct Shared { - calls: VecDeque<Call>, -} - -#[derive(Debug)] -enum Call { - Read(io::Result<Vec<u8>>), - Write(io::Result<Vec<u8>>), - Seek(SeekFrom, io::Result<u64>), - SyncAll(io::Result<()>), - SyncData(io::Result<()>), - SetLen(u64, io::Result<()>), -} - -impl Handle { - pub fn read(&self, data: &[u8]) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::Read(Ok(data.to_owned()))); - self - } - - pub fn read_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Read(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn write(&self, data: &[u8]) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::Write(Ok(data.to_owned()))); - self - } - - pub fn write_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Write(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn seek_start_ok(&self, offset: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Seek(SeekFrom::Start(offset), Ok(offset))); - self - } - - pub fn seek_current_ok(&self, offset: i64, ret: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::Seek(SeekFrom::Current(offset), Ok(ret))); - self - } - - pub fn sync_all(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SyncAll(Ok(()))); - self - } - - pub fn sync_all_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SyncAll(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn sync_data(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SyncData(Ok(()))); - self - } - - pub fn sync_data_err(&self) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SyncData(Err(io::ErrorKind::Other.into()))); - self - } - - pub fn set_len(&self, size: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls.push_back(Call::SetLen(size, Ok(()))); - self - } - - pub fn set_len_err(&self, size: u64) -> &Self { - let mut s = self.shared.lock().unwrap(); - s.calls - .push_back(Call::SetLen(size, Err(io::ErrorKind::Other.into()))); - self - } - - pub fn remaining(&self) -> usize { - let s = self.shared.lock().unwrap(); - s.calls.len() - } -} - -impl Drop for Handle { - fn drop(&mut self) { - if !std::thread::panicking() { - let s = self.shared.lock().unwrap(); - assert_eq!(0, s.calls.len()); - } - } -} - -impl File { - pub fn open(_: PathBuf) -> io::Result<File> { - unimplemented!(); - } - - pub fn create(_: PathBuf) -> io::Result<File> { - unimplemented!(); - } - - pub fn mock() -> (Handle, File) { - let shared = Arc::new(Mutex::new(Shared { - calls: VecDeque::new(), - })); - - let handle = Handle { - shared: shared.clone(), - }; - let file = File { shared }; - - (handle, file) - } - - pub fn sync_all(&self) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SyncAll(ret)) => ret, - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn sync_data(&self) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SyncData(ret)) => ret, - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn set_len(&self, size: u64) -> io::Result<()> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(SetLen(arg, ret)) => { - assert_eq!(arg, size); - ret - } - Some(op) => panic!("expected next call to be {:?}; was sync_all", op), - None => panic!("did not expect call"), - } - } - - pub fn metadata(&self) -> io::Result<Metadata> { - unimplemented!(); - } - - pub fn set_permissions(&self, _perm: Permissions) -> io::Result<()> { - unimplemented!(); - } - - pub fn try_clone(&self) -> io::Result<Self> { - unimplemented!(); - } -} - -impl Read for &'_ File { - fn read(&mut self, dst: &mut [u8]) -> io::Result<usize> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Read(Ok(data))) => { - assert!(dst.len() >= data.len()); - assert!(dst.len() <= 16 * 1024, "actual = {}", dst.len()); // max buffer - - &mut dst[..data.len()].copy_from_slice(&data); - Ok(data.len()) - } - Some(Read(Err(e))) => Err(e), - Some(op) => panic!("expected next call to be {:?}; was a read", op), - None => panic!("did not expect call"), - } - } -} - -impl Write for &'_ File { - fn write(&mut self, src: &[u8]) -> io::Result<usize> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Write(Ok(data))) => { - assert_eq!(src, &data[..]); - Ok(src.len()) - } - Some(Write(Err(e))) => Err(e), - Some(op) => panic!("expected next call to be {:?}; was write", op), - None => panic!("did not expect call"), - } - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -impl Seek for &'_ File { - fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> { - use self::Call::*; - - let mut s = self.shared.lock().unwrap(); - - match s.calls.pop_front() { - Some(Seek(expect, res)) => { - assert_eq!(expect, pos); - res - } - Some(op) => panic!("expected call {:?}; was `seek`", op), - None => panic!("did not expect call; was `seek`"), - } - } -} - -impl fmt::Debug for File { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("mock::File").finish() - } -} - -#[cfg(unix)] -impl std::os::unix::io::AsRawFd for File { - fn as_raw_fd(&self) -> std::os::unix::io::RawFd { - unimplemented!(); - } -} - -#[cfg(unix)] -impl std::os::unix::io::FromRawFd for File { - unsafe fn from_raw_fd(_: std::os::unix::io::RawFd) -> Self { - unimplemented!(); - } -} - -#[cfg(windows)] -impl std::os::windows::io::AsRawHandle for File { - fn as_raw_handle(&self) -> std::os::windows::io::RawHandle { - unimplemented!(); - } -} - -#[cfg(windows)] -impl std::os::windows::io::FromRawHandle for File { - unsafe fn from_raw_handle(_: std::os::windows::io::RawHandle) -> Self { - unimplemented!(); - } -} diff --git a/tests/support/mock_pool.rs b/tests/support/mock_pool.rs deleted file mode 100644 index e1fdb42..0000000 --- a/tests/support/mock_pool.rs +++ /dev/null @@ -1,66 +0,0 @@ -use tokio::sync::oneshot; - -use std::cell::RefCell; -use std::collections::VecDeque; -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -thread_local! { - static QUEUE: RefCell<VecDeque<Box<dyn FnOnce() + Send>>> = RefCell::new(VecDeque::new()) -} - -#[derive(Debug)] -pub(crate) struct Blocking<T> { - rx: oneshot::Receiver<T>, -} - -pub(crate) fn run<F, R>(f: F) -> Blocking<R> -where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, -{ - let (tx, rx) = oneshot::channel(); - let task = Box::new(move || { - let _ = tx.send(f()); - }); - - QUEUE.with(|cell| cell.borrow_mut().push_back(task)); - - Blocking { rx } -} - -impl<T> Future for Blocking<T> { - type Output = Result<T, io::Error>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - use std::task::Poll::*; - - match Pin::new(&mut self.rx).poll(cx) { - Ready(Ok(v)) => Ready(Ok(v)), - Ready(Err(e)) => panic!("error = {:?}", e), - Pending => Pending, - } - } -} - -pub(crate) async fn asyncify<F, T>(f: F) -> io::Result<T> -where - F: FnOnce() -> io::Result<T> + Send + 'static, - T: Send + 'static, -{ - run(f).await? -} - -pub(crate) fn len() -> usize { - QUEUE.with(|cell| cell.borrow().len()) -} - -pub(crate) fn run_one() { - let task = QUEUE - .with(|cell| cell.borrow_mut().pop_front()) - .expect("expected task to run, but none ready"); - - task(); -} diff --git a/tests/sync_mpsc.rs b/tests/sync_mpsc.rs index cd43ad4..1947d26 100644 --- a/tests/sync_mpsc.rs +++ b/tests/sync_mpsc.rs @@ -5,7 +5,7 @@ use std::thread; use tokio::runtime::Runtime; use tokio::sync::mpsc; -use tokio::sync::mpsc::error::TrySendError; +use tokio::sync::mpsc::error::{TryRecvError, TrySendError}; use tokio_test::task; use tokio_test::{ assert_err, assert_ok, assert_pending, assert_ready, assert_ready_err, assert_ready_ok, @@ -328,6 +328,27 @@ async fn try_send_fail() { } #[tokio::test] +async fn try_send_fail_with_try_recv() { + let (tx, mut rx) = mpsc::channel(1); + + tx.try_send("hello").unwrap(); + + // This should fail + match assert_err!(tx.try_send("fail")) { + TrySendError::Full(..) => {} + _ => panic!(), + } + + assert_eq!(rx.try_recv(), Ok("hello")); + + assert_ok!(tx.try_send("goodbye")); + drop(tx); + + assert_eq!(rx.try_recv(), Ok("goodbye")); + assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected)); +} + +#[tokio::test] async fn try_reserve_fails() { let (tx, mut rx) = mpsc::channel(1); @@ -389,13 +410,15 @@ fn dropping_rx_closes_channel_for_try() { drop(rx); - { - let err = assert_err!(tx.try_send(msg.clone())); - match err { - TrySendError::Closed(..) => {} - _ => panic!(), - } - } + assert!(matches!( + tx.try_send(msg.clone()), + Err(TrySendError::Closed(_)) + )); + assert!(matches!(tx.try_reserve(), Err(TrySendError::Closed(_)))); + assert!(matches!( + tx.try_reserve_owned(), + Err(TrySendError::Closed(_)) + )); assert_eq!(1, Arc::strong_count(&msg)); } @@ -494,3 +517,83 @@ async fn permit_available_not_acquired_close() { drop(permit2); assert!(rx.recv().await.is_none()); } + +#[test] +fn try_recv_bounded() { + let (tx, mut rx) = mpsc::channel(5); + + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + assert!(tx.try_send("hello").is_err()); + + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + assert_eq!(Ok("hello"), rx.try_recv()); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + assert!(tx.try_send("hello").is_err()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + tx.try_send("hello").unwrap(); + drop(tx); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Ok("hello"), rx.try_recv()); + assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); +} + +#[test] +fn try_recv_unbounded() { + for num in 0..100 { + let (tx, mut rx) = mpsc::unbounded_channel(); + + for i in 0..num { + tx.send(i).unwrap(); + } + + for i in 0..num { + assert_eq!(rx.try_recv(), Ok(i)); + } + + assert_eq!(rx.try_recv(), Err(TryRecvError::Empty)); + drop(tx); + assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected)); + } +} + +#[test] +fn try_recv_close_while_empty_bounded() { + let (tx, mut rx) = mpsc::channel::<()>(5); + + assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + drop(tx); + assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); +} + +#[test] +fn try_recv_close_while_empty_unbounded() { + let (tx, mut rx) = mpsc::unbounded_channel::<()>(); + + assert_eq!(Err(TryRecvError::Empty), rx.try_recv()); + drop(tx); + assert_eq!(Err(TryRecvError::Disconnected), rx.try_recv()); +} diff --git a/tests/sync_mutex.rs b/tests/sync_mutex.rs index 0ddb203..090db94 100644 --- a/tests/sync_mutex.rs +++ b/tests/sync_mutex.rs @@ -139,12 +139,12 @@ fn try_lock() { let m: Mutex<usize> = Mutex::new(0); { let g1 = m.try_lock(); - assert_eq!(g1.is_ok(), true); + assert!(g1.is_ok()); let g2 = m.try_lock(); - assert_eq!(g2.is_ok(), false); + assert!(!g2.is_ok()); } let g3 = m.try_lock(); - assert_eq!(g3.is_ok(), true); + assert!(g3.is_ok()); } #[tokio::test] diff --git a/tests/sync_mutex_owned.rs b/tests/sync_mutex_owned.rs index 0f1399c..898bf35 100644 --- a/tests/sync_mutex_owned.rs +++ b/tests/sync_mutex_owned.rs @@ -106,12 +106,12 @@ fn try_lock_owned() { let m: Arc<Mutex<usize>> = Arc::new(Mutex::new(0)); { let g1 = m.clone().try_lock_owned(); - assert_eq!(g1.is_ok(), true); + assert!(g1.is_ok()); let g2 = m.clone().try_lock_owned(); - assert_eq!(g2.is_ok(), false); + assert!(!g2.is_ok()); } let g3 = m.try_lock_owned(); - assert_eq!(g3.is_ok(), true); + assert!(g3.is_ok()); } #[tokio::test] diff --git a/tests/sync_once_cell.rs b/tests/sync_once_cell.rs index 60f50d2..18eaf93 100644 --- a/tests/sync_once_cell.rs +++ b/tests/sync_once_cell.rs @@ -266,3 +266,9 @@ fn drop_into_inner_new_with() { let count = NUM_DROPS.load(Ordering::Acquire); assert!(count == 1); } + +#[test] +fn from() { + let cell = OnceCell::from(2); + assert_eq!(*cell.get().unwrap(), 2); +} diff --git a/tests/sync_rwlock.rs b/tests/sync_rwlock.rs index e12052b..7d05086 100644 --- a/tests/sync_rwlock.rs +++ b/tests/sync_rwlock.rs @@ -50,8 +50,8 @@ fn read_exclusive_pending() { assert_pending!(t2.poll()); } -// If the max shared access is reached and subsquent shared access is pending -// should be made available when one of the shared acesses is dropped +// If the max shared access is reached and subsequent shared access is pending +// should be made available when one of the shared accesses is dropped #[test] fn exhaust_reading() { let rwlock = RwLock::with_max_readers(100, 1024); diff --git a/tests/sync_watch.rs b/tests/sync_watch.rs index 9dcb0c5..b7bbaf7 100644 --- a/tests/sync_watch.rs +++ b/tests/sync_watch.rs @@ -169,3 +169,35 @@ fn poll_close() { assert!(tx.send("two").is_err()); } + +#[test] +fn borrow_and_update() { + let (tx, mut rx) = watch::channel("one"); + + tx.send("two").unwrap(); + assert_ready!(spawn(rx.changed()).poll()).unwrap(); + assert_pending!(spawn(rx.changed()).poll()); + + tx.send("three").unwrap(); + assert_eq!(*rx.borrow_and_update(), "three"); + assert_pending!(spawn(rx.changed()).poll()); + + drop(tx); + assert_eq!(*rx.borrow_and_update(), "three"); + assert_ready!(spawn(rx.changed()).poll()).unwrap_err(); +} + +#[test] +fn reopened_after_subscribe() { + let (tx, rx) = watch::channel("one"); + assert!(!tx.is_closed()); + + drop(rx); + assert!(tx.is_closed()); + + let rx = tx.subscribe(); + assert!(!tx.is_closed()); + + drop(rx); + assert!(tx.is_closed()); +} diff --git a/tests/task_abort.rs b/tests/task_abort.rs index 1d72ac3..06c61dc 100644 --- a/tests/task_abort.rs +++ b/tests/task_abort.rs @@ -1,11 +1,25 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] +use std::sync::Arc; +use std::thread::sleep; +use tokio::time::Duration; + +use tokio::runtime::Builder; + +struct PanicOnDrop; + +impl Drop for PanicOnDrop { + fn drop(&mut self) { + panic!("Well what did you expect would happen..."); + } +} + /// Checks that a suspended task can be aborted without panicking as reported in /// issue #3157: <https://github.com/tokio-rs/tokio/issues/3157>. #[test] fn test_abort_without_panic_3157() { - let rt = tokio::runtime::Builder::new_multi_thread() + let rt = Builder::new_multi_thread() .enable_time() .worker_threads(1) .build() @@ -14,11 +28,11 @@ fn test_abort_without_panic_3157() { rt.block_on(async move { let handle = tokio::spawn(async move { println!("task started"); - tokio::time::sleep(std::time::Duration::new(100, 0)).await + tokio::time::sleep(Duration::new(100, 0)).await }); // wait for task to sleep. - tokio::time::sleep(std::time::Duration::new(1, 0)).await; + tokio::time::sleep(Duration::from_millis(10)).await; handle.abort(); let _ = handle.await; @@ -41,9 +55,7 @@ fn test_abort_without_panic_3662() { } } - let rt = tokio::runtime::Builder::new_current_thread() - .build() - .unwrap(); + let rt = Builder::new_current_thread().build().unwrap(); rt.block_on(async move { let drop_flag = Arc::new(AtomicBool::new(false)); @@ -62,18 +74,16 @@ fn test_abort_without_panic_3662() { // This runs in a separate thread so it doesn't have immediate // thread-local access to the executor. It does however transition // the underlying task to be completed, which will cause it to be - // dropped (in this thread no less). + // dropped (but not in this thread). assert!(!drop_flag2.load(Ordering::SeqCst)); j.abort(); - // TODO: is this guaranteed at this point? - // assert!(drop_flag2.load(Ordering::SeqCst)); j }) .join() .unwrap(); - assert!(drop_flag.load(Ordering::SeqCst)); let result = task.await; + assert!(drop_flag.load(Ordering::SeqCst)); assert!(result.unwrap_err().is_cancelled()); // Note: We do the following to trigger a deferred task cleanup. @@ -82,7 +92,7 @@ fn test_abort_without_panic_3662() { // `Inner::block_on` of `basic_scheduler.rs`. // // We cause the cleanup to happen by having a poll return Pending once - // so that the scheduler can go into the "auxilliary tasks" mode, at + // so that the scheduler can go into the "auxiliary tasks" mode, at // which point the task is removed from the scheduler. let i = tokio::spawn(async move { tokio::task::yield_now().await; @@ -91,3 +101,126 @@ fn test_abort_without_panic_3662() { i.await.unwrap(); }); } + +/// Checks that a suspended LocalSet task can be aborted from a remote thread +/// without panicking and without running the tasks destructor on the wrong thread. +/// <https://github.com/tokio-rs/tokio/issues/3929> +#[test] +fn remote_abort_local_set_3929() { + struct DropCheck { + created_on: std::thread::ThreadId, + not_send: std::marker::PhantomData<*const ()>, + } + + impl DropCheck { + fn new() -> Self { + Self { + created_on: std::thread::current().id(), + not_send: std::marker::PhantomData, + } + } + } + impl Drop for DropCheck { + fn drop(&mut self) { + if std::thread::current().id() != self.created_on { + panic!("non-Send value dropped in another thread!"); + } + } + } + + let rt = Builder::new_current_thread().build().unwrap(); + let local = tokio::task::LocalSet::new(); + + let check = DropCheck::new(); + let jh = local.spawn_local(async move { + futures::future::pending::<()>().await; + drop(check); + }); + + let jh2 = std::thread::spawn(move || { + sleep(Duration::from_millis(10)); + jh.abort(); + }); + + rt.block_on(local); + jh2.join().unwrap(); +} + +/// Checks that a suspended task can be aborted even if the `JoinHandle` is immediately dropped. +/// issue #3964: <https://github.com/tokio-rs/tokio/issues/3964>. +#[test] +fn test_abort_wakes_task_3964() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let notify_dropped = Arc::new(()); + let weak_notify_dropped = Arc::downgrade(¬ify_dropped); + + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _notify_dropped = notify_dropped; + println!("task started"); + tokio::time::sleep(Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(Duration::from_millis(10)).await; + + handle.abort(); + drop(handle); + + // wait for task to abort. + tokio::time::sleep(Duration::from_millis(10)).await; + + // Check that the Arc has been dropped. + assert!(weak_notify_dropped.upgrade().is_none()); + }); +} + +/// Checks that aborting a task whose destructor panics does not allow the +/// panic to escape the task. +#[test] +#[cfg(not(target_os = "android"))] +fn test_abort_task_that_panics_on_drop_contained() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _panic_dropped = PanicOnDrop; + println!("task started"); + tokio::time::sleep(Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(Duration::from_millis(10)).await; + + handle.abort(); + drop(handle); + + // wait for task to abort. + tokio::time::sleep(Duration::from_millis(10)).await; + }); +} + +/// Checks that aborting a task whose destructor panics has the expected result. +#[test] +#[cfg(not(target_os = "android"))] +fn test_abort_task_that_panics_on_drop_returned() { + let rt = Builder::new_current_thread().enable_time().build().unwrap(); + + rt.block_on(async move { + let handle = tokio::spawn(async move { + // Make sure the Arc is moved into the task + let _panic_dropped = PanicOnDrop; + println!("task started"); + tokio::time::sleep(Duration::new(100, 0)).await + }); + + // wait for task to sleep. + tokio::time::sleep(Duration::from_millis(10)).await; + + handle.abort(); + assert!(handle.await.unwrap_err().is_panic()); + }); +} diff --git a/tests/task_blocking.rs b/tests/task_blocking.rs index 82bef8a..e6cde25 100644 --- a/tests/task_blocking.rs +++ b/tests/task_blocking.rs @@ -114,6 +114,7 @@ fn can_enter_basic_rt_from_within_block_in_place() { } #[test] +#[cfg(not(target_os = "android"))] fn useful_panic_message_when_dropping_rt_in_rt() { use std::panic::{catch_unwind, AssertUnwindSafe}; @@ -132,7 +133,7 @@ fn useful_panic_message_when_dropping_rt_in_rt() { let err: &'static str = err.downcast_ref::<&'static str>().unwrap(); assert!( - err.find("Cannot drop a runtime").is_some(), + err.contains("Cannot drop a runtime"), "Wrong panic message: {:?}", err ); diff --git a/tests/task_builder.rs b/tests/task_builder.rs new file mode 100644 index 0000000..1499abf --- /dev/null +++ b/tests/task_builder.rs @@ -0,0 +1,67 @@ +#[cfg(all(tokio_unstable, feature = "tracing"))] +mod tests { + use std::rc::Rc; + use tokio::{ + task::{Builder, LocalSet}, + test, + }; + + #[test] + async fn spawn_with_name() { + let result = Builder::new() + .name("name") + .spawn(async { "task executed" }) + .await; + + assert_eq!(result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_blocking_with_name() { + let result = Builder::new() + .name("name") + .spawn_blocking(|| "task executed") + .await; + + assert_eq!(result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_local_with_name() { + let unsend_data = Rc::new("task executed"); + let result = LocalSet::new() + .run_until(async move { + Builder::new() + .name("name") + .spawn_local(async move { unsend_data }) + .await + }) + .await; + + assert_eq!(*result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_without_name() { + let result = Builder::new().spawn(async { "task executed" }).await; + + assert_eq!(result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_blocking_without_name() { + let result = Builder::new().spawn_blocking(|| "task executed").await; + + assert_eq!(result.unwrap(), "task executed"); + } + + #[test] + async fn spawn_local_without_name() { + let unsend_data = Rc::new("task executed"); + let result = LocalSet::new() + .run_until(async move { Builder::new().spawn_local(async move { unsend_data }).await }) + .await; + + assert_eq!(*result.unwrap(), "task executed"); + } +} diff --git a/tests/task_local_set.rs b/tests/task_local_set.rs index 8513609..f8a35d0 100644 --- a/tests/task_local_set.rs +++ b/tests/task_local_set.rs @@ -67,11 +67,11 @@ async fn localset_future_timers() { let local = LocalSet::new(); local.spawn_local(async move { - time::sleep(Duration::from_millis(10)).await; + time::sleep(Duration::from_millis(5)).await; RAN1.store(true, Ordering::SeqCst); }); local.spawn_local(async move { - time::sleep(Duration::from_millis(20)).await; + time::sleep(Duration::from_millis(10)).await; RAN2.store(true, Ordering::SeqCst); }); local.await; @@ -299,9 +299,7 @@ fn drop_cancels_tasks() { let _rc2 = rc2; started_tx.send(()).unwrap(); - loop { - time::sleep(Duration::from_secs(3600)).await; - } + futures::future::pending::<()>().await; }); local.block_on(&rt, async { @@ -334,7 +332,7 @@ fn with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static) { // something we can easily make assertions about, we'll run it in a // thread. When the test thread finishes, it will send a message on a // channel to this thread. We'll wait for that message with a fairly - // generous timeout, and if we don't recieve it, we assume the test + // generous timeout, and if we don't receive it, we assume the test // thread has hung. // // Note that it should definitely complete in under a minute, but just @@ -400,13 +398,32 @@ fn local_tasks_wake_join_all() { }); } -#[tokio::test] -async fn local_tasks_are_polled_after_tick() { +#[test] +fn local_tasks_are_polled_after_tick() { + // This test depends on timing, so we run it up to five times. + for _ in 0..4 { + let res = std::panic::catch_unwind(local_tasks_are_polled_after_tick_inner); + if res.is_ok() { + // success + return; + } + } + + // Test failed 4 times. Try one more time without catching panics. If it + // fails again, the test fails. + local_tasks_are_polled_after_tick_inner(); +} + +#[tokio::main(flavor = "current_thread")] +async fn local_tasks_are_polled_after_tick_inner() { // Reproduces issues #1899 and #1900 static RX1: AtomicUsize = AtomicUsize::new(0); static RX2: AtomicUsize = AtomicUsize::new(0); - static EXPECTED: usize = 500; + const EXPECTED: usize = 500; + + RX1.store(0, SeqCst); + RX2.store(0, SeqCst); let (tx, mut rx) = mpsc::unbounded_channel(); @@ -416,7 +433,7 @@ async fn local_tasks_are_polled_after_tick() { .run_until(async { let task2 = task::spawn(async move { // Wait a bit - time::sleep(Duration::from_millis(100)).await; + time::sleep(Duration::from_millis(10)).await; let mut oneshots = Vec::with_capacity(EXPECTED); @@ -427,13 +444,13 @@ async fn local_tasks_are_polled_after_tick() { tx.send(oneshot_rx).unwrap(); } - time::sleep(Duration::from_millis(100)).await; + time::sleep(Duration::from_millis(10)).await; for tx in oneshots.drain(..) { tx.send(()).unwrap(); } - time::sleep(Duration::from_millis(300)).await; + time::sleep(Duration::from_millis(20)).await; let rx1 = RX1.load(SeqCst); let rx2 = RX2.load(SeqCst); println!("EXPECT = {}; RX1 = {}; RX2 = {}", EXPECTED, rx1, rx2); diff --git a/tests/tcp_into_split.rs b/tests/tcp_into_split.rs index b4bb2ee..2e06643 100644 --- a/tests/tcp_into_split.rs +++ b/tests/tcp_into_split.rs @@ -116,7 +116,7 @@ async fn drop_write() -> Result<()> { // drop it while the read is in progress std::thread::spawn(move || { - thread::sleep(std::time::Duration::from_millis(50)); + thread::sleep(std::time::Duration::from_millis(10)); drop(write_half); }); diff --git a/tests/tcp_into_std.rs b/tests/tcp_into_std.rs index a46aace..4bf24c1 100644 --- a/tests/tcp_into_std.rs +++ b/tests/tcp_into_std.rs @@ -10,10 +10,11 @@ use tokio::net::TcpStream; #[tokio::test] async fn tcp_into_std() -> Result<()> { let mut data = [0u8; 12]; - let listener = TcpListener::bind("127.0.0.1:34254").await?; + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr().unwrap().to_string(); let handle = tokio::spawn(async { - let stream: TcpStream = TcpStream::connect("127.0.0.1:34254").await.unwrap(); + let stream: TcpStream = TcpStream::connect(addr).await.unwrap(); stream }); diff --git a/tests/tcp_stream.rs b/tests/tcp_stream.rs index e34c2bb..0b5d12a 100644 --- a/tests/tcp_stream.rs +++ b/tests/tcp_stream.rs @@ -55,7 +55,7 @@ async fn try_read_write() { tokio::task::yield_now().await; } - // Fill the write buffer + // Fill the write buffer using non-vectored I/O loop { // Still ready let mut writable = task::spawn(client.writable()); @@ -75,7 +75,7 @@ async fn try_read_write() { let mut writable = task::spawn(client.writable()); assert_pending!(writable.poll()); - // Drain the socket from the server end + // Drain the socket from the server end using non-vectored I/O let mut read = vec![0; written.len()]; let mut i = 0; @@ -92,6 +92,51 @@ async fn try_read_write() { assert_eq!(read, written); } + written.clear(); + client.writable().await.unwrap(); + + // Fill the write buffer using vectored I/O + let data_bufs: Vec<_> = DATA.chunks(10).map(io::IoSlice::new).collect(); + loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + + match client.try_write_vectored(&data_bufs) { + Ok(n) => written.extend(&DATA[..n]), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + break; + } + Err(e) => panic!("error = {:?}", e), + } + } + + { + // Write buffer full + let mut writable = task::spawn(client.writable()); + assert_pending!(writable.poll()); + + // Drain the socket from the server end using vectored I/O + let mut read = vec![0; written.len()]; + let mut i = 0; + + while i < read.len() { + server.readable().await.unwrap(); + + let mut bufs: Vec<_> = read[i..] + .chunks_mut(0x10000) + .map(io::IoSliceMut::new) + .collect(); + match server.try_read_vectored(&mut bufs) { + Ok(n) => i += n, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("error = {:?}", e), + } + } + + assert_eq!(read, written); + } + // Now, we listen for shutdown drop(client); diff --git a/tests/time_interval.rs b/tests/time_interval.rs index a3c7f08..5f7bf55 100644 --- a/tests/time_interval.rs +++ b/tests/time_interval.rs @@ -1,56 +1,173 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::time::{self, Duration, Instant}; +use tokio::time::{self, Duration, Instant, MissedTickBehavior}; use tokio_test::{assert_pending, assert_ready_eq, task}; -use std::future::Future; use std::task::Poll; +// Takes the `Interval` task, `start` variable, and optional time deltas +// For each time delta, it polls the `Interval` and asserts that the result is +// equal to `start` + the specific time delta. Then it asserts that the +// `Interval` is pending. +macro_rules! check_interval_poll { + ($i:ident, $start:ident, $($delta:expr),*$(,)?) => { + $( + assert_ready_eq!(poll_next(&mut $i), $start + ms($delta)); + )* + assert_pending!(poll_next(&mut $i)); + }; + ($i:ident, $start:ident) => { + check_interval_poll!($i, $start,); + }; +} + #[tokio::test] #[should_panic] async fn interval_zero_duration() { let _ = time::interval_at(Instant::now(), ms(0)); } -#[tokio::test] -async fn usage() { - time::pause(); +// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | +// Actual ticks: | work -----| delay | work | work | work -| work -----| +// Poll behavior: | | | | | | | | +// | | | | | | | | +// Ready(s) | | Ready(s + 2p) | | | | +// Pending | Ready(s + 3p) | | | +// Ready(s + p) Ready(s + 4p) | | +// Ready(s + 5p) | +// Ready(s + 6p) +#[tokio::test(start_paused = true)] +async fn burst() { + let start = Instant::now(); + + // This is necessary because the timer is only so granular, and in order for + // all our ticks to resolve, the time needs to be 1ms ahead of what we + // expect, so that the runtime will see that it is time to resolve the timer + time::advance(ms(1)).await; + + let mut i = task::spawn(time::interval_at(start, ms(300))); + + check_interval_poll!(i, start, 0); + + time::advance(ms(100)).await; + check_interval_poll!(i, start); + + time::advance(ms(200)).await; + check_interval_poll!(i, start, 300); + + time::advance(ms(650)).await; + check_interval_poll!(i, start, 600, 900); + + time::advance(ms(200)).await; + check_interval_poll!(i, start); + + time::advance(ms(100)).await; + check_interval_poll!(i, start, 1200); + + time::advance(ms(250)).await; + check_interval_poll!(i, start, 1500); + + time::advance(ms(300)).await; + check_interval_poll!(i, start, 1800); +} +// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | +// Actual ticks: | work -----| delay | work -----| work -----| work -----| +// Poll behavior: | | | | | | | | +// | | | | | | | | +// Ready(s) | | Ready(s + 2p) | | | | +// Pending | Pending | | | +// Ready(s + p) Ready(s + 2p + d) | | +// Ready(s + 3p + d) | +// Ready(s + 4p + d) +#[tokio::test(start_paused = true)] +async fn delay() { let start = Instant::now(); - // TODO: Skip this + // This is necessary because the timer is only so granular, and in order for + // all our ticks to resolve, the time needs to be 1ms ahead of what we + // expect, so that the runtime will see that it is time to resolve the timer time::advance(ms(1)).await; let mut i = task::spawn(time::interval_at(start, ms(300))); + i.set_missed_tick_behavior(MissedTickBehavior::Delay); - assert_ready_eq!(poll_next(&mut i), start); - assert_pending!(poll_next(&mut i)); + check_interval_poll!(i, start, 0); time::advance(ms(100)).await; - assert_pending!(poll_next(&mut i)); + check_interval_poll!(i, start); time::advance(ms(200)).await; - assert_ready_eq!(poll_next(&mut i), start + ms(300)); - assert_pending!(poll_next(&mut i)); + check_interval_poll!(i, start, 300); + + time::advance(ms(650)).await; + check_interval_poll!(i, start, 600); + + time::advance(ms(100)).await; + check_interval_poll!(i, start); + + // We have to add one here for the same reason as is above. + // Because `Interval` has reset its timer according to `Instant::now()`, + // we have to go forward 1 more millisecond than is expected so that the + // runtime realizes that it's time to resolve the timer. + time::advance(ms(201)).await; + // We add one because when using the `Delay` behavior, `Interval` + // adds the `period` from `Instant::now()`, which will always be off by one + // because we have to advance time by 1 (see above). + check_interval_poll!(i, start, 1251); + + time::advance(ms(300)).await; + // Again, we add one. + check_interval_poll!(i, start, 1551); + + time::advance(ms(300)).await; + check_interval_poll!(i, start, 1851); +} + +// Expected ticks: | 1 | 2 | 3 | 4 | 5 | 6 | +// Actual ticks: | work -----| delay | work ---| work -----| work -----| +// Poll behavior: | | | | | | | +// | | | | | | | +// Ready(s) | | Ready(s + 2p) | | | +// Pending | Ready(s + 4p) | | +// Ready(s + p) Ready(s + 5p) | +// Ready(s + 6p) +#[tokio::test(start_paused = true)] +async fn skip() { + let start = Instant::now(); + + // This is necessary because the timer is only so granular, and in order for + // all our ticks to resolve, the time needs to be 1ms ahead of what we + // expect, so that the runtime will see that it is time to resolve the timer + time::advance(ms(1)).await; + + let mut i = task::spawn(time::interval_at(start, ms(300))); + i.set_missed_tick_behavior(MissedTickBehavior::Skip); + + check_interval_poll!(i, start, 0); + + time::advance(ms(100)).await; + check_interval_poll!(i, start); + + time::advance(ms(200)).await; + check_interval_poll!(i, start, 300); + + time::advance(ms(650)).await; + check_interval_poll!(i, start, 600); + + time::advance(ms(250)).await; + check_interval_poll!(i, start, 1200); - time::advance(ms(400)).await; - assert_ready_eq!(poll_next(&mut i), start + ms(600)); - assert_pending!(poll_next(&mut i)); + time::advance(ms(300)).await; + check_interval_poll!(i, start, 1500); - time::advance(ms(500)).await; - assert_ready_eq!(poll_next(&mut i), start + ms(900)); - assert_ready_eq!(poll_next(&mut i), start + ms(1200)); - assert_pending!(poll_next(&mut i)); + time::advance(ms(300)).await; + check_interval_poll!(i, start, 1800); } fn poll_next(interval: &mut task::Spawn<time::Interval>) -> Poll<Instant> { - interval.enter(|cx, mut interval| { - tokio::pin! { - let fut = interval.tick(); - } - fut.poll(cx) - }) + interval.enter(|cx, mut interval| interval.poll_tick(cx)) } fn ms(n: u64) -> Duration { diff --git a/tests/time_pause.rs b/tests/time_pause.rs index bc84ac5..02e050a 100644 --- a/tests/time_pause.rs +++ b/tests/time_pause.rs @@ -3,8 +3,14 @@ use rand::SeedableRng; use rand::{rngs::StdRng, Rng}; -use tokio::time::{self, Duration, Instant}; -use tokio_test::assert_err; +use tokio::time::{self, Duration, Instant, Sleep}; +use tokio_test::{assert_elapsed, assert_err, assert_pending, assert_ready, assert_ready_eq, task}; + +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; #[tokio::test] async fn pause_time_in_main() { @@ -57,3 +63,264 @@ async fn paused_time_stress_run() -> Vec<Duration> { times } + +#[tokio::test(start_paused = true)] +async fn advance_after_poll() { + time::sleep(ms(1)).await; + + let start = Instant::now(); + + let mut sleep = task::spawn(time::sleep_until(start + ms(300))); + + assert_pending!(sleep.poll()); + + let before = Instant::now(); + time::advance(ms(100)).await; + assert_elapsed!(before, ms(100)); + + assert_pending!(sleep.poll()); +} + +#[tokio::test(start_paused = true)] +async fn sleep_no_poll() { + let start = Instant::now(); + + // TODO: Skip this + time::advance(ms(1)).await; + + let mut sleep = task::spawn(time::sleep_until(start + ms(300))); + + let before = Instant::now(); + time::advance(ms(100)).await; + assert_elapsed!(before, ms(100)); + + assert_pending!(sleep.poll()); +} + +enum State { + Begin, + AwaitingAdvance(Pin<Box<dyn Future<Output = ()>>>), + AfterAdvance, +} + +struct Tester { + sleep: Pin<Box<Sleep>>, + state: State, + before: Option<Instant>, + poll: bool, +} + +impl Future for Tester { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + match &mut self.state { + State::Begin => { + if self.poll { + assert_pending!(self.sleep.as_mut().poll(cx)); + } + self.before = Some(Instant::now()); + let advance_fut = Box::pin(time::advance(ms(100))); + self.state = State::AwaitingAdvance(advance_fut); + self.poll(cx) + } + State::AwaitingAdvance(ref mut advance_fut) => match advance_fut.as_mut().poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(()) => { + self.state = State::AfterAdvance; + self.poll(cx) + } + }, + State::AfterAdvance => { + assert_elapsed!(self.before.unwrap(), ms(100)); + + assert_pending!(self.sleep.as_mut().poll(cx)); + + Poll::Ready(()) + } + } + } +} + +#[tokio::test(start_paused = true)] +async fn sleep_same_task() { + let start = Instant::now(); + + // TODO: Skip this + time::advance(ms(1)).await; + + let sleep = Box::pin(time::sleep_until(start + ms(300))); + + Tester { + sleep, + state: State::Begin, + before: None, + poll: true, + } + .await; +} + +#[tokio::test(start_paused = true)] +async fn sleep_same_task_no_poll() { + let start = Instant::now(); + + // TODO: Skip this + time::advance(ms(1)).await; + + let sleep = Box::pin(time::sleep_until(start + ms(300))); + + Tester { + sleep, + state: State::Begin, + before: None, + poll: false, + } + .await; +} + +#[tokio::test(start_paused = true)] +async fn interval() { + let start = Instant::now(); + + // TODO: Skip this + time::advance(ms(1)).await; + + let mut i = task::spawn(time::interval_at(start, ms(300))); + + assert_ready_eq!(poll_next(&mut i), start); + assert_pending!(poll_next(&mut i)); + + let before = Instant::now(); + time::advance(ms(100)).await; + assert_elapsed!(before, ms(100)); + assert_pending!(poll_next(&mut i)); + + let before = Instant::now(); + time::advance(ms(200)).await; + assert_elapsed!(before, ms(200)); + assert_ready_eq!(poll_next(&mut i), start + ms(300)); + assert_pending!(poll_next(&mut i)); + + let before = Instant::now(); + time::advance(ms(400)).await; + assert_elapsed!(before, ms(400)); + assert_ready_eq!(poll_next(&mut i), start + ms(600)); + assert_pending!(poll_next(&mut i)); + + let before = Instant::now(); + time::advance(ms(500)).await; + assert_elapsed!(before, ms(500)); + assert_ready_eq!(poll_next(&mut i), start + ms(900)); + assert_ready_eq!(poll_next(&mut i), start + ms(1200)); + assert_pending!(poll_next(&mut i)); +} + +#[tokio::test(start_paused = true)] +async fn test_time_advance_sub_ms() { + let now = Instant::now(); + + let dur = Duration::from_micros(51_592); + time::advance(dur).await; + + assert_eq!(now.elapsed(), dur); + + let now = Instant::now(); + let dur = Duration::from_micros(1); + time::advance(dur).await; + + assert_eq!(now.elapsed(), dur); +} + +#[tokio::test(start_paused = true)] +async fn test_time_advance_3ms_and_change() { + let now = Instant::now(); + + let dur = Duration::from_micros(3_141_592); + time::advance(dur).await; + + assert_eq!(now.elapsed(), dur); + + let now = Instant::now(); + let dur = Duration::from_micros(3_123_456); + time::advance(dur).await; + + assert_eq!(now.elapsed(), dur); +} + +#[tokio::test(start_paused = true)] +async fn regression_3710_with_submillis_advance() { + let start = Instant::now(); + + time::advance(Duration::from_millis(1)).await; + + let mut sleep = task::spawn(time::sleep_until(start + Duration::from_secs(60))); + + assert_pending!(sleep.poll()); + + let before = Instant::now(); + let dur = Duration::from_micros(51_592); + time::advance(dur).await; + assert_eq!(before.elapsed(), dur); + + assert_pending!(sleep.poll()); +} + +#[tokio::test(start_paused = true)] +async fn exact_1ms_advance() { + let now = Instant::now(); + + let dur = Duration::from_millis(1); + time::advance(dur).await; + + assert_eq!(now.elapsed(), dur); + + let now = Instant::now(); + let dur = Duration::from_millis(1); + time::advance(dur).await; + + assert_eq!(now.elapsed(), dur); +} + +#[tokio::test(start_paused = true)] +async fn advance_once_with_timer() { + let mut sleep = task::spawn(time::sleep(Duration::from_millis(1))); + assert_pending!(sleep.poll()); + + time::advance(Duration::from_micros(250)).await; + assert_pending!(sleep.poll()); + + time::advance(Duration::from_micros(1500)).await; + + assert!(sleep.is_woken()); + assert_ready!(sleep.poll()); +} + +#[tokio::test(start_paused = true)] +async fn advance_multi_with_timer() { + // Round to the nearest ms + // time::sleep(Duration::from_millis(1)).await; + + let mut sleep = task::spawn(time::sleep(Duration::from_millis(1))); + assert_pending!(sleep.poll()); + + time::advance(Duration::from_micros(250)).await; + assert_pending!(sleep.poll()); + + time::advance(Duration::from_micros(250)).await; + assert_pending!(sleep.poll()); + + time::advance(Duration::from_micros(250)).await; + assert_pending!(sleep.poll()); + + time::advance(Duration::from_micros(250)).await; + assert!(sleep.is_woken()); + assert_ready!(sleep.poll()); +} + +fn poll_next(interval: &mut task::Spawn<time::Interval>) -> Poll<Instant> { + interval.enter(|cx, mut interval| interval.poll_tick(cx)) +} + +fn ms(n: u64) -> Duration { + Duration::from_millis(n) +} diff --git a/tests/time_rt.rs b/tests/time_rt.rs index 0775343..23367be 100644 --- a/tests/time_rt.rs +++ b/tests/time_rt.rs @@ -13,7 +13,7 @@ fn timer_with_threaded_runtime() { let (tx, rx) = mpsc::channel(); rt.spawn(async move { - let when = Instant::now() + Duration::from_millis(100); + let when = Instant::now() + Duration::from_millis(10); sleep_until(when).await; assert!(Instant::now() >= when); @@ -32,7 +32,7 @@ fn timer_with_basic_scheduler() { let (tx, rx) = mpsc::channel(); rt.block_on(async move { - let when = Instant::now() + Duration::from_millis(100); + let when = Instant::now() + Duration::from_millis(10); sleep_until(when).await; assert!(Instant::now() >= when); @@ -67,7 +67,7 @@ async fn starving() { } } - let when = Instant::now() + Duration::from_millis(20); + let when = Instant::now() + Duration::from_millis(10); let starve = Starve(Box::pin(sleep_until(when)), 0); starve.await; @@ -81,7 +81,7 @@ async fn timeout_value() { let (_tx, rx) = oneshot::channel::<()>(); let now = Instant::now(); - let dur = Duration::from_millis(20); + let dur = Duration::from_millis(10); let res = timeout(dur, rx).await; assert!(res.is_err()); diff --git a/tests/time_sleep.rs b/tests/time_sleep.rs index 2736258..20477d2 100644 --- a/tests/time_sleep.rs +++ b/tests/time_sleep.rs @@ -7,22 +7,7 @@ use std::task::Context; use futures::task::noop_waker_ref; use tokio::time::{self, Duration, Instant}; -use tokio_test::{assert_pending, assert_ready, task}; - -macro_rules! assert_elapsed { - ($now:expr, $ms:expr) => {{ - let elapsed = $now.elapsed(); - let lower = ms($ms); - - // Handles ms rounding - assert!( - elapsed >= lower && elapsed <= lower + ms(1), - "actual = {:?}, expected = {:?}", - elapsed, - lower - ); - }}; -} +use tokio_test::{assert_elapsed, assert_pending, assert_ready, task}; #[tokio::test] async fn immediate_sleep() { @@ -32,14 +17,14 @@ async fn immediate_sleep() { // Ready! time::sleep_until(now).await; - assert_elapsed!(now, 0); + assert_elapsed!(now, ms(1)); } #[tokio::test] async fn is_elapsed() { time::pause(); - let sleep = time::sleep(Duration::from_millis(50)); + let sleep = time::sleep(Duration::from_millis(10)); tokio::pin!(sleep); @@ -60,10 +45,11 @@ async fn delayed_sleep_level_0() { for &i in &[1, 10, 60] { let now = Instant::now(); + let dur = ms(i); - time::sleep_until(now + ms(i)).await; + time::sleep_until(now + dur).await; - assert_elapsed!(now, i); + assert_elapsed!(now, dur); } } @@ -77,7 +63,7 @@ async fn sub_ms_delayed_sleep() { time::sleep_until(deadline).await; - assert_elapsed!(now, 1); + assert_elapsed!(now, ms(1)); } } @@ -90,7 +76,7 @@ async fn delayed_sleep_wrapping_level_0() { let now = Instant::now(); time::sleep_until(now + ms(60)).await; - assert_elapsed!(now, 60); + assert_elapsed!(now, ms(60)); } #[tokio::test] @@ -107,7 +93,7 @@ async fn reset_future_sleep_before_fire() { sleep.as_mut().reset(Instant::now() + ms(200)); sleep.await; - assert_elapsed!(now, 200); + assert_elapsed!(now, ms(200)); } #[tokio::test] @@ -124,7 +110,7 @@ async fn reset_past_sleep_before_turn() { sleep.as_mut().reset(now + ms(80)); sleep.await; - assert_elapsed!(now, 80); + assert_elapsed!(now, ms(80)); } #[tokio::test] @@ -143,7 +129,7 @@ async fn reset_past_sleep_before_fire() { sleep.as_mut().reset(now + ms(80)); sleep.await; - assert_elapsed!(now, 80); + assert_elapsed!(now, ms(80)); } #[tokio::test] @@ -154,11 +140,11 @@ async fn reset_future_sleep_after_fire() { let mut sleep = Box::pin(time::sleep_until(now + ms(100))); sleep.as_mut().await; - assert_elapsed!(now, 100); + assert_elapsed!(now, ms(100)); sleep.as_mut().reset(now + ms(110)); sleep.await; - assert_elapsed!(now, 110); + assert_elapsed!(now, ms(110)); } #[tokio::test] @@ -363,7 +349,7 @@ async fn drop_from_wake() { assert!( !panicked.load(Ordering::SeqCst), - "paniced when dropping timers" + "panicked when dropping timers" ); #[derive(Clone)] diff --git a/tests/udp.rs b/tests/udp.rs index 715d8eb..ec2a1e9 100644 --- a/tests/udp.rs +++ b/tests/udp.rs @@ -5,6 +5,7 @@ use futures::future::poll_fn; use std::io; use std::sync::Arc; use tokio::{io::ReadBuf, net::UdpSocket}; +use tokio_test::assert_ok; const MSG: &[u8] = b"hello"; const MSG_LEN: usize = MSG.len(); @@ -440,3 +441,46 @@ async fn try_recv_buf_from() { } } } + +#[tokio::test] +async fn poll_ready() { + // Create listener + let server = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let saddr = server.local_addr().unwrap(); + + // Create socket pair + let client = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let caddr = client.local_addr().unwrap(); + + for _ in 0..5 { + loop { + assert_ok!(poll_fn(|cx| client.poll_send_ready(cx)).await); + + match client.try_send_to(b"hello world", saddr) { + Ok(n) => { + assert_eq!(n, 11); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + + loop { + assert_ok!(poll_fn(|cx| server.poll_recv_ready(cx)).await); + + let mut buf = Vec::with_capacity(512); + + match server.try_recv_buf_from(&mut buf) { + Ok((n, addr)) => { + assert_eq!(n, 11); + assert_eq!(addr, caddr); + assert_eq!(&buf[0..11], &b"hello world"[..]); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + } +} diff --git a/tests/uds_cred.rs b/tests/uds_cred.rs index 5bd97fc..c2b3914 100644 --- a/tests/uds_cred.rs +++ b/tests/uds_cred.rs @@ -9,10 +9,6 @@ use libc::geteuid; #[tokio::test] #[cfg_attr( - target_os = "freebsd", - ignore = "Requires FreeBSD 12.0 or later. https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=176419" -)] -#[cfg_attr( target_os = "netbsd", ignore = "NetBSD does not support getpeereid() for sockets created by socketpair()" )] diff --git a/tests/uds_datagram.rs b/tests/uds_datagram.rs index 10314be..5e5486b 100644 --- a/tests/uds_datagram.rs +++ b/tests/uds_datagram.rs @@ -87,9 +87,12 @@ async fn try_send_recv_never_block() -> io::Result<()> { dgram1.writable().await.unwrap(); match dgram1.try_send(payload) { - Err(err) => match err.kind() { - io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, - _ => unreachable!("unexpected error {:?}", err), + Err(err) => match (err.kind(), err.raw_os_error()) { + (io::ErrorKind::WouldBlock, _) => break, + (_, Some(libc::ENOBUFS)) => break, + _ => { + panic!("unexpected error {:?}", err); + } }, Ok(len) => { assert_eq!(len, payload.len()); @@ -291,9 +294,12 @@ async fn try_recv_buf_never_block() -> io::Result<()> { dgram1.writable().await.unwrap(); match dgram1.try_send(payload) { - Err(err) => match err.kind() { - io::ErrorKind::WouldBlock | io::ErrorKind::Other => break, - _ => unreachable!("unexpected error {:?}", err), + Err(err) => match (err.kind(), err.raw_os_error()) { + (io::ErrorKind::WouldBlock, _) => break, + (_, Some(libc::ENOBUFS)) => break, + _ => { + panic!("unexpected error {:?}", err); + } }, Ok(len) => { assert_eq!(len, payload.len()); @@ -322,3 +328,50 @@ async fn try_recv_buf_never_block() -> io::Result<()> { Ok(()) } + +#[tokio::test] +async fn poll_ready() -> io::Result<()> { + let dir = tempfile::tempdir().unwrap(); + let server_path = dir.path().join("server.sock"); + let client_path = dir.path().join("client.sock"); + + // Create listener + let server = UnixDatagram::bind(&server_path)?; + + // Create socket pair + let client = UnixDatagram::bind(&client_path)?; + + for _ in 0..5 { + loop { + poll_fn(|cx| client.poll_send_ready(cx)).await?; + + match client.try_send_to(b"hello world", &server_path) { + Ok(n) => { + assert_eq!(n, 11); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + + loop { + poll_fn(|cx| server.poll_recv_ready(cx)).await?; + + let mut buf = Vec::with_capacity(512); + + match server.try_recv_buf_from(&mut buf) { + Ok((n, addr)) => { + assert_eq!(n, 11); + assert_eq!(addr.as_pathname(), Some(client_path.as_ref())); + assert_eq!(&buf[0..11], &b"hello world"[..]); + break; + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("{:?}", e), + } + } + } + + Ok(()) +} diff --git a/tests/uds_stream.rs b/tests/uds_stream.rs index c528620..5f1b4cf 100644 --- a/tests/uds_stream.rs +++ b/tests/uds_stream.rs @@ -90,7 +90,7 @@ async fn try_read_write() -> std::io::Result<()> { tokio::task::yield_now().await; } - // Fill the write buffer + // Fill the write buffer using non-vectored I/O loop { // Still ready let mut writable = task::spawn(client.writable()); @@ -110,7 +110,7 @@ async fn try_read_write() -> std::io::Result<()> { let mut writable = task::spawn(client.writable()); assert_pending!(writable.poll()); - // Drain the socket from the server end + // Drain the socket from the server end using non-vectored I/O let mut read = vec![0; written.len()]; let mut i = 0; @@ -127,6 +127,51 @@ async fn try_read_write() -> std::io::Result<()> { assert_eq!(read, written); } + written.clear(); + client.writable().await.unwrap(); + + // Fill the write buffer using vectored I/O + let msg_bufs: Vec<_> = msg.chunks(3).map(io::IoSlice::new).collect(); + loop { + // Still ready + let mut writable = task::spawn(client.writable()); + assert_ready_ok!(writable.poll()); + + match client.try_write_vectored(&msg_bufs) { + Ok(n) => written.extend(&msg[..n]), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + break; + } + Err(e) => panic!("error = {:?}", e), + } + } + + { + // Write buffer full + let mut writable = task::spawn(client.writable()); + assert_pending!(writable.poll()); + + // Drain the socket from the server end using vectored I/O + let mut read = vec![0; written.len()]; + let mut i = 0; + + while i < read.len() { + server.readable().await?; + + let mut bufs: Vec<_> = read[i..] + .chunks_mut(0x10000) + .map(io::IoSliceMut::new) + .collect(); + match server.try_read_vectored(&mut bufs) { + Ok(n) => i += n, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => continue, + Err(e) => panic!("error = {:?}", e), + } + } + + assert_eq!(read, written); + } + // Now, we listen for shutdown drop(client); @@ -334,3 +379,33 @@ async fn try_read_buf() -> std::io::Result<()> { Ok(()) } + +// https://github.com/tokio-rs/tokio/issues/3879 +#[tokio::test] +#[cfg(not(target_os = "macos"))] +async fn epollhup() -> io::Result<()> { + let dir = tempfile::Builder::new() + .prefix("tokio-uds-tests") + .tempdir() + .unwrap(); + let sock_path = dir.path().join("connect.sock"); + + let listener = UnixListener::bind(&sock_path)?; + let connect = UnixStream::connect(&sock_path); + tokio::pin!(connect); + + // Poll `connect` once. + poll_fn(|cx| { + use std::future::Future; + + assert_pending!(connect.as_mut().poll(cx)); + Poll::Ready(()) + }) + .await; + + drop(listener); + + let err = connect.await.unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::ConnectionReset); + Ok(()) +} |