async_unsync/
oneshot.rs

1//! An unsync **oneshot** channel implementation.
2
3use core::{
4    cell::UnsafeCell,
5    fmt,
6    future::Future,
7    pin::Pin,
8    task::{Context, Poll, Waker},
9};
10
11#[cfg(feature = "alloc")]
12use crate::alloc::rc::Rc;
13
14use crate::error::{SendError, TryRecvError};
15
16/// Creates a new oneshot channel.
17pub const fn channel<T>() -> OneshotChannel<T> {
18    OneshotChannel(UnsafeCell::new(Slot {
19        value: None,
20        recv_waker: None,
21        close_waker: None,
22        closed: false,
23    }))
24}
25
26/// An error which can occur when receiving on a closed channel.
27#[derive(Clone, Copy, Debug, PartialEq, Eq)]
28pub struct RecvError;
29
30/// An unsynchronized (`!Sync`), asynchronous oneshot channel.
31///
32/// This is useful for asynchronously handing a single value from one future to
33/// another.
34pub struct OneshotChannel<T>(UnsafeCell<Slot<T>>);
35
36impl<T> OneshotChannel<T> {
37    /// Splits the channel into borrowing [`SenderRef`] and [`ReceiverRef`]
38    /// handles.
39    pub fn split(&mut self) -> (SenderRef<'_, T>, ReceiverRef<'_, T>) {
40        let slot = &self.0;
41        (SenderRef { slot }, ReceiverRef { slot })
42    }
43
44    #[cfg(feature = "alloc")]
45    /// Splits the channel into owning [`Sender`] and
46    /// [`Receiver`] handles.
47    ///
48    /// This requires one additional allocation over
49    /// [`split`](OneshotChannel::split), but avoids potential lifetime
50    /// restrictions.
51    pub fn into_split(self) -> (Sender<T>, Receiver<T>) {
52        let slot = Rc::new(self.0);
53        (Sender { slot: Rc::clone(&slot) }, Receiver { slot })
54    }
55}
56
57#[cfg(feature = "alloc")]
58/// An owning handle for sending an element through a split [`OneshotChannel`].
59pub struct Sender<T> {
60    slot: Rc<UnsafeCell<Slot<T>>>,
61}
62
63#[cfg(feature = "alloc")]
64impl<T> Sender<T> {
65    /// Returns `true` if the channel has been closed.
66    pub fn is_closed(&self) -> bool {
67        // SAFETY: no mutable or aliased access to slot possible
68        unsafe { (*self.slot.get()).closed }
69    }
70
71    /// Polls the channel, resolving if the channel has been closed.
72    pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
73        // SAFETY: no mutable or aliased access to slot possible
74        unsafe { (*self.slot.get()).poll_closed(cx) }
75    }
76
77    /// Resolves when the channel is closed.
78    pub async fn closed(&mut self) {
79        core::future::poll_fn(|cx| self.poll_closed(cx)).await
80    }
81
82    /// Sends a value through the channel.
83    ///
84    /// # Errors
85    ///
86    /// Fails, if the channel is closed.
87    pub fn send(self, value: T) -> Result<(), SendError<T>> {
88        // SAFETY: no mutable or aliased access to slot possible
89        unsafe { (*self.slot.get()).send(value) }
90    }
91}
92
93#[cfg(feature = "alloc")]
94impl<T> Drop for Sender<T> {
95    fn drop(&mut self) {
96        // SAFETY: no mutable or aliased access to slot possible
97        unsafe { (*self.slot.get()).closed = true }
98    }
99}
100
101#[cfg(feature = "alloc")]
102impl<T> fmt::Debug for Sender<T>
103where
104    T: fmt::Debug,
105{
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        // SAFETY: no mutable or aliased access to slot possible
108        let value = unsafe { &(*self.slot.get()).value };
109        f.debug_struct("Sender")
110            .field("is_closed", &self.is_closed())
111            .field("value", value)
112            .finish_non_exhaustive()
113    }
114}
115
116/// A borrowing handle for sending an element through a split
117/// [`OneshotChannel`].
118pub struct SenderRef<'a, T> {
119    slot: &'a UnsafeCell<Slot<T>>,
120}
121
122impl<'a, T> SenderRef<'a, T> {
123    /// Returns `true` if the channel has been closed.
124    pub fn is_closed(&self) -> bool {
125        // SAFETY: no mutable or aliased access to slot possible
126        unsafe { (*self.slot.get()).closed }
127    }
128
129    /// Polls the channel, resolving if the channel has been closed.
130    pub fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
131        // SAFETY: no mutable or aliased access to slot possible
132        unsafe { (*self.slot.get()).poll_closed(cx) }
133    }
134
135    /// Resolves when the channel is closed.
136    pub async fn closed(&mut self) {
137        core::future::poll_fn(|cx| self.poll_closed(cx)).await
138    }
139
140    /// Sends a value through the channel.
141    ///
142    /// # Errors
143    ///
144    /// Fails, if the channel is closed.
145    pub fn send(self, value: T) -> Result<(), SendError<T>> {
146        // SAFETY: no mutable or aliased access to slot possible
147        unsafe { (*self.slot.get()).send(value) }
148    }
149}
150
151impl<T> Drop for SenderRef<'_, T> {
152    fn drop(&mut self) {
153        // SAFETY: no mutable or aliased access to slot possible
154        unsafe { (*self.slot.get()).closed = true }
155    }
156}
157
158impl<T> fmt::Debug for SenderRef<'_, T>
159where
160    T: fmt::Debug,
161{
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        // SAFETY: no mutable or aliased access to slot possible
164        let value = unsafe { &(*self.slot.get()).value };
165        f.debug_struct("SenderRef")
166            .field("is_closed", &self.is_closed())
167            .field("value", value)
168            .finish_non_exhaustive()
169    }
170}
171
172#[cfg(feature = "alloc")]
173/// An owning handle for receiving elements through a split [`OneshotChannel`].
174///
175/// This receiver implements [`Future`] and can be awaited directly:
176///
177/// ```
178/// use async_unsync::oneshot;
179///
180/// # async fn example_receiver() {
181/// let (tx, rx) = oneshot::channel().into_split();
182/// tx.send(()).unwrap();
183/// let _ = rx.await;
184/// # }
185/// ```
186pub struct Receiver<T> {
187    slot: Rc<UnsafeCell<Slot<T>>>,
188}
189
190#[cfg(feature = "alloc")]
191impl<T> Receiver<T> {
192    /// Returns `true` if the channel has been closed.
193    pub fn is_closed(&self) -> bool {
194        // SAFETY: no mutable or aliased access to slot possible
195        unsafe { (*self.slot.get()).closed }
196    }
197
198    /// Closes the channel, causing any [`closed`](Sender::closed) or subsequent
199    /// [`poll_closed`](Sender::poll_closed) calls to resolve and any subsequent
200    /// [`send`s](Sender::send) to fail on the corresponding [`Sender`].
201    pub fn close(&mut self) {
202        // SAFETY: no mutable or aliased access to slot possible
203        unsafe { (*self.slot.get()).close_and_wake() }
204    }
205
206    /// Receives an element through the channel.
207    ///
208    /// # Errors
209    ///
210    /// Fails, if the channel is [empty](TryRecvError::Empty) or
211    /// [disconnected](TryRecvError::Disconnected).
212    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
213        // SAFETY: no mutable or aliased access to slot possible
214        unsafe { (*self.slot.get()).try_recv() }
215    }
216}
217
218// Receiver implements Future, so it can be awaited directly.
219#[cfg(feature = "alloc")]
220impl<T> Future for Receiver<T> {
221    type Output = Result<T, RecvError>;
222
223    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
224        let slot = &self.get_mut().slot;
225        // SAFETY: no mutable or aliased access to slot possible
226        unsafe { (*slot.get()).poll_recv(cx) }
227    }
228}
229
230#[cfg(feature = "alloc")]
231impl<T> Drop for Receiver<T> {
232    fn drop(&mut self) {
233        self.close();
234    }
235}
236
237#[cfg(feature = "alloc")]
238impl<T> fmt::Debug for Receiver<T>
239where
240    T: fmt::Debug,
241{
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        // SAFETY: no mutable or aliased access to slot possible
244        let value = unsafe { &(*self.slot.get()).value };
245        f.debug_struct("Receiver")
246            .field("is_closed", &self.is_closed())
247            .field("value", value)
248            .finish_non_exhaustive()
249    }
250}
251
252/// A borrowing handle for receiving elements through a split
253/// [`OneshotChannel`].
254///
255/// # Note
256///
257///
258///
259/// This receiver implements [`Future`] and can be awaited directly:
260///
261/// ```
262/// # async fn example_receiver_ref() {
263/// let mut chan = async_unsync::oneshot::channel();
264/// let (tx, rx) = chan.split();
265/// tx.send(()).unwrap();
266/// let _ = rx.await;
267/// # }
268/// ```
269pub struct ReceiverRef<'a, T> {
270    slot: &'a UnsafeCell<Slot<T>>,
271}
272
273impl<T> ReceiverRef<'_, T> {
274    /// Returns `true` if the channel has been closed.
275    pub fn is_closed(&self) -> bool {
276        // SAFETY: no mutable or aliased access to slot possible
277        unsafe { (*self.slot.get()).closed }
278    }
279
280    /// Closes the channel, causing any [`closed`](SenderRef::closed) or
281    /// subsequent [`poll_closed`](SenderRef::poll_closed) calls to resolve and
282    /// any subsequent [`send`s](SenderRef::send) to fail on the corresponding
283    /// [`SenderRef`].
284    pub fn close(&mut self) {
285        // SAFETY: no mutable or aliased access to slot possible
286        unsafe { (*self.slot.get()).close_and_wake() }
287    }
288
289    /// Receives an element through the channel.
290    ///
291    /// # Errors
292    ///
293    /// Fails, if the channel is [empty](TryRecvError::Empty) or
294    /// [disconnected](TryRecvError::Disconnected).
295    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
296        // SAFETY: no mutable or aliased access to slot possible
297        unsafe { (*self.slot.get()).try_recv() }
298    }
299}
300
301impl<T> Future for ReceiverRef<'_, T> {
302    type Output = Result<T, RecvError>;
303
304    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
305        let slot = self.get_mut().slot;
306        // SAFETY: no mutable or aliased access to slot possible
307        unsafe { &mut *slot.get() }.poll_recv(cx)
308    }
309}
310
311impl<T> Drop for ReceiverRef<'_, T> {
312    fn drop(&mut self) {
313        self.close();
314    }
315}
316
317impl<T> fmt::Debug for ReceiverRef<'_, T>
318where
319    T: fmt::Debug,
320{
321    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322        // SAFETY: no mutable or aliased access to slot possible
323        let value = unsafe { &(*self.slot.get()).value };
324        f.debug_struct("ReceiverRef")
325            .field("is_closed", &self.is_closed())
326            .field("value", value)
327            .finish_non_exhaustive()
328    }
329}
330
331/// A shared underlying data structure for the internal state of a
332/// [`OneshotChannel`].
333struct Slot<T> {
334    // HINT: it's not worth squeezing all Option tags into a single byte and
335    // using MaybeUninits instead. Source: I tried
336    value: Option<T>,
337    recv_waker: Option<Waker>,
338    close_waker: Option<Waker>,
339    closed: bool,
340}
341
342impl<T> Slot<T> {
343    fn send(&mut self, value: T) -> Result<(), SendError<T>> {
344        // check, if channel has been closed
345        if self.closed {
346            return Err(SendError(value));
347        }
348
349        // store sent value & wake a possibly registered receiver
350        self.value = Some(value);
351        if let Some(waker) = &self.recv_waker {
352            waker.wake_by_ref();
353        }
354
355        Ok(())
356    }
357
358    fn close_and_wake(&mut self) {
359        if self.closed {
360            return;
361        }
362
363        self.closed = true;
364        if let Some(waker) = &self.close_waker {
365            waker.wake_by_ref();
366        }
367    }
368
369    fn poll_closed(&mut self, cx: &mut Context<'_>) -> Poll<()> {
370        if self.closed {
371            Poll::Ready(())
372        } else {
373            self.close_waker = Some(cx.waker().clone());
374            Poll::Pending
375        }
376    }
377
378    fn try_recv(&mut self) -> Result<T, TryRecvError> {
379        match self.value.take() {
380            Some(value) => Ok(value),
381            None => match self.closed {
382                true => Err(TryRecvError::Disconnected),
383                false => Err(TryRecvError::Empty),
384            },
385        }
386    }
387
388    fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Result<T, RecvError>> {
389        match self.try_recv() {
390            Ok(value) => Poll::Ready(Ok(value)),
391            Err(TryRecvError::Disconnected) => Poll::Ready(Err(RecvError)),
392            Err(TryRecvError::Empty) => {
393                self.recv_waker = Some(cx.waker().clone());
394                Poll::Pending
395            }
396        }
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use std::{future::Future as _, task::Poll};
403
404    use futures_lite::future;
405
406    #[test]
407    fn recv() {
408        future::block_on(async {
409            let mut chan = super::channel::<i32>();
410            let (tx, rx) = chan.split();
411
412            tx.send(-1).unwrap();
413            assert_eq!(rx.await, Ok(-1));
414        });
415    }
416
417    #[test]
418    fn split_twice() {
419        future::block_on(async {
420            let mut chan = super::channel::<()>();
421            let (tx, rx) = chan.split();
422
423            tx.send(()).unwrap();
424            assert!(rx.await.is_ok());
425
426            let (tx, rx) = chan.split();
427            assert!(tx.send(()).is_err());
428            assert!(rx.await.is_err());
429        });
430    }
431
432    #[test]
433    fn wake_on_close() {
434        future::block_on(async {
435            let mut chan = super::channel::<i32>();
436            let (tx, mut rx) = chan.split();
437            let mut rx = core::pin::pin!(rx);
438
439            // poll once: pending
440            core::future::poll_fn(|cx| {
441                assert!(rx.as_mut().poll(cx).is_pending());
442                Poll::Ready(())
443            })
444            .await;
445
446            // drop tx & close channel
447            drop(tx);
448
449            // receiver should return ready + error
450            core::future::poll_fn(move |cx| {
451                assert!(rx.as_mut().poll(cx).is_ready());
452                Poll::Ready(())
453            })
454            .await;
455        });
456    }
457}