async_unsync/
semaphore.rs

1//! A simple asynchronous semaphore for limiting and sequencing access
2//! to arbitrary shared resources.
3
4use core::{
5    cell::{Cell, UnsafeCell},
6    fmt,
7    future::Future,
8    marker::PhantomPinned,
9    mem,
10    pin::Pin,
11    ptr::{self, NonNull},
12    task::{Context, Poll, Waker},
13};
14
15/// An unsynchronized (`!Sync`), simple semaphore for asynchronous permit
16/// acquisition.
17pub struct Semaphore {
18    shared: UnsafeCell<Shared>,
19}
20
21impl Semaphore {
22    /// Creates a new semaphore with the initial number of permits.
23    pub const fn new(permits: usize) -> Self {
24        Self {
25            shared: UnsafeCell::new(Shared { waiters: WaiterQueue::new(), permits, closed: false }),
26        }
27    }
28
29    /// Closes the semaphore and returns the number of notified pending waiters.
30    ///
31    /// This prevents the semaphore from issuing new permits and notifies all
32    /// pending waiters.
33    pub fn close(&self) -> usize {
34        // SAFETY: no mutable or aliased access to shared possible
35        unsafe { (*self.shared.get()).close() }
36    }
37
38    /// Returns `true` if the semaphore has been closed
39    pub fn is_closed(&self) -> bool {
40        // SAFETY: no mutable or aliased access to shared possible
41        unsafe { (*self.shared.get()).is_closed() }
42    }
43
44    /// Returns the number of currently registered [`Future`]s waiting for a
45    /// [`Permit`].
46    pub fn waiters(&self) -> usize {
47        // SAFETY: no mutable or aliased access to shared possible
48        unsafe { (*self.shared.get()).waiters.len() }
49    }
50
51    /// Returns the current number of available permits.
52    pub fn available_permits(&self) -> usize {
53        // SAFETY: no mutable or aliased access to shared possible
54        unsafe { (*self.shared.get()).permits }
55    }
56
57    /// Adds `n` new permits to the semaphore.
58    pub fn add_permits(&self, n: usize) {
59        // SAFETY: no mutable or aliased access to shared possible
60        unsafe { (*self.shared.get()).add_permits(n) };
61    }
62
63    /// Permanently reduces the number of available permits by `n`.
64    pub fn remove_permits(&self, n: usize) {
65        // SAFETY: no mutable or aliased access to shared possible
66        let shared = unsafe { &mut (*self.shared.get()) };
67        shared.permits = shared.permits.saturating_sub(n);
68    }
69
70    /// Acquires a single [`Permit`] or returns an [error](TryAcquireError), if
71    /// there are no available permits.
72    ///
73    /// # Errors
74    ///
75    /// Fails, if the semaphore has been closed or has no available permits.
76    pub fn try_acquire(&self) -> Result<Permit<'_>, TryAcquireError> {
77        self.try_acquire_many(1)
78    }
79
80    /// Acquires `n` [`Permit`]s or returns an [error](TryAcquireError), if
81    /// there are not enough available permits.
82    ///
83    /// # Errors
84    ///
85    /// Fails, if the semaphore has been closed or has not enough available
86    /// permits.
87    pub fn try_acquire_many(&self, n: usize) -> Result<Permit<'_>, TryAcquireError> {
88        // SAFETY: no mutable or aliased access to shared possible
89        unsafe { (*self.shared.get()).try_acquire::<true>(n) }.map(|_| Permit::new(&self.shared, n))
90    }
91
92    /// Acquires a single [`Permit`], potentially blocking until one becomes
93    /// available.
94    ///
95    /// # Errors
96    ///
97    /// Awaiting the [`Future`] fails, if the semaphore has been closed.
98    pub fn acquire(&self) -> Acquire<'_> {
99        self.build_acquire(1)
100    }
101
102    /// Acquires `n` [`Permit`]s, potentially blocking until they become
103    /// available.
104    ///
105    /// # Errors
106    ///
107    /// Awaiting the [`Future`] fails, if the semaphore has been closed.
108    pub fn acquire_many(&self, n: usize) -> Acquire<'_> {
109        self.build_acquire(n)
110    }
111
112    /// Returns an correctly initialized [`Acquire`] future instance for
113    /// acquiring `wants` permits.
114    fn build_acquire(&self, wants: usize) -> Acquire<'_> {
115        Acquire { shared: &self.shared, waiter: Waiter::new(wants) }
116    }
117}
118
119impl fmt::Debug for Semaphore {
120    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121        f.debug_struct("Semaphore")
122            .field("is_closed", &self.is_closed())
123            .field("available_permits", &self.available_permits())
124            .field("waiters", &self.waiters())
125            .finish_non_exhaustive()
126    }
127}
128
129/// A permit representing access to the [`Semaphore`]'s guarded resource.
130pub struct Permit<'a> {
131    shared: &'a UnsafeCell<Shared>,
132    count: usize,
133}
134
135impl<'a> Permit<'a> {
136    /// Returns a new [`Permit`] without actually acquiring it.
137    ///
138    /// NOTE: Only use this to "revive" a Permit that has been explicitly
139    /// [forgotten](Permit::forget)!
140    fn new(shared: &'a UnsafeCell<Shared>, count: usize) -> Self {
141        Self { shared, count }
142    }
143
144    /// Drops the permit without returning it to the [`Semaphore`].
145    ///
146    /// This permanently reduces the number of available permits.
147    pub fn forget(self) {
148        mem::forget(self);
149    }
150}
151
152impl Drop for Permit<'_> {
153    fn drop(&mut self) {
154        // SAFETY: no mutable or aliased access to shared possible
155        let shared = unsafe { &mut (*self.shared.get()) };
156        shared.add_permits(self.count);
157    }
158}
159
160impl fmt::Debug for Permit<'_> {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        f.debug_struct("Permit").finish_non_exhaustive()
163    }
164}
165
166/// An error which can occur when a [`Semaphore`] has been closed.
167#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
168pub struct AcquireError(());
169
170impl fmt::Display for AcquireError {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        f.write_str("semaphore closed")
173    }
174}
175
176#[cfg(feature = "std")]
177impl std::error::Error for AcquireError {}
178
179/// An error which can occur when a [`Semaphore`] has been closed or has no
180/// available permits.
181#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
182pub enum TryAcquireError {
183    /// The semaphore has been [closed](Semaphore::close) and can not issue new
184    /// permits.
185    Closed,
186    /// The semaphore has no available permits.
187    NoPermits,
188}
189
190#[cfg(feature = "alloc")]
191impl From<TryAcquireError> for crate::error::TrySendError<()> {
192    fn from(err: TryAcquireError) -> Self {
193        match err {
194            TryAcquireError::Closed => Self::Closed(()),
195            TryAcquireError::NoPermits => Self::Full(()),
196        }
197    }
198}
199
200impl fmt::Display for TryAcquireError {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        match self {
203            TryAcquireError::Closed => f.write_str("semaphore closed"),
204            TryAcquireError::NoPermits => f.write_str("no permits available"),
205        }
206    }
207}
208
209#[cfg(feature = "std")]
210impl std::error::Error for TryAcquireError {}
211
212/// The [`Future`] returned by [`acquire`](Semaphore::acquire), which
213/// resolves when the required number of permits becomes available.
214pub struct Acquire<'a> {
215    /// The shared [`Semaphore`] state.
216    shared: &'a UnsafeCell<Shared>,
217    /// The state for waiting and resolving the future.
218    waiter: Waiter,
219}
220
221impl<'a> Future for Acquire<'a> {
222    type Output = Result<Permit<'a>, AcquireError>;
223
224    #[inline]
225    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
226        // SAFETY: The `Acquire` future can not be moved before being dropped
227        let waiter = unsafe { Pin::map_unchecked(self.as_ref(), |acquire| &acquire.waiter) };
228
229        // SAFETY: no mutable or aliased access to shared possible
230        match unsafe { (*self.shared.get()).poll_acquire(waiter, cx) } {
231            Poll::Ready(res) => {
232                // unconditionally setting waiting to false here avoids having
233                // to traverse the waiter queue again when the future is
234                // dropped.
235                waiter.state.set(WaiterState::Woken);
236                match res {
237                    Ok(_) => {
238                        let shared = self.as_ref().shared;
239                        let count = waiter.permits.take();
240                        Poll::Ready(Ok(Permit::new(shared, count)))
241                    }
242                    Err(e) => Poll::Ready(Err(e)),
243                }
244            }
245            Poll::Pending => Poll::Pending,
246        }
247    }
248}
249
250impl Drop for Acquire<'_> {
251    fn drop(&mut self) {
252        // SAFETY: no mutable or aliased access to shared possible
253        let shared = unsafe { &mut (*self.shared.get()) };
254
255        // remove the queued waker, if it was already enqueued
256        if let WaiterState::Waiting = self.waiter.state.get() {
257            // check, if there exists some entry in queue of waiters with the
258            // same ID as this future
259            // SAFETY: non-live waiters did not exist in queue, no aliased
260            // access possible
261            unsafe { shared.waiters.try_remove(&self.waiter) };
262        }
263
264        // return all "unused" (i.e., not passed on into a [`Permit`]) permits
265        // back to the semaphore
266        let permits = self.waiter.permits.get();
267        // the order is important here, because `add_permits` may mark permits
268        // as handed out again, if they are transfered to other waiters
269        shared.add_permits(permits);
270    }
271}
272
273/// The shared [`Semaphore`] accounting state.
274struct Shared {
275    /// The queue of registered `Waker`s.
276    waiters: WaiterQueue,
277    /// The number of currently available permits.
278    permits: usize,
279    /// The flag indicating if the semaphore has been closed.
280    closed: bool,
281}
282
283impl Shared {
284    /// Closes the semaphore and notifies all remaining waiters.
285    #[cold]
286    fn close(&mut self) -> usize {
287        // SAFETY: non-live waiters di not exist in queue, no aliased access
288        // possible
289        let woken = unsafe { self.waiters.wake_all() };
290        self.closed = true;
291        self.waiters = WaiterQueue::new();
292
293        woken
294    }
295
296    /// Returns `true` if the semaphore has been closed.
297    fn is_closed(&self) -> bool {
298        self.closed
299    }
300
301    /// Adds `n` permits and wakes all waiters whose requests can now be
302    /// completed.
303    fn add_permits(&mut self, mut n: usize) {
304        while n > 0 {
305            // keep checking the waiter queue until are permits are distributed
306            if let Some(waiter) = self.waiters.front() {
307                // SAFETY: All waiters remain valid while they are enqueued.
308                let waiter = unsafe { waiter.as_ref() };
309                // check, how many permits have already been assigned and
310                // how many were requested
311                let diff = waiter.wants - waiter.permits.get();
312                if diff > n {
313                    // waiter wants more permits than are still available
314                    // the waiter gets all available permits & the loop
315                    // terminated (n = 0)
316                    waiter.permits.set(diff - n);
317                    return;
318                } else {
319                    // the waiters request can be completed, assign all
320                    // missing permits, wake the waiter, continue the loop
321                    waiter.permits.set(waiter.wants);
322                    n -= diff;
323
324                    // SAFETY: All wakers are initialized when the `Waiter`s
325                    // are enqueued and all waiters remain valid while they are
326                    // enqueued.
327                    unsafe {
328                        waiter.state.set(WaiterState::Woken);
329                        waiter.waker.get().wake_by_ref();
330                        // ...finally, dequeue the notified waker
331                        self.waiters.pop_front(waiter);
332                    };
333                }
334            } else {
335                self.permits = self.permits.saturating_add(n);
336                return;
337            }
338        }
339    }
340
341    /// Attempts to reduce available permits by up to `n` or returns an error,
342    /// if the semaphore has been closed or has no available permits.
343    fn try_acquire<const STRICT: bool>(&mut self, n: usize) -> Result<usize, TryAcquireError> {
344        if self.is_closed() {
345            return Err(TryAcquireError::Closed);
346        }
347
348        if n > self.permits {
349            if STRICT || self.permits == 0 {
350                return Err(TryAcquireError::NoPermits);
351            }
352
353            // hand out all available permits
354            let count = self.permits;
355            self.permits = 0;
356            Ok(count)
357        } else {
358            // can not underflow because n <= permits
359            self.permits -= n;
360            Ok(n)
361        }
362    }
363
364    fn poll_acquire(
365        &mut self,
366        waiter: Pin<&Waiter>,
367        cx: &mut Context<'_>,
368    ) -> Poll<Result<(), AcquireError>> {
369        if self.closed {
370            // a waiter *may* or *may not* be in the queue, but `Acquire::drop`
371            // will take care of this eventually
372            return Poll::Ready(Err(AcquireError(())));
373        }
374
375        match waiter.state.get() {
376            WaiterState::Woken => Poll::Ready(Ok(())),
377            WaiterState::Inert => self.poll_acquire_initial(waiter, cx),
378            WaiterState::Waiting => Poll::Pending,
379        }
380    }
381
382    fn poll_acquire_initial(
383        &mut self,
384        waiter: Pin<&Waiter>,
385        cx: &mut Context<'_>,
386    ) -> Poll<Result<(), AcquireError>> {
387        // on first poll, check if there are enough permits to resolve
388        // immediately or enqueue a waiter ticket to be notified (i.e. polled
389        // again) later
390        match self.try_acquire::<false>(waiter.wants) {
391            Ok(n) => {
392                // check if we got the desired amount or less
393                waiter.permits.set(n);
394                if n == waiter.wants {
395                    return Poll::Ready(Ok(()));
396                }
397            }
398            Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError(()))),
399            _ => {}
400        };
401
402        // if no or not enough permits are currently available, enqueue a
403        // waiter request ticket, to be notified when capacity becomes
404        // available
405        waiter.state.set(WaiterState::Waiting);
406        waiter.waker.set(cx.waker().clone());
407        // SAFETY: All waiters remain valid while they are enqueued.
408        //
409        // Each `Acquire` future contains (owns) a `Waiter` and may either live
410        // on the stack or the heap.
411        // Each future *must* be pinned before it can be polled and therefore
412        // both the future and the waiter will remain in-place for their entire
413        // lifetime.
414        // When the future/waiter are cancelled or dropped, they will dequeue
415        // themselves to ensure no iteration over freed data is possible.
416        // Since they must be pinned, leaking or "forgetting" the futures does
417        // not break this invariant:
418        // In case of a heap-pinned future, the destructor will not be run, but
419        // the data will still remain valid for the program duration.
420        // In case of a future safely pinned to the stack, there is no way to
421        // actually prevent the destructor from running, since only the pinned
422        // reference can be leaked.
423        unsafe { self.waiters.push_back(waiter.get_ref()) }
424        Poll::Pending
425    }
426}
427
428struct WaiterQueue {
429    head: *const Waiter,
430    tail: *const Waiter,
431}
432
433impl WaiterQueue {
434    /// Returns a new empty queue.
435    const fn new() -> Self {
436        Self { head: ptr::null(), tail: ptr::null() }
437    }
438
439    /// Returns the first `Waiter` of `null`, if the queue is empty.
440    fn front(&self) -> Option<NonNull<Waiter>> {
441        NonNull::new(self.head as *mut Waiter)
442    }
443
444    /// Returns the number of currently enqueued `Waiter`s.
445    ///
446    /// # Safety
447    ///
448    /// All pointers must reference valid, live and non-aliased `Waiter`s.
449    #[cold]
450    unsafe fn len(&self) -> usize {
451        // this is only used in the [`Debug`] implementation, so counting each
452        // waiter one by one here is irrelevant to performance
453        let mut curr = self.head;
454        let mut waiting = 0;
455        while !curr.is_null() {
456            // SAFETY: curr is non-null, validity is required by function safety
457            curr = unsafe { (*curr).next.get() };
458            waiting += 1;
459        }
460
461        waiting
462    }
463
464    /// Enqueues `waiter` at the back of the queue.
465    ///
466    /// # Safety
467    ///
468    /// All pointers must reference valid, live and non-aliased `Waiter`s.
469    unsafe fn push_back(&mut self, waiter: &Waiter) {
470        if self.tail.is_null() {
471            // queue is empty, insert waiter at head
472            self.head = waiter;
473            self.tail = waiter;
474        } else {
475            // queue is not empty, insert at tail
476            // SAFETY: non-live waiters did not exist in queue, no aliased
477            // access possible
478            unsafe { (*self.tail).next.set(waiter) };
479            waiter.prev.set(self.tail);
480            self.tail = waiter;
481        }
482    }
483
484    /// Searches for `waiter` in the queue and removes it if found.
485    ///
486    /// # Safety
487    ///
488    /// All pointers must reference valid, live and non-aliased `Waiter`s.
489    #[cold]
490    unsafe fn try_remove(&mut self, waiter: &Waiter) {
491        let prev = waiter.prev.get();
492        if prev.is_null() {
493            self.head = waiter.next.get();
494        } else {
495            // SAFETY: prev is non-null, liveness required by function invariant
496            unsafe { (*prev).next.set(waiter.next.get()) };
497        }
498
499        let next = waiter.next.get();
500        if next.is_null() {
501            self.tail = waiter.prev.get();
502        } else {
503            // SAFETY: next non-null, liveness required by function invariant
504            unsafe { (*next).prev.set(waiter.prev.get()) };
505        }
506    }
507
508    /// Removes `head` from the front of the queue.
509    ///
510    /// # Safety
511    ///
512    /// All pointers must reference valid, live and non-aliased `Waiter`s and
513    /// `head` must be the current queue head.
514    #[inline]
515    unsafe fn pop_front(&mut self, head: &Waiter) {
516        self.head = head.next.get();
517        if self.head.is_null() {
518            self.tail = ptr::null();
519        } else {
520            unsafe { (*self.head).prev.set(ptr::null()) };
521        }
522    }
523
524    #[cold]
525    unsafe fn wake_all(&mut self) -> usize {
526        let mut curr = self.head;
527        let mut woken = 0;
528
529        while !curr.is_null() {
530            // SAFETY: liveness/non-aliasedness required for all waiters by
531            // function invariant, curr is non-null and valid
532            unsafe {
533                let waiter = &*curr;
534                waiter.state.set(WaiterState::Woken);
535                waiter.waker.get().wake_by_ref();
536                curr = waiter.next.get();
537            }
538
539            woken += 1;
540        }
541
542        woken
543    }
544}
545
546/// A queue-able waiter that will be notified, when its requested number of
547/// semaphore permits has been granted.
548struct Waiter {
549    /// The number of requested permits.
550    wants: usize,
551    /// The waker to be woken if the future is enqueued as waiting.
552    ///
553    /// This field is **never** used, if the waiter does not get enqueued,
554    /// because its request can be fulfilled immediately.
555    waker: LateInitWaker,
556    /// The flag indicating the waiter's state.
557    state: Cell<WaiterState>,
558    /// The counter of already collected permits.
559    permits: Cell<usize>,
560    /// The pointer to the next enqueued waiter
561    next: Cell<*const Self>,
562    /// The pointer to the previous enqueued waiter
563    prev: Cell<*const Self>,
564    // see: https://gist.github.com/Darksonn/1567538f56af1a8038ecc3c664a42462
565    // this marker lets miri pass the self-referential nature of this struct
566    _marker: PhantomPinned,
567}
568
569impl Waiter {
570    const fn new(wants: usize) -> Self {
571        Self {
572            wants,
573            waker: LateInitWaker::new(),
574            state: Cell::new(WaiterState::Inert),
575            permits: Cell::new(0),
576            next: Cell::new(ptr::null()),
577            prev: Cell::new(ptr::null()),
578            _marker: PhantomPinned,
579        }
580    }
581}
582
583/// The current state of a [`Waiter`].
584#[derive(Clone, Copy)]
585enum WaiterState {
586    /// The waiter is inert and its future has not yet been polled.
587    Inert,
588    /// The waiter's future has been polled and the waiter was enqueued.
589    Waiting,
590    /// The waiter's future has been polled to completion.
591    ///
592    /// If the waiter had been queued it is now no longer queued.
593    Woken,
594}
595
596/// The `Waker` in an `Acquire` future is only used in case it gets enqueued
597/// in the `waiters` list or not at all.
598///
599/// `get` is only called during traversal of that list, so it is guaranteed to
600/// have been initialized
601struct LateInitWaker(UnsafeCell<Option<Waker>>);
602
603impl LateInitWaker {
604    const fn new() -> Self {
605        Self(UnsafeCell::new(None))
606    }
607
608    fn set(&self, waker: Waker) {
609        // SAFETY: no mutable or aliased access to waker possible, writing the
610        // waker is unproblematic due to the required liveness of the pointer.
611        // this is never called when there already is a waker
612        unsafe { self.0.get().write(Some(waker)) };
613    }
614
615    unsafe fn get(&self) -> &Waker {
616        // SAFETY: initness required as function invariant
617        match &*self.0.get() {
618            Some(waker) => waker,
619            None => core::hint::unreachable_unchecked(),
620        }
621    }
622}
623
624#[cfg(test)]
625mod tests {
626    use futures_lite::future;
627
628    use core::{
629        future::Future as _,
630        ptr,
631        task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
632    };
633
634    #[test]
635    fn try_acquire_one() {
636        let sem = super::Semaphore::new(0);
637        assert!(sem.try_acquire().is_err());
638        sem.add_permits(2);
639        let p1 = sem.try_acquire().unwrap();
640        let p2 = sem.try_acquire().unwrap();
641        assert_eq!(sem.available_permits(), 0);
642
643        drop((p1, p2));
644        assert_eq!(sem.available_permits(), 2);
645    }
646
647    #[test]
648    fn try_acquire_many() {
649        let sem = super::Semaphore::new(0);
650        assert!(sem.try_acquire_many(3).is_err());
651        sem.add_permits(2);
652        assert!(sem.try_acquire_many(3).is_err());
653        sem.add_permits(1);
654        let permit = sem.try_acquire_many(3).unwrap();
655        assert_eq!(permit.count, 3);
656        drop(permit);
657        assert_eq!(sem.available_permits(), 3);
658    }
659
660    #[test]
661    fn acquire_never() {
662        future::block_on(async {
663            let sem = super::Semaphore::new(0);
664            let mut fut = core::pin::pin!(sem.acquire());
665
666            core::future::poll_fn(|cx| {
667                assert!(fut.as_mut().poll(cx).is_pending());
668                Poll::Ready(())
669            })
670            .await;
671
672            assert_eq!(sem.available_permits(), 0);
673        });
674    }
675
676    #[test]
677    fn acquire() {
678        future::block_on(async {
679            let sem = super::Semaphore::new(0);
680            let mut fut = core::pin::pin!(sem.acquire());
681            core::future::poll_fn(|cx| {
682                assert!(fut.as_mut().poll(cx).is_pending());
683                Poll::Ready(())
684            })
685            .await;
686
687            sem.add_permits(1);
688            let permit = fut.await.unwrap();
689            drop(permit);
690            assert_eq!(sem.available_permits(), 1);
691        });
692    }
693
694    #[test]
695    fn acquire_one() {
696        future::block_on(async {
697            let sem = super::Semaphore::new(0);
698            let mut fut = core::pin::pin!(sem.acquire());
699
700            // poll future once to enqueue waiter
701            core::future::poll_fn(|cx| {
702                assert!(fut.as_mut().poll(cx).is_pending());
703                assert_eq!(sem.waiters(), 1);
704                // add 2 permits, one goes directly to the enqueued waiter and
705                // wakes it, one goes into the semaphore
706                sem.add_permits(2);
707                Poll::Ready(())
708            })
709            .await;
710
711            // future must resolve now, since it has been woken
712            let permit = fut.await.unwrap();
713            assert_eq!(sem.available_permits(), 1);
714            drop(permit);
715            assert_eq!(sem.available_permits(), 2);
716        });
717    }
718
719    #[test]
720    fn poll_acquire_after_completion() {
721        future::block_on(async {
722            let sem = super::Semaphore::new(0);
723            let mut fut = core::pin::pin!(sem.acquire());
724            core::future::poll_fn(|cx| {
725                assert!(fut.as_mut().poll(cx).is_pending());
726                Poll::Ready(())
727            })
728            .await;
729
730            sem.add_permits(1);
731
732            core::future::poll_fn(|cx| {
733                assert!(fut.as_mut().poll(cx).is_ready());
734                // polling again after completion, works in this case, but might
735                // cause a Waker leak under other circumstances.
736                assert!(fut.as_mut().poll(cx).is_ready());
737                Poll::Ready(())
738            })
739            .await;
740
741            assert_eq!(sem.available_permits(), 1);
742        });
743    }
744
745    #[test]
746    fn poll_future() {
747        static RAW_VTABLE: RawWakerVTable = RawWakerVTable::new(
748            |_| RawWaker::new(ptr::null(), &RAW_VTABLE),
749            |_| {},
750            |_| {},
751            |_| {},
752        );
753
754        let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &RAW_VTABLE)) };
755        let mut cx = Context::from_waker(&waker);
756
757        let sem = super::Semaphore::new(0);
758        let mut fut = Box::pin(sem.build_acquire(1));
759
760        assert!(fut.as_mut().poll(&mut cx).is_pending());
761        assert_eq!(sem.waiters(), 1);
762        sem.add_permits(1);
763
764        assert!(fut.as_mut().poll(&mut cx).is_ready());
765        drop(fut);
766        assert_eq!(sem.waiters(), 0);
767    }
768
769    #[test]
770    fn acquire_many() {
771        future::block_on(async {
772            let sem = super::Semaphore::new(0);
773            let mut f1 = Box::pin(sem.acquire_many(2));
774            let mut f2 = Box::pin(sem.acquire_many(1));
775
776            core::future::poll_fn(|cx| {
777                // poll both futures once to establish order
778                assert!(f1.as_mut().poll(cx).is_pending());
779                assert!(f2.as_mut().poll(cx).is_pending());
780
781                assert_eq!(sem.waiters(), 2);
782                sem.add_permits(1);
783
784                // due to established order, f2 must not resolve before f1
785                assert!(f2.as_mut().poll(cx).is_pending());
786
787                // adding another permit must wake f1
788                sem.add_permits(1);
789                assert_eq!(sem.waiters(), 1);
790                Poll::Ready(())
791            })
792            .await;
793
794            // f1 should resolve now
795            let permit = f1.await.unwrap();
796            assert_eq!(sem.waiters(), 1);
797
798            // dropping the permit must pass one permit to the next waiter,
799            // wake it and return the other permit back to the semaphore
800            drop(permit);
801            assert_eq!(sem.waiters(), 0);
802            assert_eq!(sem.available_permits(), 1);
803
804            let permit = f2.await.unwrap();
805            assert_eq!(sem.available_permits(), 1);
806            drop(permit);
807
808            assert_eq!(sem.available_permits(), 2);
809        });
810    }
811
812    #[test]
813    fn cleanup() {
814        future::block_on(async {
815            let sem = super::Semaphore::new(0);
816
817            let mut fut = Box::pin(sem.acquire());
818            // poll once to enqueue the future as waiting
819            core::future::poll_fn(|cx| {
820                assert!(fut.as_mut().poll(cx).is_pending());
821                Poll::Ready(())
822            })
823            .await;
824
825            // dropping the future should clear up its queue entry immediately
826            drop(fut);
827            assert_eq!(sem.waiters(), 0);
828
829            let mut fut = Box::pin(sem.acquire());
830            // poll once to enqueue the future as waiting
831            core::future::poll_fn(|cx| {
832                assert!(fut.as_mut().poll(cx).is_pending());
833                Poll::Ready(())
834            })
835            .await;
836
837            // add 1 permit to wake future
838            sem.add_permits(1);
839            // ..and close semaphore
840            assert_eq!(sem.close(), 0);
841
842            assert!(fut.await.is_err());
843            assert_eq!(sem.waiters(), 0);
844            assert_eq!(sem.available_permits(), 1);
845        });
846    }
847
848    #[test]
849    fn cleanup_after_wake() {
850        future::block_on(async {
851            let sem = super::Semaphore::new(0);
852            let mut fut = Box::pin(sem.acquire());
853
854            core::future::poll_fn(|cx| {
855                // poll once to enque the future as waiting
856                assert!(fut.as_mut().poll(cx).is_pending());
857                Poll::Ready(())
858            })
859            .await;
860
861            // adding a permit will wake the Acquire future instead of
862            // increasing the amount of available permits
863            sem.add_permits(1);
864            // dropping the future should return the added permit instead of
865            // removing the waker from the queue
866            drop(fut);
867
868            assert_eq!(sem.waiters(), 0);
869            assert_eq!(sem.available_permits(), 1);
870        });
871    }
872
873    #[test]
874    fn close() {
875        future::block_on(async {
876            let sem = super::Semaphore::new(1);
877            let permit = sem.acquire().await.unwrap();
878
879            let mut f1 = Box::pin(sem.acquire());
880            let mut f2 = Box::pin(sem.acquire());
881            core::future::poll_fn(|cx| {
882                // poll once to enqueue the futures as waiting
883                assert!(f1.as_mut().poll(cx).is_pending());
884                assert!(f2.as_mut().poll(cx).is_pending());
885                Poll::Ready(())
886            })
887            .await;
888
889            assert_eq!(sem.waiters(), 2);
890            assert_eq!(sem.close(), 2);
891            assert_eq!(sem.waiters(), 0);
892
893            core::future::poll_fn(|cx| {
894                // closing the semaphore should have woken the future
895                match f1.as_mut().poll(cx) {
896                    Poll::Ready(Err(_)) => Poll::Ready(()),
897                    _ => panic!("acquire future should have resolved"),
898                }
899            })
900            .await;
901
902            // dropping the resolved future should have no effect
903            drop(f1);
904            assert_eq!(sem.available_permits(), 0);
905            // awaiting f2 must not deadlock, even if not polled manually
906            assert!(f2.await.is_err());
907
908            // dropping the permit must return even though the semaphore has
909            // been closed
910            drop(permit);
911            assert_eq!(sem.available_permits(), 1);
912
913            // no further permits must be acquirable
914            assert!(sem.try_acquire().is_err());
915            assert!(sem.acquire().await.is_err());
916        });
917    }
918
919    #[test]
920    fn return_outstanding_permit_on_close() {
921        future::block_on(async {
922            let sem = super::Semaphore::new(1);
923            let permit = sem.acquire().await.unwrap();
924
925            let mut fut = Box::pin(sem.acquire());
926            assert!(future::poll_once(&mut fut).await.is_none());
927            assert_eq!(sem.waiters(), 1);
928
929            // dropping a permit will transfer it to the next waiter, waking it
930            drop(permit);
931            assert_eq!(sem.waiters(), 0);
932            assert_eq!(sem.available_permits(), 0);
933
934            // closing by itself will not return the outstanding permit
935            sem.close();
936            assert_eq!(sem.available_permits(), 0);
937
938            // ... but awaiting the Acquire future should!
939            assert!(fut.await.is_err());
940            assert_eq!(sem.available_permits(), 1);
941        });
942    }
943
944    #[test]
945    fn return_outstanding_permit_on_cancel() {
946        future::block_on(async {
947            let sem = super::Semaphore::new(0);
948
949            let mut fut = Box::pin(sem.acquire());
950            assert!(future::poll_once(&mut fut).await.is_none());
951            assert_eq!(sem.waiters(), 1);
952
953            sem.add_permits(1);
954            assert_eq!(sem.waiters(), 0);
955
956            // dropping the unresolved future must return the already
957            // transferred permit ownership back to the semaphore
958            drop(fut);
959
960            assert_eq!(sem.waiters(), 0);
961            assert_eq!(sem.available_permits(), 1);
962        });
963    }
964
965    #[test]
966    fn forget_acquire_future() {
967        future::block_on(async {
968            async fn acquire_and_forget(sem: &super::Semaphore) {
969                let waiters = sem.waiters();
970                let mut fut = std::pin::pin!(sem.acquire());
971                assert!(future::poll_once(&mut fut).await.is_none());
972                assert_eq!(sem.waiters(), waiters + 1);
973
974                // this will not leak the future itself, but only the pinned
975                // reference to it, so the actual future will still be dropped
976                // correctly
977                std::mem::forget(fut);
978            }
979
980            let sem = super::Semaphore::new(0);
981            acquire_and_forget(&sem).await;
982            assert_eq!(sem.waiters(), 0);
983
984            // trash previously used stack space
985            let mut arr = [0u8; 1000];
986            for v in &mut arr {
987                *v = 255;
988            }
989
990            let mut f1 = std::pin::pin!(sem.acquire());
991            assert!(future::poll_once(&mut f1).await.is_none());
992            let mut f2 = std::pin::pin!(sem.acquire());
993            assert!(future::poll_once(&mut f2).await.is_none());
994            let mut f3 = std::pin::pin!(sem.acquire());
995            assert!(future::poll_once(&mut f3).await.is_none());
996
997            assert_eq!(sem.waiters(), 3);
998            assert_eq!(sem.available_permits(), 0);
999            sem.add_permits(3);
1000
1001            assert!(matches!(future::poll_once(&mut f1).await, Some(Ok(_))));
1002            assert!(matches!(future::poll_once(&mut f2).await, Some(Ok(_))));
1003            assert!(matches!(future::poll_once(&mut f3).await, Some(Ok(_))));
1004
1005            assert_eq!(sem.waiters(), 0);
1006            assert_eq!(sem.available_permits(), 3);
1007        });
1008    }
1009}