diff --git a/src/libstd/iterator.rs b/src/libstd/iterator.rs index 6828de51622..33d863f3716 100644 --- a/src/libstd/iterator.rs +++ b/src/libstd/iterator.rs @@ -529,7 +529,7 @@ impl> IteratorUtil for T { #[inline] fn flat_map_<'r, B, U: Iterator>(self, f: &'r fn(A) -> U) -> FlatMap<'r, A, T, U> { - FlatMap{iter: self, f: f, subiter: None } + FlatMap{iter: self, f: f, frontiter: None, backiter: None } } // FIXME: #5898: should be called `peek` @@ -1251,7 +1251,8 @@ impl<'self, A, B, T: Iterator, St> Iterator for Scan<'self, A, B, T, St> { pub struct FlatMap<'self, A, T, U> { priv iter: T, priv f: &'self fn(A) -> U, - priv subiter: Option, + priv frontiter: Option, + priv backiter: Option, } impl<'self, A, T: Iterator, B, U: Iterator> Iterator for @@ -1259,14 +1260,35 @@ impl<'self, A, T: Iterator, B, U: Iterator> Iterator for #[inline] fn next(&mut self) -> Option { loop { - for self.subiter.mut_iter().advance |inner| { + for self.frontiter.mut_iter().advance |inner| { for inner.advance |x| { return Some(x) } } match self.iter.next().map_consume(|x| (self.f)(x)) { - None => return None, - next => self.subiter = next, + None => return self.backiter.chain_mut_ref(|it| it.next()), + next => self.frontiter = next, + } + } + } +} + +impl<'self, + A, T: DoubleEndedIterator, + B, U: DoubleEndedIterator> DoubleEndedIterator + for FlatMap<'self, A, T, U> { + #[inline] + fn next_back(&mut self) -> Option { + loop { + for self.backiter.mut_iter().advance |inner| { + match inner.next_back() { + None => (), + y => return y + } + } + match self.iter.next_back().map_consume(|x| (self.f)(x)) { + None => return self.frontiter.chain_mut_ref(|it| it.next_back()), + next => self.backiter = next, } } } @@ -1768,6 +1790,23 @@ mod tests { assert_eq!(it.next_back(), None) } + #[test] + fn test_double_ended_flat_map() { + let u = [0u,1]; + let v = [5,6,7,8]; + let mut it = u.iter().flat_map_(|x| v.slice(*x, v.len()).iter()); + assert_eq!(it.next_back().unwrap(), &8); + assert_eq!(it.next().unwrap(), &5); + assert_eq!(it.next_back().unwrap(), &7); + assert_eq!(it.next_back().unwrap(), &6); + assert_eq!(it.next_back().unwrap(), &8); + assert_eq!(it.next().unwrap(), &6); + assert_eq!(it.next_back().unwrap(), &7); + assert_eq!(it.next_back(), None); + assert_eq!(it.next(), None); + assert_eq!(it.next_back(), None); + } + #[test] fn test_random_access_chain() { let xs = [1, 2, 3, 4, 5];