diff options
Diffstat (limited to 'src/sync/rwlock.rs')
-rw-r--r-- | src/sync/rwlock.rs | 367 |
1 files changed, 360 insertions, 7 deletions
diff --git a/src/sync/rwlock.rs b/src/sync/rwlock.rs index 120bc72..0492e4e 100644 --- a/src/sync/rwlock.rs +++ b/src/sync/rwlock.rs @@ -1,5 +1,7 @@ use crate::sync::batch_semaphore::{Semaphore, TryAcquireError}; use crate::sync::mutex::TryLockError; +#[cfg(all(tokio_unstable, feature = "tracing"))] +use crate::util::trace; use std::cell::UnsafeCell; use std::marker; use std::marker::PhantomData; @@ -86,6 +88,9 @@ const MAX_READS: u32 = 10; /// [_write-preferring_]: https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Priority_policies #[derive(Debug)] pub struct RwLock<T: ?Sized> { + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span, + // maximum number of concurrent readers mr: u32, @@ -197,14 +202,55 @@ impl<T: ?Sized> RwLock<T> { /// /// let lock = RwLock::new(5); /// ``` + #[track_caller] pub fn new(value: T) -> RwLock<T> where T: Sized, { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "RwLock", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + max_readers = MAX_READS, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 0, + ); + }); + + resource_span + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| Semaphore::new(MAX_READS as usize)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = Semaphore::new(MAX_READS as usize); + RwLock { mr: MAX_READS, c: UnsafeCell::new(value), - s: Semaphore::new(MAX_READS as usize), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -222,6 +268,7 @@ impl<T: ?Sized> RwLock<T> { /// # Panics /// /// Panics if `max_reads` is more than `u32::MAX >> 3`. + #[track_caller] pub fn with_max_readers(value: T, max_reads: u32) -> RwLock<T> where T: Sized, @@ -231,10 +278,52 @@ impl<T: ?Sized> RwLock<T> { "a RwLock may not be created with more than {} readers", MAX_READS ); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = { + let location = std::panic::Location::caller(); + + let resource_span = tracing::trace_span!( + "runtime.resource", + concrete_type = "RwLock", + kind = "Sync", + loc.file = location.file(), + loc.line = location.line(), + loc.col = location.column(), + ); + + resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + max_readers = max_reads, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = false, + ); + + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 0, + ); + }); + + resource_span + }; + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let s = resource_span.in_scope(|| Semaphore::new(max_reads as usize)); + + #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] + let s = Semaphore::new(max_reads as usize); + RwLock { mr: max_reads, c: UnsafeCell::new(value), - s: Semaphore::new(max_reads as usize), + s, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -257,6 +346,8 @@ impl<T: ?Sized> RwLock<T> { mr: MAX_READS, c: UnsafeCell::new(value), s: Semaphore::const_new(MAX_READS as usize), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), } } @@ -281,6 +372,8 @@ impl<T: ?Sized> RwLock<T> { mr: max_reads, c: UnsafeCell::new(value), s: Semaphore::const_new(max_reads as usize), + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: tracing::Span::none(), } } @@ -327,21 +420,98 @@ impl<T: ?Sized> RwLock<T> { /// /// // Drop the guard after the spawned task finishes. /// drop(n); - ///} + /// } /// ``` pub async fn read(&self) -> RwLockReadGuard<'_, T> { - self.s.acquire(1).await.unwrap_or_else(|_| { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(1), + self.resource_span.clone(), + "RwLock::read", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(1); + + inner.await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + RwLockReadGuard { s: &self.s, data: self.c.get(), marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), } } + /// Blockingly locks this `RwLock` with shared read access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + /// + /// Returns an RAII guard which will drop the read access of this `RwLock` when dropped. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::sync::RwLock; + /// + /// #[tokio::main] + /// async fn main() { + /// let rwlock = Arc::new(RwLock::new(1)); + /// let mut write_lock = rwlock.write().await; + /// + /// let blocking_task = tokio::task::spawn_blocking({ + /// let rwlock = Arc::clone(&rwlock); + /// move || { + /// // This shall block until the `write_lock` is released. + /// let read_lock = rwlock.blocking_read(); + /// assert_eq!(*read_lock, 0); + /// } + /// }); + /// + /// *write_lock -= 1; + /// drop(write_lock); // release the lock. + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// assert!(rwlock.try_write().is_ok()); + /// } + /// ``` + #[track_caller] + #[cfg(feature = "sync")] + pub fn blocking_read(&self) -> RwLockReadGuard<'_, T> { + crate::future::block_on(self.read()) + } + /// Locks this `RwLock` with shared read access, causing the current task /// to yield until the lock has been acquired. /// @@ -394,15 +564,42 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn read_owned(self: Arc<Self>) -> OwnedRwLockReadGuard<T> { - self.s.acquire(1).await.unwrap_or_else(|_| { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(1), + self.resource_span.clone(), + "RwLock::read_owned", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(1); + + inner.await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + OwnedRwLockReadGuard { data: self.c.get(), lock: ManuallyDrop::new(self), _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -445,10 +642,21 @@ impl<T: ?Sized> RwLock<T> { Err(TryAcquireError::Closed) => unreachable!(), } + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + Ok(RwLockReadGuard { s: &self.s, data: self.c.get(), marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), }) } @@ -497,10 +705,24 @@ impl<T: ?Sized> RwLock<T> { Err(TryAcquireError::Closed) => unreachable!(), } + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + current_readers = 1, + current_readers.op = "add", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + Ok(OwnedRwLockReadGuard { data: self.c.get(), lock: ManuallyDrop::new(self), _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, }) } @@ -533,19 +755,98 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn write(&self) -> RwLockWriteGuard<'_, T> { - self.s.acquire(self.mr).await.unwrap_or_else(|_| { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(self.mr), + self.resource_span.clone(), + "RwLock::write", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(self.mr); + + inner.await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + RwLockWriteGuard { permits_acquired: self.mr, s: &self.s, data: self.c.get(), marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), } } + /// Blockingly locks this `RwLock` with exclusive write access. + /// + /// This method is intended for use cases where you + /// need to use this rwlock in asynchronous code as well as in synchronous code. + /// + /// Returns an RAII guard which will drop the write access of this `RwLock` when dropped. + /// + /// # Panics + /// + /// This function panics if called within an asynchronous execution context. + /// + /// - If you find yourself in an asynchronous execution context and needing + /// to call some (synchronous) function which performs one of these + /// `blocking_` operations, then consider wrapping that call inside + /// [`spawn_blocking()`][crate::runtime::Handle::spawn_blocking] + /// (or [`block_in_place()`][crate::task::block_in_place]). + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// use tokio::{sync::RwLock}; + /// + /// #[tokio::main] + /// async fn main() { + /// let rwlock = Arc::new(RwLock::new(1)); + /// let read_lock = rwlock.read().await; + /// + /// let blocking_task = tokio::task::spawn_blocking({ + /// let rwlock = Arc::clone(&rwlock); + /// move || { + /// // This shall block until the `read_lock` is released. + /// let mut write_lock = rwlock.blocking_write(); + /// *write_lock = 2; + /// } + /// }); + /// + /// assert_eq!(*read_lock, 1); + /// // Release the last outstanding read lock. + /// drop(read_lock); + /// + /// // Await the completion of the blocking task. + /// blocking_task.await.unwrap(); + /// + /// // Assert uncontended. + /// let read_lock = rwlock.try_read().unwrap(); + /// assert_eq!(*read_lock, 2); + /// } + /// ``` + #[track_caller] + #[cfg(feature = "sync")] + pub fn blocking_write(&self) -> RwLockWriteGuard<'_, T> { + crate::future::block_on(self.write()) + } + /// Locks this `RwLock` with exclusive write access, causing the current /// task to yield until the lock has been acquired. /// @@ -582,16 +883,43 @@ impl<T: ?Sized> RwLock<T> { ///} /// ``` pub async fn write_owned(self: Arc<Self>) -> OwnedRwLockWriteGuard<T> { - self.s.acquire(self.mr).await.unwrap_or_else(|_| { + #[cfg(all(tokio_unstable, feature = "tracing"))] + let inner = trace::async_op( + || self.s.acquire(self.mr), + self.resource_span.clone(), + "RwLock::write_owned", + "poll", + false, + ); + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let inner = self.s.acquire(self.mr); + + inner.await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + OwnedRwLockWriteGuard { permits_acquired: self.mr, data: self.c.get(), lock: ManuallyDrop::new(self), _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, } } @@ -625,11 +953,22 @@ impl<T: ?Sized> RwLock<T> { Err(TryAcquireError::Closed) => unreachable!(), } + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + Ok(RwLockWriteGuard { permits_acquired: self.mr, s: &self.s, data: self.c.get(), marker: marker::PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span: self.resource_span.clone(), }) } @@ -670,11 +1009,25 @@ impl<T: ?Sized> RwLock<T> { Err(TryAcquireError::Closed) => unreachable!(), } + #[cfg(all(tokio_unstable, feature = "tracing"))] + self.resource_span.in_scope(|| { + tracing::trace!( + target: "runtime::resource::state_update", + write_locked = true, + write_locked.op = "override", + ) + }); + + #[cfg(all(tokio_unstable, feature = "tracing"))] + let resource_span = self.resource_span.clone(); + Ok(OwnedRwLockWriteGuard { permits_acquired: self.mr, data: self.c.get(), lock: ManuallyDrop::new(self), _p: PhantomData, + #[cfg(all(tokio_unstable, feature = "tracing"))] + resource_span, }) } |