futures_util/stream/stream/
flatten_unordered.rs

1use alloc::sync::Arc;
2use core::{
3    cell::UnsafeCell,
4    convert::identity,
5    fmt,
6    marker::PhantomData,
7    num::NonZeroUsize,
8    pin::Pin,
9    sync::atomic::{AtomicU8, Ordering},
10};
11
12use pin_project_lite::pin_project;
13
14use futures_core::{
15    future::Future,
16    ready,
17    stream::{FusedStream, Stream},
18    task::{Context, Poll, Waker},
19};
20#[cfg(feature = "sink")]
21use futures_sink::Sink;
22use futures_task::{waker, ArcWake};
23
24use crate::stream::FuturesUnordered;
25
26/// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
27/// method.
28pub type FlattenUnordered<St> = FlattenUnorderedWithFlowController<St, ()>;
29
30/// There is nothing to poll and stream isn't being polled/waking/woken at the moment.
31const NONE: u8 = 0;
32
33/// Inner streams need to be polled.
34const NEED_TO_POLL_INNER_STREAMS: u8 = 1;
35
36/// The base stream needs to be polled.
37const NEED_TO_POLL_STREAM: u8 = 0b10;
38
39/// Both base stream and inner streams need to be polled.
40const NEED_TO_POLL_ALL: u8 = NEED_TO_POLL_INNER_STREAMS | NEED_TO_POLL_STREAM;
41
42/// The current stream is being polled at the moment.
43const POLLING: u8 = 0b100;
44
45/// Stream is being woken at the moment.
46const WAKING: u8 = 0b1000;
47
48/// The stream was waked and will be polled.
49const WOKEN: u8 = 0b10000;
50
51/// Internal polling state of the stream.
52#[derive(Clone, Debug)]
53struct SharedPollState {
54    state: Arc<AtomicU8>,
55}
56
57impl SharedPollState {
58    /// Constructs new `SharedPollState` with the given state.
59    fn new(value: u8) -> Self {
60        Self { state: Arc::new(AtomicU8::new(value)) }
61    }
62
63    /// Attempts to start polling, returning stored state in case of success.
64    /// Returns `None` if either waker is waking at the moment.
65    fn start_polling(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)> {
66        let value = self
67            .state
68            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
69                if value & WAKING == NONE {
70                    Some(POLLING)
71                } else {
72                    None
73                }
74            })
75            .ok()?;
76        let bomb = PollStateBomb::new(self, Self::reset);
77
78        Some((value, bomb))
79    }
80
81    /// Attempts to start the waking process and performs bitwise or with the given value.
82    ///
83    /// If some waker is already in progress or stream is already woken/being polled, waking process won't start, however
84    /// state will be disjuncted with the given value.
85    fn start_waking(
86        &self,
87        to_poll: u8,
88    ) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&Self) -> u8>)> {
89        let value = self
90            .state
91            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
92                let mut next_value = value | to_poll;
93                if value & (WOKEN | POLLING) == NONE {
94                    next_value |= WAKING;
95                }
96
97                if next_value != value {
98                    Some(next_value)
99                } else {
100                    None
101                }
102            })
103            .ok()?;
104
105        // Only start the waking process if we're not in the polling/waking phase and the stream isn't woken already
106        if value & (WOKEN | POLLING | WAKING) == NONE {
107            let bomb = PollStateBomb::new(self, Self::stop_waking);
108
109            Some((value, bomb))
110        } else {
111            None
112        }
113    }
114
115    /// Sets current state to
116    /// - `!POLLING` allowing to use wakers
117    /// - `WOKEN` if the state was changed during `POLLING` phase as waker will be called,
118    ///   or `will_be_woken` flag supplied
119    /// - `!WAKING` as
120    ///   * Wakers called during the `POLLING` phase won't propagate their calls
121    ///   * `POLLING` phase can't start if some of the wakers are active
122    ///     So no wrapped waker can touch the inner waker's cell, it's safe to poll again.
123    fn stop_polling(&self, to_poll: u8, will_be_woken: bool) -> u8 {
124        self.state
125            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |mut value| {
126                let mut next_value = to_poll;
127
128                value &= NEED_TO_POLL_ALL;
129                if value != NONE || will_be_woken {
130                    next_value |= WOKEN;
131                }
132                next_value |= value;
133
134                Some(next_value & !POLLING & !WAKING)
135            })
136            .unwrap()
137    }
138
139    /// Toggles state to non-waking, allowing to start polling.
140    fn stop_waking(&self) -> u8 {
141        let value = self
142            .state
143            .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |value| {
144                let next_value = value & !WAKING | WOKEN;
145
146                if next_value != value {
147                    Some(next_value)
148                } else {
149                    None
150                }
151            })
152            .unwrap_or_else(identity);
153
154        debug_assert!(value & (WOKEN | POLLING | WAKING) == WAKING);
155        value
156    }
157
158    /// Resets current state allowing to poll the stream and wake up wakers.
159    fn reset(&self) -> u8 {
160        self.state.swap(NEED_TO_POLL_ALL, Ordering::SeqCst)
161    }
162}
163
164/// Used to execute some function on the given state when dropped.
165struct PollStateBomb<'a, F: FnOnce(&SharedPollState) -> u8> {
166    state: &'a SharedPollState,
167    drop: Option<F>,
168}
169
170impl<'a, F: FnOnce(&SharedPollState) -> u8> PollStateBomb<'a, F> {
171    /// Constructs new bomb with the given state.
172    fn new(state: &'a SharedPollState, drop: F) -> Self {
173        Self { state, drop: Some(drop) }
174    }
175
176    /// Deactivates bomb, forces it to not call provided function when dropped.
177    fn deactivate(mut self) {
178        self.drop.take();
179    }
180}
181
182impl<F: FnOnce(&SharedPollState) -> u8> Drop for PollStateBomb<'_, F> {
183    fn drop(&mut self) {
184        if let Some(drop) = self.drop.take() {
185            (drop)(self.state);
186        }
187    }
188}
189
190/// Will update state with the provided value on `wake_by_ref` call
191/// and then, if there is a need, call `inner_waker`.
192struct WrappedWaker {
193    inner_waker: UnsafeCell<Option<Waker>>,
194    poll_state: SharedPollState,
195    need_to_poll: u8,
196}
197
198unsafe impl Send for WrappedWaker {}
199unsafe impl Sync for WrappedWaker {}
200
201impl WrappedWaker {
202    /// Replaces given waker's inner_waker for polling stream/futures which will
203    /// update poll state on `wake_by_ref` call. Use only if you need several
204    /// contexts.
205    ///
206    /// ## Safety
207    ///
208    /// This function will modify waker's `inner_waker` via `UnsafeCell`, so
209    /// it should be used only during `POLLING` phase by one thread at the time.
210    unsafe fn replace_waker(self_arc: &mut Arc<Self>, cx: &Context<'_>) {
211        unsafe { *self_arc.inner_waker.get() = cx.waker().clone().into() }
212    }
213
214    /// Attempts to start the waking process for the waker with the given value.
215    /// If succeeded, then the stream isn't yet woken and not being polled at the moment.
216    fn start_waking(&self) -> Option<(u8, PollStateBomb<'_, impl FnOnce(&SharedPollState) -> u8>)> {
217        self.poll_state.start_waking(self.need_to_poll)
218    }
219}
220
221impl ArcWake for WrappedWaker {
222    fn wake_by_ref(self_arc: &Arc<Self>) {
223        if let Some((_, state_bomb)) = self_arc.start_waking() {
224            // Safety: now state is not `POLLING`
225            let waker_opt = unsafe { self_arc.inner_waker.get().as_ref().unwrap() };
226
227            if let Some(inner_waker) = waker_opt.clone() {
228                // Stop waking to allow polling stream
229                drop(state_bomb);
230
231                // Wake up inner waker
232                inner_waker.wake();
233            }
234        }
235    }
236}
237
238pin_project! {
239    /// Future which polls optional inner stream.
240    ///
241    /// If it's `Some`, it will attempt to call `poll_next` on it,
242    /// returning `Some((item, next_item_fut))` in case of `Poll::Ready(Some(...))`
243    /// or `None` in case of `Poll::Ready(None)`.
244    ///
245    /// If `poll_next` will return `Poll::Pending`, it will be forwarded to
246    /// the future and current task will be notified by waker.
247    #[must_use = "futures do nothing unless you `.await` or poll them"]
248    struct PollStreamFut<St> {
249        #[pin]
250        stream: Option<St>,
251    }
252}
253
254impl<St> PollStreamFut<St> {
255    /// Constructs new `PollStreamFut` using given `stream`.
256    fn new(stream: impl Into<Option<St>>) -> Self {
257        Self { stream: stream.into() }
258    }
259}
260
261impl<St: Stream + Unpin> Future for PollStreamFut<St> {
262    type Output = Option<(St::Item, Self)>;
263
264    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
265        let mut stream = self.project().stream;
266
267        let item = if let Some(stream) = stream.as_mut().as_pin_mut() {
268            ready!(stream.poll_next(cx))
269        } else {
270            None
271        };
272        let next_item_fut = Self::new(stream.get_mut().take());
273        let out = item.map(|item| (item, next_item_fut));
274
275        Poll::Ready(out)
276    }
277}
278
279pin_project! {
280    /// Stream for the [`flatten_unordered`](super::StreamExt::flatten_unordered)
281    /// method with ability to specify flow controller.
282    #[project = FlattenUnorderedWithFlowControllerProj]
283    #[must_use = "streams do nothing unless polled"]
284    pub struct FlattenUnorderedWithFlowController<St, Fc> where St: Stream {
285        #[pin]
286        inner_streams: FuturesUnordered<PollStreamFut<St::Item>>,
287        #[pin]
288        stream: St,
289        poll_state: SharedPollState,
290        limit: Option<NonZeroUsize>,
291        is_stream_done: bool,
292        inner_streams_waker: Arc<WrappedWaker>,
293        stream_waker: Arc<WrappedWaker>,
294        flow_controller: PhantomData<Fc>
295    }
296}
297
298impl<St, Fc> fmt::Debug for FlattenUnorderedWithFlowController<St, Fc>
299where
300    St: Stream + fmt::Debug,
301    St::Item: Stream + fmt::Debug,
302{
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        f.debug_struct("FlattenUnorderedWithFlowController")
305            .field("poll_state", &self.poll_state)
306            .field("inner_streams", &self.inner_streams)
307            .field("limit", &self.limit)
308            .field("stream", &self.stream)
309            .field("is_stream_done", &self.is_stream_done)
310            .field("flow_controller", &self.flow_controller)
311            .finish()
312    }
313}
314
315impl<St, Fc> FlattenUnorderedWithFlowController<St, Fc>
316where
317    St: Stream,
318    Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
319    St::Item: Stream + Unpin,
320{
321    pub(crate) fn new(stream: St, limit: Option<usize>) -> Self {
322        let poll_state = SharedPollState::new(NEED_TO_POLL_STREAM);
323
324        Self {
325            inner_streams: FuturesUnordered::new(),
326            stream,
327            is_stream_done: false,
328            limit: limit.and_then(NonZeroUsize::new),
329            inner_streams_waker: Arc::new(WrappedWaker {
330                inner_waker: UnsafeCell::new(None),
331                poll_state: poll_state.clone(),
332                need_to_poll: NEED_TO_POLL_INNER_STREAMS,
333            }),
334            stream_waker: Arc::new(WrappedWaker {
335                inner_waker: UnsafeCell::new(None),
336                poll_state: poll_state.clone(),
337                need_to_poll: NEED_TO_POLL_STREAM,
338            }),
339            poll_state,
340            flow_controller: PhantomData,
341        }
342    }
343
344    delegate_access_inner!(stream, St, ());
345}
346
347/// Returns the next flow step based on the received item.
348pub trait FlowController<I, O> {
349    /// Handles an item producing `FlowStep` describing the next flow step.
350    fn next_step(item: I) -> FlowStep<I, O>;
351}
352
353impl<I, O> FlowController<I, O> for () {
354    fn next_step(item: I) -> FlowStep<I, O> {
355        FlowStep::Continue(item)
356    }
357}
358
359/// Describes the next flow step.
360#[derive(Debug, Clone)]
361pub enum FlowStep<C, R> {
362    /// Just yields an item and continues standard flow.
363    Continue(C),
364    /// Immediately returns an underlying item from the function.
365    Return(R),
366}
367
368impl<St, Fc> FlattenUnorderedWithFlowControllerProj<'_, St, Fc>
369where
370    St: Stream,
371{
372    /// Checks if current `inner_streams` bucket size is greater than optional limit.
373    fn is_exceeded_limit(&self) -> bool {
374        self.limit.map_or(false, |limit| self.inner_streams.len() >= limit.get())
375    }
376}
377
378impl<St, Fc> FusedStream for FlattenUnorderedWithFlowController<St, Fc>
379where
380    St: FusedStream,
381    Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
382    St::Item: Stream + Unpin,
383{
384    fn is_terminated(&self) -> bool {
385        self.stream.is_terminated() && self.inner_streams.is_empty()
386    }
387}
388
389impl<St, Fc> Stream for FlattenUnorderedWithFlowController<St, Fc>
390where
391    St: Stream,
392    Fc: FlowController<St::Item, <St::Item as Stream>::Item>,
393    St::Item: Stream + Unpin,
394{
395    type Item = <St::Item as Stream>::Item;
396
397    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398        let mut next_item = None;
399        let mut need_to_poll_next = NONE;
400
401        let mut this = self.as_mut().project();
402
403        // Attempt to start polling, in case some waker is holding the lock, wait in loop
404        let (mut poll_state_value, state_bomb) = loop {
405            if let Some(value) = this.poll_state.start_polling() {
406                break value;
407            }
408        };
409
410        // Safety: now state is `POLLING`.
411        unsafe {
412            WrappedWaker::replace_waker(this.stream_waker, cx);
413            WrappedWaker::replace_waker(this.inner_streams_waker, cx)
414        };
415
416        if poll_state_value & NEED_TO_POLL_STREAM != NONE {
417            let mut stream_waker = None;
418
419            // Here we need to poll the base stream.
420            //
421            // To improve performance, we will attempt to place as many items as we can
422            // to the `FuturesUnordered` bucket before polling inner streams
423            loop {
424                if this.is_exceeded_limit() || *this.is_stream_done {
425                    // We either exceeded the limit or the stream is exhausted
426                    if !*this.is_stream_done {
427                        // The stream needs to be polled in the next iteration
428                        need_to_poll_next |= NEED_TO_POLL_STREAM;
429                    }
430
431                    break;
432                } else {
433                    let mut cx = Context::from_waker(
434                        stream_waker.get_or_insert_with(|| waker(this.stream_waker.clone())),
435                    );
436
437                    match this.stream.as_mut().poll_next(&mut cx) {
438                        Poll::Ready(Some(item)) => {
439                            let next_item_fut = match Fc::next_step(item) {
440                                // Propagates an item immediately (the main use-case is for errors)
441                                FlowStep::Return(item) => {
442                                    need_to_poll_next |= NEED_TO_POLL_STREAM
443                                        | (poll_state_value & NEED_TO_POLL_INNER_STREAMS);
444                                    poll_state_value &= !NEED_TO_POLL_INNER_STREAMS;
445
446                                    next_item = Some(item);
447
448                                    break;
449                                }
450                                // Yields an item and continues processing (normal case)
451                                FlowStep::Continue(inner_stream) => {
452                                    PollStreamFut::new(inner_stream)
453                                }
454                            };
455                            // Add new stream to the inner streams bucket
456                            this.inner_streams.as_mut().push(next_item_fut);
457                            // Inner streams must be polled afterward
458                            poll_state_value |= NEED_TO_POLL_INNER_STREAMS;
459                        }
460                        Poll::Ready(None) => {
461                            // Mark the base stream as done
462                            *this.is_stream_done = true;
463                        }
464                        Poll::Pending => {
465                            break;
466                        }
467                    }
468                }
469            }
470        }
471
472        if poll_state_value & NEED_TO_POLL_INNER_STREAMS != NONE {
473            let inner_streams_waker = waker(this.inner_streams_waker.clone());
474            let mut cx = Context::from_waker(&inner_streams_waker);
475
476            match this.inner_streams.as_mut().poll_next(&mut cx) {
477                Poll::Ready(Some(Some((item, next_item_fut)))) => {
478                    // Push next inner stream item future to the list of inner streams futures
479                    this.inner_streams.as_mut().push(next_item_fut);
480                    // Take the received item
481                    next_item = Some(item);
482                    // On the next iteration, inner streams must be polled again
483                    need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
484                }
485                Poll::Ready(Some(None)) => {
486                    // On the next iteration, inner streams must be polled again
487                    need_to_poll_next |= NEED_TO_POLL_INNER_STREAMS;
488                }
489                _ => {}
490            }
491        }
492
493        // We didn't have any `poll_next` panic, so it's time to deactivate the bomb
494        state_bomb.deactivate();
495
496        // Call the waker at the end of polling if
497        let mut force_wake =
498            // we need to poll the stream and didn't reach the limit yet
499            need_to_poll_next & NEED_TO_POLL_STREAM != NONE && !this.is_exceeded_limit()
500            // or we need to poll the inner streams again
501            || need_to_poll_next & NEED_TO_POLL_INNER_STREAMS != NONE;
502
503        // Stop polling and swap the latest state
504        poll_state_value = this.poll_state.stop_polling(need_to_poll_next, force_wake);
505        // If state was changed during `POLLING` phase, we also need to manually call a waker
506        force_wake |= poll_state_value & NEED_TO_POLL_ALL != NONE;
507
508        let is_done = *this.is_stream_done && this.inner_streams.is_empty();
509
510        if next_item.is_some() || is_done {
511            Poll::Ready(next_item)
512        } else {
513            if force_wake {
514                cx.waker().wake_by_ref();
515            }
516
517            Poll::Pending
518        }
519    }
520}
521
522// Forwarding impl of Sink from the underlying stream
523#[cfg(feature = "sink")]
524impl<St, Item, Fc> Sink<Item> for FlattenUnorderedWithFlowController<St, Fc>
525where
526    St: Stream + Sink<Item>,
527{
528    type Error = St::Error;
529
530    delegate_sink!(stream, Item);
531}