1use core::{
5 cell::{Cell, UnsafeCell},
6 fmt,
7 future::Future,
8 marker::PhantomPinned,
9 mem,
10 pin::Pin,
11 ptr::{self, NonNull},
12 task::{Context, Poll, Waker},
13};
14
15pub struct Semaphore {
18 shared: UnsafeCell<Shared>,
19}
20
21impl Semaphore {
22 pub const fn new(permits: usize) -> Self {
24 Self {
25 shared: UnsafeCell::new(Shared { waiters: WaiterQueue::new(), permits, closed: false }),
26 }
27 }
28
29 pub fn close(&self) -> usize {
34 unsafe { (*self.shared.get()).close() }
36 }
37
38 pub fn is_closed(&self) -> bool {
40 unsafe { (*self.shared.get()).is_closed() }
42 }
43
44 pub fn waiters(&self) -> usize {
47 unsafe { (*self.shared.get()).waiters.len() }
49 }
50
51 pub fn available_permits(&self) -> usize {
53 unsafe { (*self.shared.get()).permits }
55 }
56
57 pub fn add_permits(&self, n: usize) {
59 unsafe { (*self.shared.get()).add_permits(n) };
61 }
62
63 pub fn remove_permits(&self, n: usize) {
65 let shared = unsafe { &mut (*self.shared.get()) };
67 shared.permits = shared.permits.saturating_sub(n);
68 }
69
70 pub fn try_acquire(&self) -> Result<Permit<'_>, TryAcquireError> {
77 self.try_acquire_many(1)
78 }
79
80 pub fn try_acquire_many(&self, n: usize) -> Result<Permit<'_>, TryAcquireError> {
88 unsafe { (*self.shared.get()).try_acquire::<true>(n) }.map(|_| Permit::new(&self.shared, n))
90 }
91
92 pub fn acquire(&self) -> Acquire<'_> {
99 self.build_acquire(1)
100 }
101
102 pub fn acquire_many(&self, n: usize) -> Acquire<'_> {
109 self.build_acquire(n)
110 }
111
112 fn build_acquire(&self, wants: usize) -> Acquire<'_> {
115 Acquire { shared: &self.shared, waiter: Waiter::new(wants) }
116 }
117}
118
119impl fmt::Debug for Semaphore {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 f.debug_struct("Semaphore")
122 .field("is_closed", &self.is_closed())
123 .field("available_permits", &self.available_permits())
124 .field("waiters", &self.waiters())
125 .finish_non_exhaustive()
126 }
127}
128
129pub struct Permit<'a> {
131 shared: &'a UnsafeCell<Shared>,
132 count: usize,
133}
134
135impl<'a> Permit<'a> {
136 fn new(shared: &'a UnsafeCell<Shared>, count: usize) -> Self {
141 Self { shared, count }
142 }
143
144 pub fn forget(self) {
148 mem::forget(self);
149 }
150}
151
152impl Drop for Permit<'_> {
153 fn drop(&mut self) {
154 let shared = unsafe { &mut (*self.shared.get()) };
156 shared.add_permits(self.count);
157 }
158}
159
160impl fmt::Debug for Permit<'_> {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 f.debug_struct("Permit").finish_non_exhaustive()
163 }
164}
165
166#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
168pub struct AcquireError(());
169
170impl fmt::Display for AcquireError {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 f.write_str("semaphore closed")
173 }
174}
175
176#[cfg(feature = "std")]
177impl std::error::Error for AcquireError {}
178
179#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
182pub enum TryAcquireError {
183 Closed,
186 NoPermits,
188}
189
190#[cfg(feature = "alloc")]
191impl From<TryAcquireError> for crate::error::TrySendError<()> {
192 fn from(err: TryAcquireError) -> Self {
193 match err {
194 TryAcquireError::Closed => Self::Closed(()),
195 TryAcquireError::NoPermits => Self::Full(()),
196 }
197 }
198}
199
200impl fmt::Display for TryAcquireError {
201 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202 match self {
203 TryAcquireError::Closed => f.write_str("semaphore closed"),
204 TryAcquireError::NoPermits => f.write_str("no permits available"),
205 }
206 }
207}
208
209#[cfg(feature = "std")]
210impl std::error::Error for TryAcquireError {}
211
212pub struct Acquire<'a> {
215 shared: &'a UnsafeCell<Shared>,
217 waiter: Waiter,
219}
220
221impl<'a> Future for Acquire<'a> {
222 type Output = Result<Permit<'a>, AcquireError>;
223
224 #[inline]
225 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
226 let waiter = unsafe { Pin::map_unchecked(self.as_ref(), |acquire| &acquire.waiter) };
228
229 match unsafe { (*self.shared.get()).poll_acquire(waiter, cx) } {
231 Poll::Ready(res) => {
232 waiter.state.set(WaiterState::Woken);
236 match res {
237 Ok(_) => {
238 let shared = self.as_ref().shared;
239 let count = waiter.permits.take();
240 Poll::Ready(Ok(Permit::new(shared, count)))
241 }
242 Err(e) => Poll::Ready(Err(e)),
243 }
244 }
245 Poll::Pending => Poll::Pending,
246 }
247 }
248}
249
250impl Drop for Acquire<'_> {
251 fn drop(&mut self) {
252 let shared = unsafe { &mut (*self.shared.get()) };
254
255 if let WaiterState::Waiting = self.waiter.state.get() {
257 unsafe { shared.waiters.try_remove(&self.waiter) };
262 }
263
264 let permits = self.waiter.permits.get();
267 shared.add_permits(permits);
270 }
271}
272
273struct Shared {
275 waiters: WaiterQueue,
277 permits: usize,
279 closed: bool,
281}
282
283impl Shared {
284 #[cold]
286 fn close(&mut self) -> usize {
287 let woken = unsafe { self.waiters.wake_all() };
290 self.closed = true;
291 self.waiters = WaiterQueue::new();
292
293 woken
294 }
295
296 fn is_closed(&self) -> bool {
298 self.closed
299 }
300
301 fn add_permits(&mut self, mut n: usize) {
304 while n > 0 {
305 if let Some(waiter) = self.waiters.front() {
307 let waiter = unsafe { waiter.as_ref() };
309 let diff = waiter.wants - waiter.permits.get();
312 if diff > n {
313 waiter.permits.set(diff - n);
317 return;
318 } else {
319 waiter.permits.set(waiter.wants);
322 n -= diff;
323
324 unsafe {
328 waiter.state.set(WaiterState::Woken);
329 waiter.waker.get().wake_by_ref();
330 self.waiters.pop_front(waiter);
332 };
333 }
334 } else {
335 self.permits = self.permits.saturating_add(n);
336 return;
337 }
338 }
339 }
340
341 fn try_acquire<const STRICT: bool>(&mut self, n: usize) -> Result<usize, TryAcquireError> {
344 if self.is_closed() {
345 return Err(TryAcquireError::Closed);
346 }
347
348 if n > self.permits {
349 if STRICT || self.permits == 0 {
350 return Err(TryAcquireError::NoPermits);
351 }
352
353 let count = self.permits;
355 self.permits = 0;
356 Ok(count)
357 } else {
358 self.permits -= n;
360 Ok(n)
361 }
362 }
363
364 fn poll_acquire(
365 &mut self,
366 waiter: Pin<&Waiter>,
367 cx: &mut Context<'_>,
368 ) -> Poll<Result<(), AcquireError>> {
369 if self.closed {
370 return Poll::Ready(Err(AcquireError(())));
373 }
374
375 match waiter.state.get() {
376 WaiterState::Woken => Poll::Ready(Ok(())),
377 WaiterState::Inert => self.poll_acquire_initial(waiter, cx),
378 WaiterState::Waiting => Poll::Pending,
379 }
380 }
381
382 fn poll_acquire_initial(
383 &mut self,
384 waiter: Pin<&Waiter>,
385 cx: &mut Context<'_>,
386 ) -> Poll<Result<(), AcquireError>> {
387 match self.try_acquire::<false>(waiter.wants) {
391 Ok(n) => {
392 waiter.permits.set(n);
394 if n == waiter.wants {
395 return Poll::Ready(Ok(()));
396 }
397 }
398 Err(TryAcquireError::Closed) => return Poll::Ready(Err(AcquireError(()))),
399 _ => {}
400 };
401
402 waiter.state.set(WaiterState::Waiting);
406 waiter.waker.set(cx.waker().clone());
407 unsafe { self.waiters.push_back(waiter.get_ref()) }
424 Poll::Pending
425 }
426}
427
428struct WaiterQueue {
429 head: *const Waiter,
430 tail: *const Waiter,
431}
432
433impl WaiterQueue {
434 const fn new() -> Self {
436 Self { head: ptr::null(), tail: ptr::null() }
437 }
438
439 fn front(&self) -> Option<NonNull<Waiter>> {
441 NonNull::new(self.head as *mut Waiter)
442 }
443
444 #[cold]
450 unsafe fn len(&self) -> usize {
451 let mut curr = self.head;
454 let mut waiting = 0;
455 while !curr.is_null() {
456 curr = unsafe { (*curr).next.get() };
458 waiting += 1;
459 }
460
461 waiting
462 }
463
464 unsafe fn push_back(&mut self, waiter: &Waiter) {
470 if self.tail.is_null() {
471 self.head = waiter;
473 self.tail = waiter;
474 } else {
475 unsafe { (*self.tail).next.set(waiter) };
479 waiter.prev.set(self.tail);
480 self.tail = waiter;
481 }
482 }
483
484 #[cold]
490 unsafe fn try_remove(&mut self, waiter: &Waiter) {
491 let prev = waiter.prev.get();
492 if prev.is_null() {
493 self.head = waiter.next.get();
494 } else {
495 unsafe { (*prev).next.set(waiter.next.get()) };
497 }
498
499 let next = waiter.next.get();
500 if next.is_null() {
501 self.tail = waiter.prev.get();
502 } else {
503 unsafe { (*next).prev.set(waiter.prev.get()) };
505 }
506 }
507
508 #[inline]
515 unsafe fn pop_front(&mut self, head: &Waiter) {
516 self.head = head.next.get();
517 if self.head.is_null() {
518 self.tail = ptr::null();
519 } else {
520 unsafe { (*self.head).prev.set(ptr::null()) };
521 }
522 }
523
524 #[cold]
525 unsafe fn wake_all(&mut self) -> usize {
526 let mut curr = self.head;
527 let mut woken = 0;
528
529 while !curr.is_null() {
530 unsafe {
533 let waiter = &*curr;
534 waiter.state.set(WaiterState::Woken);
535 waiter.waker.get().wake_by_ref();
536 curr = waiter.next.get();
537 }
538
539 woken += 1;
540 }
541
542 woken
543 }
544}
545
546struct Waiter {
549 wants: usize,
551 waker: LateInitWaker,
556 state: Cell<WaiterState>,
558 permits: Cell<usize>,
560 next: Cell<*const Self>,
562 prev: Cell<*const Self>,
564 _marker: PhantomPinned,
567}
568
569impl Waiter {
570 const fn new(wants: usize) -> Self {
571 Self {
572 wants,
573 waker: LateInitWaker::new(),
574 state: Cell::new(WaiterState::Inert),
575 permits: Cell::new(0),
576 next: Cell::new(ptr::null()),
577 prev: Cell::new(ptr::null()),
578 _marker: PhantomPinned,
579 }
580 }
581}
582
583#[derive(Clone, Copy)]
585enum WaiterState {
586 Inert,
588 Waiting,
590 Woken,
594}
595
596struct LateInitWaker(UnsafeCell<Option<Waker>>);
602
603impl LateInitWaker {
604 const fn new() -> Self {
605 Self(UnsafeCell::new(None))
606 }
607
608 fn set(&self, waker: Waker) {
609 unsafe { self.0.get().write(Some(waker)) };
613 }
614
615 unsafe fn get(&self) -> &Waker {
616 match &*self.0.get() {
618 Some(waker) => waker,
619 None => core::hint::unreachable_unchecked(),
620 }
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use futures_lite::future;
627
628 use core::{
629 future::Future as _,
630 ptr,
631 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
632 };
633
634 #[test]
635 fn try_acquire_one() {
636 let sem = super::Semaphore::new(0);
637 assert!(sem.try_acquire().is_err());
638 sem.add_permits(2);
639 let p1 = sem.try_acquire().unwrap();
640 let p2 = sem.try_acquire().unwrap();
641 assert_eq!(sem.available_permits(), 0);
642
643 drop((p1, p2));
644 assert_eq!(sem.available_permits(), 2);
645 }
646
647 #[test]
648 fn try_acquire_many() {
649 let sem = super::Semaphore::new(0);
650 assert!(sem.try_acquire_many(3).is_err());
651 sem.add_permits(2);
652 assert!(sem.try_acquire_many(3).is_err());
653 sem.add_permits(1);
654 let permit = sem.try_acquire_many(3).unwrap();
655 assert_eq!(permit.count, 3);
656 drop(permit);
657 assert_eq!(sem.available_permits(), 3);
658 }
659
660 #[test]
661 fn acquire_never() {
662 future::block_on(async {
663 let sem = super::Semaphore::new(0);
664 let mut fut = core::pin::pin!(sem.acquire());
665
666 core::future::poll_fn(|cx| {
667 assert!(fut.as_mut().poll(cx).is_pending());
668 Poll::Ready(())
669 })
670 .await;
671
672 assert_eq!(sem.available_permits(), 0);
673 });
674 }
675
676 #[test]
677 fn acquire() {
678 future::block_on(async {
679 let sem = super::Semaphore::new(0);
680 let mut fut = core::pin::pin!(sem.acquire());
681 core::future::poll_fn(|cx| {
682 assert!(fut.as_mut().poll(cx).is_pending());
683 Poll::Ready(())
684 })
685 .await;
686
687 sem.add_permits(1);
688 let permit = fut.await.unwrap();
689 drop(permit);
690 assert_eq!(sem.available_permits(), 1);
691 });
692 }
693
694 #[test]
695 fn acquire_one() {
696 future::block_on(async {
697 let sem = super::Semaphore::new(0);
698 let mut fut = core::pin::pin!(sem.acquire());
699
700 core::future::poll_fn(|cx| {
702 assert!(fut.as_mut().poll(cx).is_pending());
703 assert_eq!(sem.waiters(), 1);
704 sem.add_permits(2);
707 Poll::Ready(())
708 })
709 .await;
710
711 let permit = fut.await.unwrap();
713 assert_eq!(sem.available_permits(), 1);
714 drop(permit);
715 assert_eq!(sem.available_permits(), 2);
716 });
717 }
718
719 #[test]
720 fn poll_acquire_after_completion() {
721 future::block_on(async {
722 let sem = super::Semaphore::new(0);
723 let mut fut = core::pin::pin!(sem.acquire());
724 core::future::poll_fn(|cx| {
725 assert!(fut.as_mut().poll(cx).is_pending());
726 Poll::Ready(())
727 })
728 .await;
729
730 sem.add_permits(1);
731
732 core::future::poll_fn(|cx| {
733 assert!(fut.as_mut().poll(cx).is_ready());
734 assert!(fut.as_mut().poll(cx).is_ready());
737 Poll::Ready(())
738 })
739 .await;
740
741 assert_eq!(sem.available_permits(), 1);
742 });
743 }
744
745 #[test]
746 fn poll_future() {
747 static RAW_VTABLE: RawWakerVTable = RawWakerVTable::new(
748 |_| RawWaker::new(ptr::null(), &RAW_VTABLE),
749 |_| {},
750 |_| {},
751 |_| {},
752 );
753
754 let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &RAW_VTABLE)) };
755 let mut cx = Context::from_waker(&waker);
756
757 let sem = super::Semaphore::new(0);
758 let mut fut = Box::pin(sem.build_acquire(1));
759
760 assert!(fut.as_mut().poll(&mut cx).is_pending());
761 assert_eq!(sem.waiters(), 1);
762 sem.add_permits(1);
763
764 assert!(fut.as_mut().poll(&mut cx).is_ready());
765 drop(fut);
766 assert_eq!(sem.waiters(), 0);
767 }
768
769 #[test]
770 fn acquire_many() {
771 future::block_on(async {
772 let sem = super::Semaphore::new(0);
773 let mut f1 = Box::pin(sem.acquire_many(2));
774 let mut f2 = Box::pin(sem.acquire_many(1));
775
776 core::future::poll_fn(|cx| {
777 assert!(f1.as_mut().poll(cx).is_pending());
779 assert!(f2.as_mut().poll(cx).is_pending());
780
781 assert_eq!(sem.waiters(), 2);
782 sem.add_permits(1);
783
784 assert!(f2.as_mut().poll(cx).is_pending());
786
787 sem.add_permits(1);
789 assert_eq!(sem.waiters(), 1);
790 Poll::Ready(())
791 })
792 .await;
793
794 let permit = f1.await.unwrap();
796 assert_eq!(sem.waiters(), 1);
797
798 drop(permit);
801 assert_eq!(sem.waiters(), 0);
802 assert_eq!(sem.available_permits(), 1);
803
804 let permit = f2.await.unwrap();
805 assert_eq!(sem.available_permits(), 1);
806 drop(permit);
807
808 assert_eq!(sem.available_permits(), 2);
809 });
810 }
811
812 #[test]
813 fn cleanup() {
814 future::block_on(async {
815 let sem = super::Semaphore::new(0);
816
817 let mut fut = Box::pin(sem.acquire());
818 core::future::poll_fn(|cx| {
820 assert!(fut.as_mut().poll(cx).is_pending());
821 Poll::Ready(())
822 })
823 .await;
824
825 drop(fut);
827 assert_eq!(sem.waiters(), 0);
828
829 let mut fut = Box::pin(sem.acquire());
830 core::future::poll_fn(|cx| {
832 assert!(fut.as_mut().poll(cx).is_pending());
833 Poll::Ready(())
834 })
835 .await;
836
837 sem.add_permits(1);
839 assert_eq!(sem.close(), 0);
841
842 assert!(fut.await.is_err());
843 assert_eq!(sem.waiters(), 0);
844 assert_eq!(sem.available_permits(), 1);
845 });
846 }
847
848 #[test]
849 fn cleanup_after_wake() {
850 future::block_on(async {
851 let sem = super::Semaphore::new(0);
852 let mut fut = Box::pin(sem.acquire());
853
854 core::future::poll_fn(|cx| {
855 assert!(fut.as_mut().poll(cx).is_pending());
857 Poll::Ready(())
858 })
859 .await;
860
861 sem.add_permits(1);
864 drop(fut);
867
868 assert_eq!(sem.waiters(), 0);
869 assert_eq!(sem.available_permits(), 1);
870 });
871 }
872
873 #[test]
874 fn close() {
875 future::block_on(async {
876 let sem = super::Semaphore::new(1);
877 let permit = sem.acquire().await.unwrap();
878
879 let mut f1 = Box::pin(sem.acquire());
880 let mut f2 = Box::pin(sem.acquire());
881 core::future::poll_fn(|cx| {
882 assert!(f1.as_mut().poll(cx).is_pending());
884 assert!(f2.as_mut().poll(cx).is_pending());
885 Poll::Ready(())
886 })
887 .await;
888
889 assert_eq!(sem.waiters(), 2);
890 assert_eq!(sem.close(), 2);
891 assert_eq!(sem.waiters(), 0);
892
893 core::future::poll_fn(|cx| {
894 match f1.as_mut().poll(cx) {
896 Poll::Ready(Err(_)) => Poll::Ready(()),
897 _ => panic!("acquire future should have resolved"),
898 }
899 })
900 .await;
901
902 drop(f1);
904 assert_eq!(sem.available_permits(), 0);
905 assert!(f2.await.is_err());
907
908 drop(permit);
911 assert_eq!(sem.available_permits(), 1);
912
913 assert!(sem.try_acquire().is_err());
915 assert!(sem.acquire().await.is_err());
916 });
917 }
918
919 #[test]
920 fn return_outstanding_permit_on_close() {
921 future::block_on(async {
922 let sem = super::Semaphore::new(1);
923 let permit = sem.acquire().await.unwrap();
924
925 let mut fut = Box::pin(sem.acquire());
926 assert!(future::poll_once(&mut fut).await.is_none());
927 assert_eq!(sem.waiters(), 1);
928
929 drop(permit);
931 assert_eq!(sem.waiters(), 0);
932 assert_eq!(sem.available_permits(), 0);
933
934 sem.close();
936 assert_eq!(sem.available_permits(), 0);
937
938 assert!(fut.await.is_err());
940 assert_eq!(sem.available_permits(), 1);
941 });
942 }
943
944 #[test]
945 fn return_outstanding_permit_on_cancel() {
946 future::block_on(async {
947 let sem = super::Semaphore::new(0);
948
949 let mut fut = Box::pin(sem.acquire());
950 assert!(future::poll_once(&mut fut).await.is_none());
951 assert_eq!(sem.waiters(), 1);
952
953 sem.add_permits(1);
954 assert_eq!(sem.waiters(), 0);
955
956 drop(fut);
959
960 assert_eq!(sem.waiters(), 0);
961 assert_eq!(sem.available_permits(), 1);
962 });
963 }
964
965 #[test]
966 fn forget_acquire_future() {
967 future::block_on(async {
968 async fn acquire_and_forget(sem: &super::Semaphore) {
969 let waiters = sem.waiters();
970 let mut fut = std::pin::pin!(sem.acquire());
971 assert!(future::poll_once(&mut fut).await.is_none());
972 assert_eq!(sem.waiters(), waiters + 1);
973
974 std::mem::forget(fut);
978 }
979
980 let sem = super::Semaphore::new(0);
981 acquire_and_forget(&sem).await;
982 assert_eq!(sem.waiters(), 0);
983
984 let mut arr = [0u8; 1000];
986 for v in &mut arr {
987 *v = 255;
988 }
989
990 let mut f1 = std::pin::pin!(sem.acquire());
991 assert!(future::poll_once(&mut f1).await.is_none());
992 let mut f2 = std::pin::pin!(sem.acquire());
993 assert!(future::poll_once(&mut f2).await.is_none());
994 let mut f3 = std::pin::pin!(sem.acquire());
995 assert!(future::poll_once(&mut f3).await.is_none());
996
997 assert_eq!(sem.waiters(), 3);
998 assert_eq!(sem.available_permits(), 0);
999 sem.add_permits(3);
1000
1001 assert!(matches!(future::poll_once(&mut f1).await, Some(Ok(_))));
1002 assert!(matches!(future::poll_once(&mut f2).await, Some(Ok(_))));
1003 assert!(matches!(future::poll_once(&mut f3).await, Some(Ok(_))));
1004
1005 assert_eq!(sem.waiters(), 0);
1006 assert_eq!(sem.available_permits(), 3);
1007 });
1008 }
1009}