heapless/pool/
arc.rs

1//! `std::sync::Arc`-like API on top of a lock-free memory pool
2//!
3//! # Example usage
4//!
5//! ```
6//! use heapless::{arc_pool, pool::arc::{Arc, ArcBlock}};
7//!
8//! arc_pool!(P: u128);
9//!
10//! // cannot allocate without first giving memory blocks to the pool
11//! assert!(P.alloc(42).is_err());
12//!
13//! // (some `no_std` runtimes have safe APIs to create `&'static mut` references)
14//! let block: &'static mut ArcBlock<u128> = unsafe {
15//!     static mut B: ArcBlock<u128> = ArcBlock::new();
16//!     &mut B
17//! };
18//!
19//! P.manage(block);
20//!
21//! let arc = P.alloc(1).unwrap();
22//!
23//! // number of smart pointers is limited to the number of blocks managed by the pool
24//! let res = P.alloc(2);
25//! assert!(res.is_err());
26//!
27//! // but cloning does not consume an `ArcBlock`
28//! let arc2 = arc.clone();
29//!
30//! assert_eq!(1, *arc2);
31//!
32//! // `arc`'s destructor returns the memory block to the pool
33//! drop(arc2); // decrease reference counter
34//! drop(arc); // release memory
35//!
36//! // it's now possible to allocate a new `Arc` smart pointer
37//! let res = P.alloc(3);
38//!
39//! assert!(res.is_ok());
40//! ```
41//!
42//! # Array block initialization
43//!
44//! You can create a static variable that contains an array of memory blocks and give all the blocks
45//! to the `ArcPool`. This requires an intermediate `const` value as shown below:
46//!
47//! ```
48//! use heapless::{arc_pool, pool::arc::ArcBlock};
49//!
50//! arc_pool!(P: u128);
51//!
52//! const POOL_CAPACITY: usize = 8;
53//!
54//! let blocks: &'static mut [ArcBlock<u128>] = {
55//!     const BLOCK: ArcBlock<u128> = ArcBlock::new(); // <=
56//!     static mut BLOCKS: [ArcBlock<u128>; POOL_CAPACITY] = [BLOCK; POOL_CAPACITY];
57//!     unsafe { &mut BLOCKS }
58//! };
59//!
60//! for block in blocks {
61//!     P.manage(block);
62//! }
63//! ```
64
65// reference counting logic is based on version 1.63.0 of the Rust standard library (`alloc`  crate)
66// which is licensed under 'MIT or APACHE-2.0'
67// https://github.com/rust-lang/rust/blob/1.63.0/library/alloc/src/sync.rs#L235 (last visited
68// 2022-09-05)
69
70use core::{
71    fmt,
72    hash::{Hash, Hasher},
73    mem::{ManuallyDrop, MaybeUninit},
74    ops, ptr,
75    sync::atomic::{self, AtomicUsize, Ordering},
76};
77
78use super::treiber::{NonNullPtr, Stack, UnionNode};
79
80/// Creates a new `ArcPool` singleton with the given `$name` that manages the specified `$data_type`
81///
82/// For more extensive documentation see the [module level documentation](crate::pool::arc)
83#[macro_export]
84macro_rules! arc_pool {
85    ($name:ident: $data_type:ty) => {
86        pub struct $name;
87
88        impl $crate::pool::arc::ArcPool for $name {
89            type Data = $data_type;
90
91            fn singleton() -> &'static $crate::pool::arc::ArcPoolImpl<$data_type> {
92                static $name: $crate::pool::arc::ArcPoolImpl<$data_type> =
93                    $crate::pool::arc::ArcPoolImpl::new();
94
95                &$name
96            }
97        }
98
99        impl $name {
100            /// Inherent method version of `ArcPool::alloc`
101            #[allow(dead_code)]
102            pub fn alloc(
103                &self,
104                value: $data_type,
105            ) -> Result<$crate::pool::arc::Arc<$name>, $data_type> {
106                <$name as $crate::pool::arc::ArcPool>::alloc(value)
107            }
108
109            /// Inherent method version of `ArcPool::manage`
110            #[allow(dead_code)]
111            pub fn manage(&self, block: &'static mut $crate::pool::arc::ArcBlock<$data_type>) {
112                <$name as $crate::pool::arc::ArcPool>::manage(block)
113            }
114        }
115    };
116}
117
118/// A singleton that manages `pool::arc::Arc` smart pointers
119pub trait ArcPool: Sized {
120    /// The data type managed by the memory pool
121    type Data: 'static;
122
123    /// `arc_pool!` implementation detail
124    #[doc(hidden)]
125    fn singleton() -> &'static ArcPoolImpl<Self::Data>;
126
127    /// Allocate a new `Arc` smart pointer initialized to the given `value`
128    ///
129    /// `manage` should be called at least once before calling `alloc`
130    ///
131    /// # Errors
132    ///
133    /// The `Err`or variant is returned when the memory pool has run out of memory blocks
134    fn alloc(value: Self::Data) -> Result<Arc<Self>, Self::Data> {
135        Ok(Arc {
136            node_ptr: Self::singleton().alloc(value)?,
137        })
138    }
139
140    /// Add a statically allocated memory block to the memory pool
141    fn manage(block: &'static mut ArcBlock<Self::Data>) {
142        Self::singleton().manage(block)
143    }
144}
145
146/// `arc_pool!` implementation detail
147// newtype to avoid having to make field types public
148#[doc(hidden)]
149pub struct ArcPoolImpl<T> {
150    stack: Stack<UnionNode<MaybeUninit<ArcInner<T>>>>,
151}
152
153impl<T> ArcPoolImpl<T> {
154    /// `arc_pool!` implementation detail
155    #[doc(hidden)]
156    pub const fn new() -> Self {
157        Self {
158            stack: Stack::new(),
159        }
160    }
161
162    fn alloc(&self, value: T) -> Result<NonNullPtr<UnionNode<MaybeUninit<ArcInner<T>>>>, T> {
163        if let Some(node_ptr) = self.stack.try_pop() {
164            let inner = ArcInner {
165                data: value,
166                strong: AtomicUsize::new(1),
167            };
168            unsafe { node_ptr.as_ptr().cast::<ArcInner<T>>().write(inner) }
169
170            Ok(node_ptr)
171        } else {
172            Err(value)
173        }
174    }
175
176    fn manage(&self, block: &'static mut ArcBlock<T>) {
177        let node: &'static mut _ = &mut block.node;
178
179        unsafe { self.stack.push(NonNullPtr::from_static_mut_ref(node)) }
180    }
181}
182
183unsafe impl<T> Sync for ArcPoolImpl<T> {}
184
185/// Like `std::sync::Arc` but managed by memory pool `P`
186pub struct Arc<P>
187where
188    P: ArcPool,
189{
190    node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>,
191}
192
193impl<P> Arc<P>
194where
195    P: ArcPool,
196{
197    fn inner(&self) -> &ArcInner<P::Data> {
198        unsafe { &*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>() }
199    }
200
201    fn from_inner(node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>) -> Self {
202        Self { node_ptr }
203    }
204
205    unsafe fn get_mut_unchecked(this: &mut Self) -> &mut P::Data {
206        &mut *ptr::addr_of_mut!((*this.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data)
207    }
208
209    #[inline(never)]
210    unsafe fn drop_slow(&mut self) {
211        // run `P::Data`'s destructor
212        ptr::drop_in_place(Self::get_mut_unchecked(self));
213
214        // return memory to pool
215        P::singleton().stack.push(self.node_ptr);
216    }
217}
218
219impl<P> AsRef<P::Data> for Arc<P>
220where
221    P: ArcPool,
222{
223    fn as_ref(&self) -> &P::Data {
224        &**self
225    }
226}
227
228const MAX_REFCOUNT: usize = (isize::MAX) as usize;
229
230impl<P> Clone for Arc<P>
231where
232    P: ArcPool,
233{
234    fn clone(&self) -> Self {
235        let old_size = self.inner().strong.fetch_add(1, Ordering::Relaxed);
236
237        if old_size > MAX_REFCOUNT {
238            // XXX original code calls `intrinsics::abort` which is unstable API
239            panic!();
240        }
241
242        Self::from_inner(self.node_ptr)
243    }
244}
245
246impl<A> fmt::Debug for Arc<A>
247where
248    A: ArcPool,
249    A::Data: fmt::Debug,
250{
251    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252        A::Data::fmt(self, f)
253    }
254}
255
256impl<P> ops::Deref for Arc<P>
257where
258    P: ArcPool,
259{
260    type Target = P::Data;
261
262    fn deref(&self) -> &Self::Target {
263        unsafe { &*ptr::addr_of!((*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data) }
264    }
265}
266
267impl<A> fmt::Display for Arc<A>
268where
269    A: ArcPool,
270    A::Data: fmt::Display,
271{
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        A::Data::fmt(self, f)
274    }
275}
276
277impl<A> Drop for Arc<A>
278where
279    A: ArcPool,
280{
281    fn drop(&mut self) {
282        if self.inner().strong.fetch_sub(1, Ordering::Release) != 1 {
283            return;
284        }
285
286        atomic::fence(Ordering::Acquire);
287
288        unsafe { self.drop_slow() }
289    }
290}
291
292impl<A> Eq for Arc<A>
293where
294    A: ArcPool,
295    A::Data: Eq,
296{
297}
298
299impl<A> Hash for Arc<A>
300where
301    A: ArcPool,
302    A::Data: Hash,
303{
304    fn hash<H>(&self, state: &mut H)
305    where
306        H: Hasher,
307    {
308        (**self).hash(state)
309    }
310}
311
312impl<A> Ord for Arc<A>
313where
314    A: ArcPool,
315    A::Data: Ord,
316{
317    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
318        A::Data::cmp(self, other)
319    }
320}
321
322impl<A, B> PartialEq<Arc<B>> for Arc<A>
323where
324    A: ArcPool,
325    B: ArcPool,
326    A::Data: PartialEq<B::Data>,
327{
328    fn eq(&self, other: &Arc<B>) -> bool {
329        A::Data::eq(self, &**other)
330    }
331}
332
333impl<A, B> PartialOrd<Arc<B>> for Arc<A>
334where
335    A: ArcPool,
336    B: ArcPool,
337    A::Data: PartialOrd<B::Data>,
338{
339    fn partial_cmp(&self, other: &Arc<B>) -> Option<core::cmp::Ordering> {
340        A::Data::partial_cmp(self, &**other)
341    }
342}
343
344unsafe impl<A> Send for Arc<A>
345where
346    A: ArcPool,
347    A::Data: Sync + Send,
348{
349}
350
351unsafe impl<A> Sync for Arc<A>
352where
353    A: ArcPool,
354    A::Data: Sync + Send,
355{
356}
357
358impl<A> Unpin for Arc<A> where A: ArcPool {}
359
360struct ArcInner<T> {
361    data: T,
362    strong: AtomicUsize,
363}
364
365/// A chunk of memory that an `ArcPool` can manage
366pub struct ArcBlock<T> {
367    node: UnionNode<MaybeUninit<ArcInner<T>>>,
368}
369
370impl<T> ArcBlock<T> {
371    /// Creates a new memory block
372    pub const fn new() -> Self {
373        Self {
374            node: UnionNode {
375                data: ManuallyDrop::new(MaybeUninit::uninit()),
376            },
377        }
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384
385    #[test]
386    fn cannot_alloc_if_empty() {
387        arc_pool!(P: i32);
388
389        assert_eq!(Err(42), P.alloc(42),);
390    }
391
392    #[test]
393    fn can_alloc_if_manages_one_block() {
394        arc_pool!(P: i32);
395
396        let block = unsafe {
397            static mut B: ArcBlock<i32> = ArcBlock::new();
398            &mut B
399        };
400        P.manage(block);
401
402        assert_eq!(42, *P.alloc(42).unwrap());
403    }
404
405    #[test]
406    fn alloc_drop_alloc() {
407        arc_pool!(P: i32);
408
409        let block = unsafe {
410            static mut B: ArcBlock<i32> = ArcBlock::new();
411            &mut B
412        };
413        P.manage(block);
414
415        let arc = P.alloc(1).unwrap();
416
417        drop(arc);
418
419        assert_eq!(2, *P.alloc(2).unwrap());
420    }
421
422    #[test]
423    fn strong_count_starts_at_one() {
424        arc_pool!(P: i32);
425
426        let block = unsafe {
427            static mut B: ArcBlock<i32> = ArcBlock::new();
428            &mut B
429        };
430        P.manage(block);
431
432        let arc = P.alloc(1).ok().unwrap();
433
434        assert_eq!(1, arc.inner().strong.load(Ordering::Relaxed));
435    }
436
437    #[test]
438    fn clone_increases_strong_count() {
439        arc_pool!(P: i32);
440
441        let block = unsafe {
442            static mut B: ArcBlock<i32> = ArcBlock::new();
443            &mut B
444        };
445        P.manage(block);
446
447        let arc = P.alloc(1).ok().unwrap();
448
449        let before = arc.inner().strong.load(Ordering::Relaxed);
450
451        let arc2 = arc.clone();
452
453        let expected = before + 1;
454        assert_eq!(expected, arc.inner().strong.load(Ordering::Relaxed));
455        assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
456    }
457
458    #[test]
459    fn drop_decreases_strong_count() {
460        arc_pool!(P: i32);
461
462        let block = unsafe {
463            static mut B: ArcBlock<i32> = ArcBlock::new();
464            &mut B
465        };
466        P.manage(block);
467
468        let arc = P.alloc(1).ok().unwrap();
469        let arc2 = arc.clone();
470
471        let before = arc.inner().strong.load(Ordering::Relaxed);
472
473        drop(arc);
474
475        let expected = before - 1;
476        assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
477    }
478
479    #[test]
480    fn runs_destructor_exactly_once_when_strong_count_reaches_zero() {
481        static COUNT: AtomicUsize = AtomicUsize::new(0);
482
483        pub struct S;
484
485        impl Drop for S {
486            fn drop(&mut self) {
487                COUNT.fetch_add(1, Ordering::Relaxed);
488            }
489        }
490
491        arc_pool!(P: S);
492
493        let block = unsafe {
494            static mut B: ArcBlock<S> = ArcBlock::new();
495            &mut B
496        };
497        P.manage(block);
498
499        let arc = P.alloc(S).ok().unwrap();
500
501        assert_eq!(0, COUNT.load(Ordering::Relaxed));
502
503        drop(arc);
504
505        assert_eq!(1, COUNT.load(Ordering::Relaxed));
506    }
507
508    #[test]
509    fn zst_is_well_aligned() {
510        #[repr(align(4096))]
511        pub struct Zst4096;
512
513        arc_pool!(P: Zst4096);
514
515        let block = unsafe {
516            static mut B: ArcBlock<Zst4096> = ArcBlock::new();
517            &mut B
518        };
519        P.manage(block);
520
521        let arc = P.alloc(Zst4096).ok().unwrap();
522
523        let raw = &*arc as *const Zst4096;
524        assert_eq!(0, raw as usize % 4096);
525    }
526}