blob: d947c7b0e30130617b7c4303fe8fcd33c2f9d9d1 [file] [log] [blame]
use crate::cmp;
use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen};
use crate::ops::{ControlFlow, Try};
/// An iterator that only iterates over the first `n` iterations of `iter`.
///
/// This `struct` is created by the [`take`] method on [`Iterator`]. See its
/// documentation for more.
///
/// [`take`]: Iterator::take
/// [`Iterator`]: trait.Iterator.html
#[derive(Clone, Debug)]
#[must_use = "iterators are lazy and do nothing unless consumed"]
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Take<I> {
iter: I,
n: usize,
}
impl<I> Take<I> {
pub(in crate::iter) fn new(iter: I, n: usize) -> Take<I> {
Take { iter, n }
}
}
#[stable(feature = "rust1", since = "1.0.0")]
impl<I> Iterator for Take<I>
where
I: Iterator,
{
type Item = <I as Iterator>::Item;
#[inline]
fn next(&mut self) -> Option<<I as Iterator>::Item> {
if self.n != 0 {
self.n -= 1;
self.iter.next()
} else {
None
}
}
#[inline]
fn nth(&mut self, n: usize) -> Option<I::Item> {
if self.n > n {
self.n -= n + 1;
self.iter.nth(n)
} else {
if self.n > 0 {
self.iter.nth(self.n - 1);
self.n = 0;
}
None
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
if self.n == 0 {
return (0, Some(0));
}
let (lower, upper) = self.iter.size_hint();
let lower = cmp::min(lower, self.n);
let upper = match upper {
Some(x) if x < self.n => Some(x),
_ => Some(self.n),
};
(lower, upper)
}
#[inline]
fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
where
Fold: FnMut(Acc, Self::Item) -> R,
R: Try<Output = Acc>,
{
fn check<'a, T, Acc, R: Try<Output = Acc>>(
n: &'a mut usize,
mut fold: impl FnMut(Acc, T) -> R + 'a,
) -> impl FnMut(Acc, T) -> ControlFlow<R, Acc> + 'a {
move |acc, x| {
*n -= 1;
let r = fold(acc, x);
if *n == 0 { ControlFlow::Break(r) } else { ControlFlow::from_try(r) }
}
}
if self.n == 0 {
try { init }
} else {
let n = &mut self.n;
self.iter.try_fold(init, check(n, fold)).into_try()
}
}
impl_fold_via_try_fold! { fold -> try_fold }
#[inline]
fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
// The default implementation would use a unit accumulator, so we can
// avoid a stateful closure by folding over the remaining number
// of items we wish to return instead.
fn check<'a, Item>(
mut action: impl FnMut(Item) + 'a,
) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
move |more, x| {
action(x);
more.checked_sub(1)
}
}
let remaining = self.n;
if remaining > 0 {
self.iter.try_fold(remaining - 1, check(f));
}
}
#[inline]
#[rustc_inherit_overflow_checks]
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
let min = self.n.min(n);
match self.iter.advance_by(min) {
Ok(_) => {
self.n -= min;
if min < n { Err(min) } else { Ok(()) }
}
ret @ Err(advanced) => {
self.n -= advanced;
ret
}
}
}
}
#[unstable(issue = "none", feature = "inplace_iteration")]
unsafe impl<I> SourceIter for Take<I>
where
I: SourceIter,
{
type Source = I::Source;
#[inline]
unsafe fn as_inner(&mut self) -> &mut I::Source {
// SAFETY: unsafe function forwarding to unsafe function with the same requirements
unsafe { SourceIter::as_inner(&mut self.iter) }
}
}
#[unstable(issue = "none", feature = "inplace_iteration")]
unsafe impl<I: InPlaceIterable> InPlaceIterable for Take<I> {}
#[stable(feature = "double_ended_take_iterator", since = "1.38.0")]
impl<I> DoubleEndedIterator for Take<I>
where
I: DoubleEndedIterator + ExactSizeIterator,
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
if self.n == 0 {
None
} else {
let n = self.n;
self.n -= 1;
self.iter.nth_back(self.iter.len().saturating_sub(n))
}
}
#[inline]
fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
let len = self.iter.len();
if self.n > n {
let m = len.saturating_sub(self.n) + n;
self.n -= n + 1;
self.iter.nth_back(m)
} else {
if len > 0 {
self.iter.nth_back(len - 1);
}
None
}
}
#[inline]
fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
where
Self: Sized,
Fold: FnMut(Acc, Self::Item) -> R,
R: Try<Output = Acc>,
{
if self.n == 0 {
try { init }
} else {
let len = self.iter.len();
if len > self.n && self.iter.nth_back(len - self.n - 1).is_none() {
try { init }
} else {
self.iter.try_rfold(init, fold)
}
}
}
#[inline]
fn rfold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc
where
Self: Sized,
Fold: FnMut(Acc, Self::Item) -> Acc,
{
if self.n == 0 {
init
} else {
let len = self.iter.len();
if len > self.n && self.iter.nth_back(len - self.n - 1).is_none() {
init
} else {
self.iter.rfold(init, fold)
}
}
}
#[inline]
#[rustc_inherit_overflow_checks]
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
// The amount by which the inner iterator needs to be shortened for it to be
// at most as long as the take() amount.
let trim_inner = self.iter.len().saturating_sub(self.n);
// The amount we need to advance inner to fulfill the caller's request.
// take(), advance_by() and len() all can be at most usize, so we don't have to worry
// about having to advance more than usize::MAX here.
let advance_by = trim_inner.saturating_add(n);
let advanced = match self.iter.advance_back_by(advance_by) {
Ok(_) => advance_by - trim_inner,
Err(advanced) => advanced - trim_inner,
};
self.n -= advanced;
return if advanced < n { Err(advanced) } else { Ok(()) };
}
}
#[stable(feature = "rust1", since = "1.0.0")]
impl<I> ExactSizeIterator for Take<I> where I: ExactSizeIterator {}
#[stable(feature = "fused", since = "1.26.0")]
impl<I> FusedIterator for Take<I> where I: FusedIterator {}
#[unstable(feature = "trusted_len", issue = "37572")]
unsafe impl<I: TrustedLen> TrustedLen for Take<I> {}