aboutsummaryrefslogtreecommitdiff
path: root/src/sync/poll_semaphore.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/sync/poll_semaphore.rs')
-rw-r--r--src/sync/poll_semaphore.rs171
1 files changed, 171 insertions, 0 deletions
diff --git a/src/sync/poll_semaphore.rs b/src/sync/poll_semaphore.rs
new file mode 100644
index 0000000..6b44574
--- /dev/null
+++ b/src/sync/poll_semaphore.rs
@@ -0,0 +1,171 @@
+use futures_core::{ready, Stream};
+use std::fmt;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore, TryAcquireError};
+
+use super::ReusableBoxFuture;
+
+/// A wrapper around [`Semaphore`] that provides a `poll_acquire` method.
+///
+/// [`Semaphore`]: tokio::sync::Semaphore
+pub struct PollSemaphore {
+ semaphore: Arc<Semaphore>,
+ permit_fut: Option<(
+ u32, // The number of permits requested.
+ ReusableBoxFuture<'static, Result<OwnedSemaphorePermit, AcquireError>>,
+ )>,
+}
+
+impl PollSemaphore {
+ /// Create a new `PollSemaphore`.
+ pub fn new(semaphore: Arc<Semaphore>) -> Self {
+ Self {
+ semaphore,
+ permit_fut: None,
+ }
+ }
+
+ /// Closes the semaphore.
+ pub fn close(&self) {
+ self.semaphore.close()
+ }
+
+ /// Obtain a clone of the inner semaphore.
+ pub fn clone_inner(&self) -> Arc<Semaphore> {
+ self.semaphore.clone()
+ }
+
+ /// Get back the inner semaphore.
+ pub fn into_inner(self) -> Arc<Semaphore> {
+ self.semaphore
+ }
+
+ /// Poll to acquire a permit from the semaphore.
+ ///
+ /// This can return the following values:
+ ///
+ /// - `Poll::Pending` if a permit is not currently available.
+ /// - `Poll::Ready(Some(permit))` if a permit was acquired.
+ /// - `Poll::Ready(None)` if the semaphore has been closed.
+ ///
+ /// When this method returns `Poll::Pending`, the current task is scheduled
+ /// to receive a wakeup when a permit becomes available, or when the
+ /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
+ /// the `Waker` from the `Context` passed to the most recent call is
+ /// scheduled to receive a wakeup.
+ pub fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
+ self.poll_acquire_many(cx, 1)
+ }
+
+ /// Poll to acquire many permits from the semaphore.
+ ///
+ /// This can return the following values:
+ ///
+ /// - `Poll::Pending` if a permit is not currently available.
+ /// - `Poll::Ready(Some(permit))` if a permit was acquired.
+ /// - `Poll::Ready(None)` if the semaphore has been closed.
+ ///
+ /// When this method returns `Poll::Pending`, the current task is scheduled
+ /// to receive a wakeup when the permits become available, or when the
+ /// semaphore is closed. Note that on multiple calls to `poll_acquire`, only
+ /// the `Waker` from the `Context` passed to the most recent call is
+ /// scheduled to receive a wakeup.
+ pub fn poll_acquire_many(
+ &mut self,
+ cx: &mut Context<'_>,
+ permits: u32,
+ ) -> Poll<Option<OwnedSemaphorePermit>> {
+ let permit_future = match self.permit_fut.as_mut() {
+ Some((prev_permits, fut)) if *prev_permits == permits => fut,
+ Some((old_permits, fut_box)) => {
+ // We're requesting a different number of permits, so replace the future
+ // and record the new amount.
+ let fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
+ fut_box.set(fut);
+ *old_permits = permits;
+ fut_box
+ }
+ None => {
+ // avoid allocations completely if we can grab a permit immediately
+ match Arc::clone(&self.semaphore).try_acquire_many_owned(permits) {
+ Ok(permit) => return Poll::Ready(Some(permit)),
+ Err(TryAcquireError::Closed) => return Poll::Ready(None),
+ Err(TryAcquireError::NoPermits) => {}
+ }
+
+ let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
+ &mut self
+ .permit_fut
+ .get_or_insert((permits, ReusableBoxFuture::new(next_fut)))
+ .1
+ }
+ };
+
+ let result = ready!(permit_future.poll(cx));
+
+ // Assume we'll request the same amount of permits in a subsequent call.
+ let next_fut = Arc::clone(&self.semaphore).acquire_many_owned(permits);
+ permit_future.set(next_fut);
+
+ match result {
+ Ok(permit) => Poll::Ready(Some(permit)),
+ Err(_closed) => {
+ self.permit_fut = None;
+ Poll::Ready(None)
+ }
+ }
+ }
+
+ /// Returns the current number of available permits.
+ ///
+ /// This is equivalent to the [`Semaphore::available_permits`] method on the
+ /// `tokio::sync::Semaphore` type.
+ ///
+ /// [`Semaphore::available_permits`]: tokio::sync::Semaphore::available_permits
+ pub fn available_permits(&self) -> usize {
+ self.semaphore.available_permits()
+ }
+
+ /// Adds `n` new permits to the semaphore.
+ ///
+ /// The maximum number of permits is [`Semaphore::MAX_PERMITS`], and this function
+ /// will panic if the limit is exceeded.
+ ///
+ /// This is equivalent to the [`Semaphore::add_permits`] method on the
+ /// `tokio::sync::Semaphore` type.
+ ///
+ /// [`Semaphore::add_permits`]: tokio::sync::Semaphore::add_permits
+ pub fn add_permits(&self, n: usize) {
+ self.semaphore.add_permits(n);
+ }
+}
+
+impl Stream for PollSemaphore {
+ type Item = OwnedSemaphorePermit;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<OwnedSemaphorePermit>> {
+ Pin::into_inner(self).poll_acquire(cx)
+ }
+}
+
+impl Clone for PollSemaphore {
+ fn clone(&self) -> PollSemaphore {
+ PollSemaphore::new(self.clone_inner())
+ }
+}
+
+impl fmt::Debug for PollSemaphore {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PollSemaphore")
+ .field("semaphore", &self.semaphore)
+ .finish()
+ }
+}
+
+impl AsRef<Semaphore> for PollSemaphore {
+ fn as_ref(&self) -> &Semaphore {
+ &self.semaphore
+ }
+}