From dd2305def6fff1d9f1149ecb73b02e37eb8a18b0 Mon Sep 17 00:00:00 2001 From: Joel Galenson Date: Wed, 19 May 2021 16:31:56 -0700 Subject: Upgrade rust/crates/rayon to 1.5.1 Test: make Change-Id: I40c1a4538832871d1f4cd09daf6904d094b5615e --- src/iter/collect/consumer.rs | 72 +++++++++++++++++++-------------------- src/iter/collect/mod.rs | 78 +++++++++++++++++++++--------------------- src/iter/collect/test.rs | 2 +- src/iter/mod.rs | 81 +++++++++++++++++++++++++++++++++++++++++--- src/iter/par_bridge.rs | 38 +++++++++++++++------ src/iter/unzip.rs | 61 +++++++++++++++++++++++++++++++++ 6 files changed, 242 insertions(+), 90 deletions(-) (limited to 'src/iter') diff --git a/src/iter/collect/consumer.rs b/src/iter/collect/consumer.rs index 689f29c..3a8eea0 100644 --- a/src/iter/collect/consumer.rs +++ b/src/iter/collect/consumer.rs @@ -1,26 +1,18 @@ use super::super::plumbing::*; use std::marker::PhantomData; +use std::mem::MaybeUninit; use std::ptr; use std::slice; pub(super) struct CollectConsumer<'c, T: Send> { /// A slice covering the target memory, not yet initialized! - target: &'c mut [T], -} - -pub(super) struct CollectFolder<'c, T: Send> { - /// The folder writes into `result` and must extend the result - /// up to exactly this number of elements. - final_len: usize, - - /// The current written-to part of our slice of the target - result: CollectResult<'c, T>, + target: &'c mut [MaybeUninit], } impl<'c, T: Send + 'c> CollectConsumer<'c, T> { /// The target memory is considered uninitialized, and will be /// overwritten without reading or dropping existing values. - pub(super) fn new(target: &'c mut [T]) -> Self { + pub(super) fn new(target: &'c mut [MaybeUninit]) -> Self { CollectConsumer { target } } } @@ -31,8 +23,12 @@ impl<'c, T: Send + 'c> CollectConsumer<'c, T> { /// the elements will be dropped, unless its ownership is released before then. #[must_use] pub(super) struct CollectResult<'c, T> { - start: *mut T, + /// A slice covering the target memory, initialized up to our separate `len`. + target: &'c mut [MaybeUninit], + /// The current initialized length in `target` len: usize, + /// Lifetime invariance guarantees that the data flows from consumer to result, + /// especially for the `scope_fn` callback in `Collect::with_consumer`. invariant_lifetime: PhantomData<&'c mut &'c mut [T]>, } @@ -57,13 +53,15 @@ impl<'c, T> Drop for CollectResult<'c, T> { // Drop the first `self.len` elements, which have been recorded // to be initialized by the folder. unsafe { - ptr::drop_in_place(slice::from_raw_parts_mut(self.start, self.len)); + // TODO: use `MaybeUninit::slice_as_mut_ptr` + let start = self.target.as_mut_ptr() as *mut T; + ptr::drop_in_place(slice::from_raw_parts_mut(start, self.len)); } } } impl<'c, T: Send + 'c> Consumer for CollectConsumer<'c, T> { - type Folder = CollectFolder<'c, T>; + type Folder = CollectResult<'c, T>; type Reducer = CollectReducer; type Result = CollectResult<'c, T>; @@ -80,16 +78,13 @@ impl<'c, T: Send + 'c> Consumer for CollectConsumer<'c, T> { ) } - fn into_folder(self) -> CollectFolder<'c, T> { - // Create a folder that consumes values and writes them + fn into_folder(self) -> Self::Folder { + // Create a result/folder that consumes values and writes them // into target. The initial result has length 0. - CollectFolder { - final_len: self.target.len(), - result: CollectResult { - start: self.target.as_mut_ptr(), - len: 0, - invariant_lifetime: PhantomData, - }, + CollectResult { + target: self.target, + len: 0, + invariant_lifetime: PhantomData, } } @@ -98,19 +93,19 @@ impl<'c, T: Send + 'c> Consumer for CollectConsumer<'c, T> { } } -impl<'c, T: Send + 'c> Folder for CollectFolder<'c, T> { - type Result = CollectResult<'c, T>; +impl<'c, T: Send + 'c> Folder for CollectResult<'c, T> { + type Result = Self; - fn consume(mut self, item: T) -> CollectFolder<'c, T> { - if self.result.len >= self.final_len { - panic!("too many values pushed to consumer"); - } + fn consume(mut self, item: T) -> Self { + let dest = self + .target + .get_mut(self.len) + .expect("too many values pushed to consumer"); - // Compute target pointer and write to it, and - // extend the current result by one element + // Write item and increase the initialized length unsafe { - self.result.start.add(self.result.len).write(item); - self.result.len += 1; + dest.as_mut_ptr().write(item); + self.len += 1; } self @@ -119,7 +114,7 @@ impl<'c, T: Send + 'c> Folder for CollectFolder<'c, T> { fn complete(self) -> Self::Result { // NB: We don't explicitly check that the local writes were complete, // but Collect will assert the total result length in the end. - self.result + self } fn full(&self) -> bool { @@ -151,8 +146,13 @@ impl<'c, T> Reducer> for CollectReducer { // Merge if the CollectResults are adjacent and in left to right order // else: drop the right piece now and total length will end up short in the end, // when the correctness of the collected result is asserted. - if left.start.wrapping_add(left.len) == right.start { - left.len += right.release_ownership(); + let left_end = left.target[left.len..].as_ptr(); + if left_end == right.target.as_ptr() { + let len = left.len + right.release_ownership(); + unsafe { + left.target = slice::from_raw_parts_mut(left.target.as_mut_ptr(), len); + } + left.len = len; } left } diff --git a/src/iter/collect/mod.rs b/src/iter/collect/mod.rs index e18298e..7cbf215 100644 --- a/src/iter/collect/mod.rs +++ b/src/iter/collect/mod.rs @@ -1,4 +1,5 @@ use super::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator}; +use std::mem::MaybeUninit; use std::slice; mod consumer; @@ -88,55 +89,56 @@ impl<'c, T: Send + 'c> Collect<'c, T> { where F: FnOnce(CollectConsumer<'_, T>) -> CollectResult<'_, T>, { + let slice = Self::reserve_get_tail_slice(&mut self.vec, self.len); + let result = scope_fn(CollectConsumer::new(slice)); + + // The CollectResult represents a contiguous part of the + // slice, that has been written to. + // On unwind here, the CollectResult will be dropped. + // If some producers on the way did not produce enough elements, + // partial CollectResults may have been dropped without + // being reduced to the final result, and we will see + // that as the length coming up short. + // + // Here, we assert that `slice` is fully initialized. This is + // checked by the following assert, which verifies if a + // complete CollectResult was produced; if the length is + // correct, it is necessarily covering the target slice. + // Since we know that the consumer cannot have escaped from + // `drive` (by parametricity, essentially), we know that any + // stores that will happen, have happened. Unless some code is buggy, + // that means we should have seen `len` total writes. + let actual_writes = result.len(); + assert!( + actual_writes == self.len, + "expected {} total writes, but got {}", + self.len, + actual_writes + ); + + // Release the result's mutable borrow and "proxy ownership" + // of the elements, before the vector takes it over. + result.release_ownership(); + + let new_len = self.vec.len() + self.len; + unsafe { - let slice = Self::reserve_get_tail_slice(&mut self.vec, self.len); - let result = scope_fn(CollectConsumer::new(slice)); - - // The CollectResult represents a contiguous part of the - // slice, that has been written to. - // On unwind here, the CollectResult will be dropped. - // If some producers on the way did not produce enough elements, - // partial CollectResults may have been dropped without - // being reduced to the final result, and we will see - // that as the length coming up short. - // - // Here, we assert that `slice` is fully initialized. This is - // checked by the following assert, which verifies if a - // complete CollectResult was produced; if the length is - // correct, it is necessarily covering the target slice. - // Since we know that the consumer cannot have escaped from - // `drive` (by parametricity, essentially), we know that any - // stores that will happen, have happened. Unless some code is buggy, - // that means we should have seen `len` total writes. - let actual_writes = result.len(); - assert!( - actual_writes == self.len, - "expected {} total writes, but got {}", - self.len, - actual_writes - ); - - // Release the result's mutable borrow and "proxy ownership" - // of the elements, before the vector takes it over. - result.release_ownership(); - - let new_len = self.vec.len() + self.len; self.vec.set_len(new_len); } } /// Reserve space for `len` more elements in the vector, /// and return a slice to the uninitialized tail of the vector - /// - /// Safety: The tail slice is uninitialized - unsafe fn reserve_get_tail_slice(vec: &mut Vec, len: usize) -> &mut [T] { + fn reserve_get_tail_slice(vec: &mut Vec, len: usize) -> &mut [MaybeUninit] { // Reserve the new space. vec.reserve(len); - // Get a correct borrow, then extend it for the newly added length. + // TODO: use `Vec::spare_capacity_mut` instead + // SAFETY: `MaybeUninit` is guaranteed to have the same layout + // as `T`, and we already made sure to have the additional space. let start = vec.len(); - let slice = &mut vec[start..]; - slice::from_raw_parts_mut(slice.as_mut_ptr(), len) + let tail_ptr = vec[start..].as_mut_ptr() as *mut MaybeUninit; + unsafe { slice::from_raw_parts_mut(tail_ptr, len) } } } diff --git a/src/iter/collect/test.rs b/src/iter/collect/test.rs index 00c16c4..ddf7757 100644 --- a/src/iter/collect/test.rs +++ b/src/iter/collect/test.rs @@ -24,7 +24,7 @@ fn produce_too_many_items() { let mut folder = consumer.into_folder(); folder = folder.consume(22); folder = folder.consume(23); - folder.consume(24); + folder = folder.consume(24); unreachable!("folder does not complete") }); } diff --git a/src/iter/mod.rs b/src/iter/mod.rs index 0c82933..edff1a6 100644 --- a/src/iter/mod.rs +++ b/src/iter/mod.rs @@ -61,7 +61,7 @@ //! If you'd like to build a custom parallel iterator, or to write your own //! combinator, then check out the [split] function and the [plumbing] module. //! -//! [regular iterator]: http://doc.rust-lang.org/std/iter/trait.Iterator.html +//! [regular iterator]: https://doc.rust-lang.org/std/iter/trait.Iterator.html //! [`ParallelIterator`]: trait.ParallelIterator.html //! [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html //! [split]: fn.split.html @@ -1966,6 +1966,79 @@ pub trait ParallelIterator: Sized + Send { /// /// assert_eq!(sync_vec, async_vec); /// ``` + /// + /// You can collect a pair of collections like [`unzip`](#method.unzip) + /// for paired items: + /// + /// ``` + /// use rayon::prelude::*; + /// + /// let a = [(0, 1), (1, 2), (2, 3), (3, 4)]; + /// let (first, second): (Vec<_>, Vec<_>) = a.into_par_iter().collect(); + /// + /// assert_eq!(first, [0, 1, 2, 3]); + /// assert_eq!(second, [1, 2, 3, 4]); + /// ``` + /// + /// Or like [`partition_map`](#method.partition_map) for `Either` items: + /// + /// ``` + /// use rayon::prelude::*; + /// use rayon::iter::Either; + /// + /// let (left, right): (Vec<_>, Vec<_>) = (0..8).into_par_iter().map(|x| { + /// if x % 2 == 0 { + /// Either::Left(x * 4) + /// } else { + /// Either::Right(x * 3) + /// } + /// }).collect(); + /// + /// assert_eq!(left, [0, 8, 16, 24]); + /// assert_eq!(right, [3, 9, 15, 21]); + /// ``` + /// + /// You can even collect an arbitrarily-nested combination of pairs and `Either`: + /// + /// ``` + /// use rayon::prelude::*; + /// use rayon::iter::Either; + /// + /// let (first, (left, right)): (Vec<_>, (Vec<_>, Vec<_>)) + /// = (0..8).into_par_iter().map(|x| { + /// if x % 2 == 0 { + /// (x, Either::Left(x * 4)) + /// } else { + /// (-x, Either::Right(x * 3)) + /// } + /// }).collect(); + /// + /// assert_eq!(first, [0, -1, 2, -3, 4, -5, 6, -7]); + /// assert_eq!(left, [0, 8, 16, 24]); + /// assert_eq!(right, [3, 9, 15, 21]); + /// ``` + /// + /// All of that can _also_ be combined with short-circuiting collection of + /// `Result` or `Option` types: + /// + /// ``` + /// use rayon::prelude::*; + /// use rayon::iter::Either; + /// + /// let result: Result<(Vec<_>, (Vec<_>, Vec<_>)), _> + /// = (0..8).into_par_iter().map(|x| { + /// if x > 5 { + /// Err(x) + /// } else if x % 2 == 0 { + /// Ok((x, Either::Left(x * 4))) + /// } else { + /// Ok((-x, Either::Right(x * 3))) + /// } + /// }).collect(); + /// + /// let error = result.unwrap_err(); + /// assert!(error == 6 || error == 7); + /// ``` fn collect(self) -> C where C: FromParallelIterator, @@ -2130,7 +2203,7 @@ pub trait ParallelIterator: Sized + Send { /// See the [README] for more details on the internals of parallel /// iterators. /// - /// [README]: README.md + /// [README]: https://github.com/rayon-rs/rayon/blob/master/src/iter/plumbing/README.md fn drive_unindexed(self, consumer: C) -> C::Result where C: UnindexedConsumer; @@ -2817,7 +2890,7 @@ pub trait IndexedParallelIterator: ParallelIterator { /// See the [README] for more details on the internals of parallel /// iterators. /// - /// [README]: README.md + /// [README]: https://github.com/rayon-rs/rayon/blob/master/src/iter/plumbing/README.md fn drive>(self, consumer: C) -> C::Result; /// Internal method used to define the behavior of this parallel @@ -2834,7 +2907,7 @@ pub trait IndexedParallelIterator: ParallelIterator { /// See the [README] for more details on the internals of parallel /// iterators. /// - /// [README]: README.md + /// [README]: https://github.com/rayon-rs/rayon/blob/master/src/iter/plumbing/README.md fn with_producer>(self, callback: CB) -> CB::Output; } diff --git a/src/iter/par_bridge.rs b/src/iter/par_bridge.rs index 4c2b96e..339ac1a 100644 --- a/src/iter/par_bridge.rs +++ b/src/iter/par_bridge.rs @@ -125,16 +125,19 @@ where let mut count = self.split_count.load(Ordering::SeqCst); loop { - let done = self.done.load(Ordering::SeqCst); + // Check if the iterator is exhausted *and* we've consumed every item from it. + let done = self.done.load(Ordering::SeqCst) && self.items.is_empty(); + match count.checked_sub(1) { Some(new_count) if !done => { - let last_count = - self.split_count - .compare_and_swap(count, new_count, Ordering::SeqCst); - if last_count == count { - return (self.clone(), Some(self)); - } else { - count = last_count; + match self.split_count.compare_exchange_weak( + count, + new_count, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => return (self.clone(), Some(self)), + Err(last_count) => count = last_count, } } _ => { @@ -157,13 +160,26 @@ where } } Steal::Empty => { + // Don't storm the mutex if we're already done. if self.done.load(Ordering::SeqCst) { - // the iterator is out of items, no use in continuing - return folder; + // Someone might have pushed more between our `steal()` and `done.load()` + if self.items.is_empty() { + // The iterator is out of items, no use in continuing + return folder; + } } else { // our cache is out of items, time to load more from the iterator match self.iter.try_lock() { Ok(mut guard) => { + // Check `done` again in case we raced with the previous lock + // holder on its way out. + if self.done.load(Ordering::SeqCst) { + if self.items.is_empty() { + return folder; + } + continue; + } + let count = current_num_threads(); let count = (count * count) * 2; @@ -184,7 +200,7 @@ where } Err(TryLockError::WouldBlock) => { // someone else has the mutex, just sit tight until it's ready - yield_now(); //TODO: use a thread=pool-aware yield? (#548) + yield_now(); //TODO: use a thread-pool-aware yield? (#548) } Err(TryLockError::Poisoned(_)) => { // any panics from other threads will have been caught by the pool, diff --git a/src/iter/unzip.rs b/src/iter/unzip.rs index 219b909..0b7074e 100644 --- a/src/iter/unzip.rs +++ b/src/iter/unzip.rs @@ -462,3 +462,64 @@ where } } } + +impl FromParallelIterator<(A, B)> for (FromA, FromB) +where + A: Send, + B: Send, + FromA: Send + FromParallelIterator, + FromB: Send + FromParallelIterator, +{ + fn from_par_iter(pi: I) -> Self + where + I: IntoParallelIterator, + { + let (a, b): (Collector, Collector) = pi.into_par_iter().unzip(); + (a.result.unwrap(), b.result.unwrap()) + } +} + +impl FromParallelIterator> for (A, B) +where + L: Send, + R: Send, + A: Send + FromParallelIterator, + B: Send + FromParallelIterator, +{ + fn from_par_iter(pi: I) -> Self + where + I: IntoParallelIterator>, + { + fn identity(x: T) -> T { + x + } + + let (a, b): (Collector, Collector) = pi.into_par_iter().partition_map(identity); + (a.result.unwrap(), b.result.unwrap()) + } +} + +/// Shim to implement a one-time `ParallelExtend` using `FromParallelIterator`. +struct Collector { + result: Option, +} + +impl Default for Collector { + fn default() -> Self { + Collector { result: None } + } +} + +impl ParallelExtend for Collector +where + T: Send, + FromT: Send + FromParallelIterator, +{ + fn par_extend(&mut self, pi: I) + where + I: IntoParallelIterator, + { + debug_assert!(self.result.is_none()); + self.result = Some(pi.into_par_iter().collect()); + } +} -- cgit v1.2.3