aboutsummaryrefslogtreecommitdiff
path: root/src/iter
diff options
context:
space:
mode:
authorJoel Galenson <jgalenson@google.com>2021-05-19 16:31:56 -0700
committerJoel Galenson <jgalenson@google.com>2021-05-19 16:31:56 -0700
commitdd2305def6fff1d9f1149ecb73b02e37eb8a18b0 (patch)
treea915bf2398b238dbd70ae338a6bb643a7372de84 /src/iter
parent698313e1f73ce48237108f63ba85b3000f9d0ff1 (diff)
downloadrayon-dd2305def6fff1d9f1149ecb73b02e37eb8a18b0.tar.gz
Upgrade rust/crates/rayon to 1.5.1
Test: make Change-Id: I40c1a4538832871d1f4cd09daf6904d094b5615e
Diffstat (limited to 'src/iter')
-rw-r--r--src/iter/collect/consumer.rs72
-rw-r--r--src/iter/collect/mod.rs78
-rw-r--r--src/iter/collect/test.rs2
-rw-r--r--src/iter/mod.rs81
-rw-r--r--src/iter/par_bridge.rs38
-rw-r--r--src/iter/unzip.rs61
6 files changed, 242 insertions, 90 deletions
diff --git a/src/iter/collect/consumer.rs b/src/iter/collect/consumer.rs
index 689f29c..3a8eea0 100644
--- a/src/iter/collect/consumer.rs
+++ b/src/iter/collect/consumer.rs
@@ -1,26 +1,18 @@
use super::super::plumbing::*;
use std::marker::PhantomData;
+use std::mem::MaybeUninit;
use std::ptr;
use std::slice;
pub(super) struct CollectConsumer<'c, T: Send> {
/// A slice covering the target memory, not yet initialized!
- target: &'c mut [T],
-}
-
-pub(super) struct CollectFolder<'c, T: Send> {
- /// The folder writes into `result` and must extend the result
- /// up to exactly this number of elements.
- final_len: usize,
-
- /// The current written-to part of our slice of the target
- result: CollectResult<'c, T>,
+ target: &'c mut [MaybeUninit<T>],
}
impl<'c, T: Send + 'c> CollectConsumer<'c, T> {
/// The target memory is considered uninitialized, and will be
/// overwritten without reading or dropping existing values.
- pub(super) fn new(target: &'c mut [T]) -> Self {
+ pub(super) fn new(target: &'c mut [MaybeUninit<T>]) -> Self {
CollectConsumer { target }
}
}
@@ -31,8 +23,12 @@ impl<'c, T: Send + 'c> CollectConsumer<'c, T> {
/// the elements will be dropped, unless its ownership is released before then.
#[must_use]
pub(super) struct CollectResult<'c, T> {
- start: *mut T,
+ /// A slice covering the target memory, initialized up to our separate `len`.
+ target: &'c mut [MaybeUninit<T>],
+ /// The current initialized length in `target`
len: usize,
+ /// Lifetime invariance guarantees that the data flows from consumer to result,
+ /// especially for the `scope_fn` callback in `Collect::with_consumer`.
invariant_lifetime: PhantomData<&'c mut &'c mut [T]>,
}
@@ -57,13 +53,15 @@ impl<'c, T> Drop for CollectResult<'c, T> {
// Drop the first `self.len` elements, which have been recorded
// to be initialized by the folder.
unsafe {
- ptr::drop_in_place(slice::from_raw_parts_mut(self.start, self.len));
+ // TODO: use `MaybeUninit::slice_as_mut_ptr`
+ let start = self.target.as_mut_ptr() as *mut T;
+ ptr::drop_in_place(slice::from_raw_parts_mut(start, self.len));
}
}
}
impl<'c, T: Send + 'c> Consumer<T> for CollectConsumer<'c, T> {
- type Folder = CollectFolder<'c, T>;
+ type Folder = CollectResult<'c, T>;
type Reducer = CollectReducer;
type Result = CollectResult<'c, T>;
@@ -80,16 +78,13 @@ impl<'c, T: Send + 'c> Consumer<T> for CollectConsumer<'c, T> {
)
}
- fn into_folder(self) -> CollectFolder<'c, T> {
- // Create a folder that consumes values and writes them
+ fn into_folder(self) -> Self::Folder {
+ // Create a result/folder that consumes values and writes them
// into target. The initial result has length 0.
- CollectFolder {
- final_len: self.target.len(),
- result: CollectResult {
- start: self.target.as_mut_ptr(),
- len: 0,
- invariant_lifetime: PhantomData,
- },
+ CollectResult {
+ target: self.target,
+ len: 0,
+ invariant_lifetime: PhantomData,
}
}
@@ -98,19 +93,19 @@ impl<'c, T: Send + 'c> Consumer<T> for CollectConsumer<'c, T> {
}
}
-impl<'c, T: Send + 'c> Folder<T> for CollectFolder<'c, T> {
- type Result = CollectResult<'c, T>;
+impl<'c, T: Send + 'c> Folder<T> for CollectResult<'c, T> {
+ type Result = Self;
- fn consume(mut self, item: T) -> CollectFolder<'c, T> {
- if self.result.len >= self.final_len {
- panic!("too many values pushed to consumer");
- }
+ fn consume(mut self, item: T) -> Self {
+ let dest = self
+ .target
+ .get_mut(self.len)
+ .expect("too many values pushed to consumer");
- // Compute target pointer and write to it, and
- // extend the current result by one element
+ // Write item and increase the initialized length
unsafe {
- self.result.start.add(self.result.len).write(item);
- self.result.len += 1;
+ dest.as_mut_ptr().write(item);
+ self.len += 1;
}
self
@@ -119,7 +114,7 @@ impl<'c, T: Send + 'c> Folder<T> for CollectFolder<'c, T> {
fn complete(self) -> Self::Result {
// NB: We don't explicitly check that the local writes were complete,
// but Collect will assert the total result length in the end.
- self.result
+ self
}
fn full(&self) -> bool {
@@ -151,8 +146,13 @@ impl<'c, T> Reducer<CollectResult<'c, T>> for CollectReducer {
// Merge if the CollectResults are adjacent and in left to right order
// else: drop the right piece now and total length will end up short in the end,
// when the correctness of the collected result is asserted.
- if left.start.wrapping_add(left.len) == right.start {
- left.len += right.release_ownership();
+ let left_end = left.target[left.len..].as_ptr();
+ if left_end == right.target.as_ptr() {
+ let len = left.len + right.release_ownership();
+ unsafe {
+ left.target = slice::from_raw_parts_mut(left.target.as_mut_ptr(), len);
+ }
+ left.len = len;
}
left
}
diff --git a/src/iter/collect/mod.rs b/src/iter/collect/mod.rs
index e18298e..7cbf215 100644
--- a/src/iter/collect/mod.rs
+++ b/src/iter/collect/mod.rs
@@ -1,4 +1,5 @@
use super::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator};
+use std::mem::MaybeUninit;
use std::slice;
mod consumer;
@@ -88,55 +89,56 @@ impl<'c, T: Send + 'c> Collect<'c, T> {
where
F: FnOnce(CollectConsumer<'_, T>) -> CollectResult<'_, T>,
{
+ let slice = Self::reserve_get_tail_slice(&mut self.vec, self.len);
+ let result = scope_fn(CollectConsumer::new(slice));
+
+ // The CollectResult represents a contiguous part of the
+ // slice, that has been written to.
+ // On unwind here, the CollectResult will be dropped.
+ // If some producers on the way did not produce enough elements,
+ // partial CollectResults may have been dropped without
+ // being reduced to the final result, and we will see
+ // that as the length coming up short.
+ //
+ // Here, we assert that `slice` is fully initialized. This is
+ // checked by the following assert, which verifies if a
+ // complete CollectResult was produced; if the length is
+ // correct, it is necessarily covering the target slice.
+ // Since we know that the consumer cannot have escaped from
+ // `drive` (by parametricity, essentially), we know that any
+ // stores that will happen, have happened. Unless some code is buggy,
+ // that means we should have seen `len` total writes.
+ let actual_writes = result.len();
+ assert!(
+ actual_writes == self.len,
+ "expected {} total writes, but got {}",
+ self.len,
+ actual_writes
+ );
+
+ // Release the result's mutable borrow and "proxy ownership"
+ // of the elements, before the vector takes it over.
+ result.release_ownership();
+
+ let new_len = self.vec.len() + self.len;
+
unsafe {
- let slice = Self::reserve_get_tail_slice(&mut self.vec, self.len);
- let result = scope_fn(CollectConsumer::new(slice));
-
- // The CollectResult represents a contiguous part of the
- // slice, that has been written to.
- // On unwind here, the CollectResult will be dropped.
- // If some producers on the way did not produce enough elements,
- // partial CollectResults may have been dropped without
- // being reduced to the final result, and we will see
- // that as the length coming up short.
- //
- // Here, we assert that `slice` is fully initialized. This is
- // checked by the following assert, which verifies if a
- // complete CollectResult was produced; if the length is
- // correct, it is necessarily covering the target slice.
- // Since we know that the consumer cannot have escaped from
- // `drive` (by parametricity, essentially), we know that any
- // stores that will happen, have happened. Unless some code is buggy,
- // that means we should have seen `len` total writes.
- let actual_writes = result.len();
- assert!(
- actual_writes == self.len,
- "expected {} total writes, but got {}",
- self.len,
- actual_writes
- );
-
- // Release the result's mutable borrow and "proxy ownership"
- // of the elements, before the vector takes it over.
- result.release_ownership();
-
- let new_len = self.vec.len() + self.len;
self.vec.set_len(new_len);
}
}
/// Reserve space for `len` more elements in the vector,
/// and return a slice to the uninitialized tail of the vector
- ///
- /// Safety: The tail slice is uninitialized
- unsafe fn reserve_get_tail_slice(vec: &mut Vec<T>, len: usize) -> &mut [T] {
+ fn reserve_get_tail_slice(vec: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
// Reserve the new space.
vec.reserve(len);
- // Get a correct borrow, then extend it for the newly added length.
+ // TODO: use `Vec::spare_capacity_mut` instead
+ // SAFETY: `MaybeUninit<T>` is guaranteed to have the same layout
+ // as `T`, and we already made sure to have the additional space.
let start = vec.len();
- let slice = &mut vec[start..];
- slice::from_raw_parts_mut(slice.as_mut_ptr(), len)
+ let tail_ptr = vec[start..].as_mut_ptr() as *mut MaybeUninit<T>;
+ unsafe { slice::from_raw_parts_mut(tail_ptr, len) }
}
}
diff --git a/src/iter/collect/test.rs b/src/iter/collect/test.rs
index 00c16c4..ddf7757 100644
--- a/src/iter/collect/test.rs
+++ b/src/iter/collect/test.rs
@@ -24,7 +24,7 @@ fn produce_too_many_items() {
let mut folder = consumer.into_folder();
folder = folder.consume(22);
folder = folder.consume(23);
- folder.consume(24);
+ folder = folder.consume(24);
unreachable!("folder does not complete")
});
}
diff --git a/src/iter/mod.rs b/src/iter/mod.rs
index 0c82933..edff1a6 100644
--- a/src/iter/mod.rs
+++ b/src/iter/mod.rs
@@ -61,7 +61,7 @@
//! If you'd like to build a custom parallel iterator, or to write your own
//! combinator, then check out the [split] function and the [plumbing] module.
//!
-//! [regular iterator]: http://doc.rust-lang.org/std/iter/trait.Iterator.html
+//! [regular iterator]: https://doc.rust-lang.org/std/iter/trait.Iterator.html
//! [`ParallelIterator`]: trait.ParallelIterator.html
//! [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
//! [split]: fn.split.html
@@ -1966,6 +1966,79 @@ pub trait ParallelIterator: Sized + Send {
///
/// assert_eq!(sync_vec, async_vec);
/// ```
+ ///
+ /// You can collect a pair of collections like [`unzip`](#method.unzip)
+ /// for paired items:
+ ///
+ /// ```
+ /// use rayon::prelude::*;
+ ///
+ /// let a = [(0, 1), (1, 2), (2, 3), (3, 4)];
+ /// let (first, second): (Vec<_>, Vec<_>) = a.into_par_iter().collect();
+ ///
+ /// assert_eq!(first, [0, 1, 2, 3]);
+ /// assert_eq!(second, [1, 2, 3, 4]);
+ /// ```
+ ///
+ /// Or like [`partition_map`](#method.partition_map) for `Either` items:
+ ///
+ /// ```
+ /// use rayon::prelude::*;
+ /// use rayon::iter::Either;
+ ///
+ /// let (left, right): (Vec<_>, Vec<_>) = (0..8).into_par_iter().map(|x| {
+ /// if x % 2 == 0 {
+ /// Either::Left(x * 4)
+ /// } else {
+ /// Either::Right(x * 3)
+ /// }
+ /// }).collect();
+ ///
+ /// assert_eq!(left, [0, 8, 16, 24]);
+ /// assert_eq!(right, [3, 9, 15, 21]);
+ /// ```
+ ///
+ /// You can even collect an arbitrarily-nested combination of pairs and `Either`:
+ ///
+ /// ```
+ /// use rayon::prelude::*;
+ /// use rayon::iter::Either;
+ ///
+ /// let (first, (left, right)): (Vec<_>, (Vec<_>, Vec<_>))
+ /// = (0..8).into_par_iter().map(|x| {
+ /// if x % 2 == 0 {
+ /// (x, Either::Left(x * 4))
+ /// } else {
+ /// (-x, Either::Right(x * 3))
+ /// }
+ /// }).collect();
+ ///
+ /// assert_eq!(first, [0, -1, 2, -3, 4, -5, 6, -7]);
+ /// assert_eq!(left, [0, 8, 16, 24]);
+ /// assert_eq!(right, [3, 9, 15, 21]);
+ /// ```
+ ///
+ /// All of that can _also_ be combined with short-circuiting collection of
+ /// `Result` or `Option` types:
+ ///
+ /// ```
+ /// use rayon::prelude::*;
+ /// use rayon::iter::Either;
+ ///
+ /// let result: Result<(Vec<_>, (Vec<_>, Vec<_>)), _>
+ /// = (0..8).into_par_iter().map(|x| {
+ /// if x > 5 {
+ /// Err(x)
+ /// } else if x % 2 == 0 {
+ /// Ok((x, Either::Left(x * 4)))
+ /// } else {
+ /// Ok((-x, Either::Right(x * 3)))
+ /// }
+ /// }).collect();
+ ///
+ /// let error = result.unwrap_err();
+ /// assert!(error == 6 || error == 7);
+ /// ```
fn collect<C>(self) -> C
where
C: FromParallelIterator<Self::Item>,
@@ -2130,7 +2203,7 @@ pub trait ParallelIterator: Sized + Send {
/// See the [README] for more details on the internals of parallel
/// iterators.
///
- /// [README]: README.md
+ /// [README]: https://github.com/rayon-rs/rayon/blob/master/src/iter/plumbing/README.md
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>;
@@ -2817,7 +2890,7 @@ pub trait IndexedParallelIterator: ParallelIterator {
/// See the [README] for more details on the internals of parallel
/// iterators.
///
- /// [README]: README.md
+ /// [README]: https://github.com/rayon-rs/rayon/blob/master/src/iter/plumbing/README.md
fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result;
/// Internal method used to define the behavior of this parallel
@@ -2834,7 +2907,7 @@ pub trait IndexedParallelIterator: ParallelIterator {
/// See the [README] for more details on the internals of parallel
/// iterators.
///
- /// [README]: README.md
+ /// [README]: https://github.com/rayon-rs/rayon/blob/master/src/iter/plumbing/README.md
fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output;
}
diff --git a/src/iter/par_bridge.rs b/src/iter/par_bridge.rs
index 4c2b96e..339ac1a 100644
--- a/src/iter/par_bridge.rs
+++ b/src/iter/par_bridge.rs
@@ -125,16 +125,19 @@ where
let mut count = self.split_count.load(Ordering::SeqCst);
loop {
- let done = self.done.load(Ordering::SeqCst);
+ // Check if the iterator is exhausted *and* we've consumed every item from it.
+ let done = self.done.load(Ordering::SeqCst) && self.items.is_empty();
+
match count.checked_sub(1) {
Some(new_count) if !done => {
- let last_count =
- self.split_count
- .compare_and_swap(count, new_count, Ordering::SeqCst);
- if last_count == count {
- return (self.clone(), Some(self));
- } else {
- count = last_count;
+ match self.split_count.compare_exchange_weak(
+ count,
+ new_count,
+ Ordering::SeqCst,
+ Ordering::SeqCst,
+ ) {
+ Ok(_) => return (self.clone(), Some(self)),
+ Err(last_count) => count = last_count,
}
}
_ => {
@@ -157,13 +160,26 @@ where
}
}
Steal::Empty => {
+ // Don't storm the mutex if we're already done.
if self.done.load(Ordering::SeqCst) {
- // the iterator is out of items, no use in continuing
- return folder;
+ // Someone might have pushed more between our `steal()` and `done.load()`
+ if self.items.is_empty() {
+ // The iterator is out of items, no use in continuing
+ return folder;
+ }
} else {
// our cache is out of items, time to load more from the iterator
match self.iter.try_lock() {
Ok(mut guard) => {
+ // Check `done` again in case we raced with the previous lock
+ // holder on its way out.
+ if self.done.load(Ordering::SeqCst) {
+ if self.items.is_empty() {
+ return folder;
+ }
+ continue;
+ }
+
let count = current_num_threads();
let count = (count * count) * 2;
@@ -184,7 +200,7 @@ where
}
Err(TryLockError::WouldBlock) => {
// someone else has the mutex, just sit tight until it's ready
- yield_now(); //TODO: use a thread=pool-aware yield? (#548)
+ yield_now(); //TODO: use a thread-pool-aware yield? (#548)
}
Err(TryLockError::Poisoned(_)) => {
// any panics from other threads will have been caught by the pool,
diff --git a/src/iter/unzip.rs b/src/iter/unzip.rs
index 219b909..0b7074e 100644
--- a/src/iter/unzip.rs
+++ b/src/iter/unzip.rs
@@ -462,3 +462,64 @@ where
}
}
}
+
+impl<A, B, FromA, FromB> FromParallelIterator<(A, B)> for (FromA, FromB)
+where
+ A: Send,
+ B: Send,
+ FromA: Send + FromParallelIterator<A>,
+ FromB: Send + FromParallelIterator<B>,
+{
+ fn from_par_iter<I>(pi: I) -> Self
+ where
+ I: IntoParallelIterator<Item = (A, B)>,
+ {
+ let (a, b): (Collector<FromA>, Collector<FromB>) = pi.into_par_iter().unzip();
+ (a.result.unwrap(), b.result.unwrap())
+ }
+}
+
+impl<L, R, A, B> FromParallelIterator<Either<L, R>> for (A, B)
+where
+ L: Send,
+ R: Send,
+ A: Send + FromParallelIterator<L>,
+ B: Send + FromParallelIterator<R>,
+{
+ fn from_par_iter<I>(pi: I) -> Self
+ where
+ I: IntoParallelIterator<Item = Either<L, R>>,
+ {
+ fn identity<T>(x: T) -> T {
+ x
+ }
+
+ let (a, b): (Collector<A>, Collector<B>) = pi.into_par_iter().partition_map(identity);
+ (a.result.unwrap(), b.result.unwrap())
+ }
+}
+
+/// Shim to implement a one-time `ParallelExtend` using `FromParallelIterator`.
+struct Collector<FromT> {
+ result: Option<FromT>,
+}
+
+impl<FromT> Default for Collector<FromT> {
+ fn default() -> Self {
+ Collector { result: None }
+ }
+}
+
+impl<T, FromT> ParallelExtend<T> for Collector<FromT>
+where
+ T: Send,
+ FromT: Send + FromParallelIterator<T>,
+{
+ fn par_extend<I>(&mut self, pi: I)
+ where
+ I: IntoParallelIterator<Item = T>,
+ {
+ debug_assert!(self.result.is_none());
+ self.result = Some(pi.into_par_iter().collect());
+ }
+}