heapless/pool/
arc.rs
1use core::{
71 fmt,
72 hash::{Hash, Hasher},
73 mem::{ManuallyDrop, MaybeUninit},
74 ops, ptr,
75 sync::atomic::{self, AtomicUsize, Ordering},
76};
77
78use super::treiber::{NonNullPtr, Stack, UnionNode};
79
80#[macro_export]
84macro_rules! arc_pool {
85 ($name:ident: $data_type:ty) => {
86 pub struct $name;
87
88 impl $crate::pool::arc::ArcPool for $name {
89 type Data = $data_type;
90
91 fn singleton() -> &'static $crate::pool::arc::ArcPoolImpl<$data_type> {
92 static $name: $crate::pool::arc::ArcPoolImpl<$data_type> =
93 $crate::pool::arc::ArcPoolImpl::new();
94
95 &$name
96 }
97 }
98
99 impl $name {
100 #[allow(dead_code)]
102 pub fn alloc(
103 &self,
104 value: $data_type,
105 ) -> Result<$crate::pool::arc::Arc<$name>, $data_type> {
106 <$name as $crate::pool::arc::ArcPool>::alloc(value)
107 }
108
109 #[allow(dead_code)]
111 pub fn manage(&self, block: &'static mut $crate::pool::arc::ArcBlock<$data_type>) {
112 <$name as $crate::pool::arc::ArcPool>::manage(block)
113 }
114 }
115 };
116}
117
118pub trait ArcPool: Sized {
120 type Data: 'static;
122
123 #[doc(hidden)]
125 fn singleton() -> &'static ArcPoolImpl<Self::Data>;
126
127 fn alloc(value: Self::Data) -> Result<Arc<Self>, Self::Data> {
135 Ok(Arc {
136 node_ptr: Self::singleton().alloc(value)?,
137 })
138 }
139
140 fn manage(block: &'static mut ArcBlock<Self::Data>) {
142 Self::singleton().manage(block)
143 }
144}
145
146#[doc(hidden)]
149pub struct ArcPoolImpl<T> {
150 stack: Stack<UnionNode<MaybeUninit<ArcInner<T>>>>,
151}
152
153impl<T> ArcPoolImpl<T> {
154 #[doc(hidden)]
156 pub const fn new() -> Self {
157 Self {
158 stack: Stack::new(),
159 }
160 }
161
162 fn alloc(&self, value: T) -> Result<NonNullPtr<UnionNode<MaybeUninit<ArcInner<T>>>>, T> {
163 if let Some(node_ptr) = self.stack.try_pop() {
164 let inner = ArcInner {
165 data: value,
166 strong: AtomicUsize::new(1),
167 };
168 unsafe { node_ptr.as_ptr().cast::<ArcInner<T>>().write(inner) }
169
170 Ok(node_ptr)
171 } else {
172 Err(value)
173 }
174 }
175
176 fn manage(&self, block: &'static mut ArcBlock<T>) {
177 let node: &'static mut _ = &mut block.node;
178
179 unsafe { self.stack.push(NonNullPtr::from_static_mut_ref(node)) }
180 }
181}
182
183unsafe impl<T> Sync for ArcPoolImpl<T> {}
184
185pub struct Arc<P>
187where
188 P: ArcPool,
189{
190 node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>,
191}
192
193impl<P> Arc<P>
194where
195 P: ArcPool,
196{
197 fn inner(&self) -> &ArcInner<P::Data> {
198 unsafe { &*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>() }
199 }
200
201 fn from_inner(node_ptr: NonNullPtr<UnionNode<MaybeUninit<ArcInner<P::Data>>>>) -> Self {
202 Self { node_ptr }
203 }
204
205 unsafe fn get_mut_unchecked(this: &mut Self) -> &mut P::Data {
206 &mut *ptr::addr_of_mut!((*this.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data)
207 }
208
209 #[inline(never)]
210 unsafe fn drop_slow(&mut self) {
211 ptr::drop_in_place(Self::get_mut_unchecked(self));
213
214 P::singleton().stack.push(self.node_ptr);
216 }
217}
218
219impl<P> AsRef<P::Data> for Arc<P>
220where
221 P: ArcPool,
222{
223 fn as_ref(&self) -> &P::Data {
224 &**self
225 }
226}
227
228const MAX_REFCOUNT: usize = (isize::MAX) as usize;
229
230impl<P> Clone for Arc<P>
231where
232 P: ArcPool,
233{
234 fn clone(&self) -> Self {
235 let old_size = self.inner().strong.fetch_add(1, Ordering::Relaxed);
236
237 if old_size > MAX_REFCOUNT {
238 panic!();
240 }
241
242 Self::from_inner(self.node_ptr)
243 }
244}
245
246impl<A> fmt::Debug for Arc<A>
247where
248 A: ArcPool,
249 A::Data: fmt::Debug,
250{
251 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252 A::Data::fmt(self, f)
253 }
254}
255
256impl<P> ops::Deref for Arc<P>
257where
258 P: ArcPool,
259{
260 type Target = P::Data;
261
262 fn deref(&self) -> &Self::Target {
263 unsafe { &*ptr::addr_of!((*self.node_ptr.as_ptr().cast::<ArcInner<P::Data>>()).data) }
264 }
265}
266
267impl<A> fmt::Display for Arc<A>
268where
269 A: ArcPool,
270 A::Data: fmt::Display,
271{
272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273 A::Data::fmt(self, f)
274 }
275}
276
277impl<A> Drop for Arc<A>
278where
279 A: ArcPool,
280{
281 fn drop(&mut self) {
282 if self.inner().strong.fetch_sub(1, Ordering::Release) != 1 {
283 return;
284 }
285
286 atomic::fence(Ordering::Acquire);
287
288 unsafe { self.drop_slow() }
289 }
290}
291
292impl<A> Eq for Arc<A>
293where
294 A: ArcPool,
295 A::Data: Eq,
296{
297}
298
299impl<A> Hash for Arc<A>
300where
301 A: ArcPool,
302 A::Data: Hash,
303{
304 fn hash<H>(&self, state: &mut H)
305 where
306 H: Hasher,
307 {
308 (**self).hash(state)
309 }
310}
311
312impl<A> Ord for Arc<A>
313where
314 A: ArcPool,
315 A::Data: Ord,
316{
317 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
318 A::Data::cmp(self, other)
319 }
320}
321
322impl<A, B> PartialEq<Arc<B>> for Arc<A>
323where
324 A: ArcPool,
325 B: ArcPool,
326 A::Data: PartialEq<B::Data>,
327{
328 fn eq(&self, other: &Arc<B>) -> bool {
329 A::Data::eq(self, &**other)
330 }
331}
332
333impl<A, B> PartialOrd<Arc<B>> for Arc<A>
334where
335 A: ArcPool,
336 B: ArcPool,
337 A::Data: PartialOrd<B::Data>,
338{
339 fn partial_cmp(&self, other: &Arc<B>) -> Option<core::cmp::Ordering> {
340 A::Data::partial_cmp(self, &**other)
341 }
342}
343
344unsafe impl<A> Send for Arc<A>
345where
346 A: ArcPool,
347 A::Data: Sync + Send,
348{
349}
350
351unsafe impl<A> Sync for Arc<A>
352where
353 A: ArcPool,
354 A::Data: Sync + Send,
355{
356}
357
358impl<A> Unpin for Arc<A> where A: ArcPool {}
359
360struct ArcInner<T> {
361 data: T,
362 strong: AtomicUsize,
363}
364
365pub struct ArcBlock<T> {
367 node: UnionNode<MaybeUninit<ArcInner<T>>>,
368}
369
370impl<T> ArcBlock<T> {
371 pub const fn new() -> Self {
373 Self {
374 node: UnionNode {
375 data: ManuallyDrop::new(MaybeUninit::uninit()),
376 },
377 }
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[test]
386 fn cannot_alloc_if_empty() {
387 arc_pool!(P: i32);
388
389 assert_eq!(Err(42), P.alloc(42),);
390 }
391
392 #[test]
393 fn can_alloc_if_manages_one_block() {
394 arc_pool!(P: i32);
395
396 let block = unsafe {
397 static mut B: ArcBlock<i32> = ArcBlock::new();
398 &mut B
399 };
400 P.manage(block);
401
402 assert_eq!(42, *P.alloc(42).unwrap());
403 }
404
405 #[test]
406 fn alloc_drop_alloc() {
407 arc_pool!(P: i32);
408
409 let block = unsafe {
410 static mut B: ArcBlock<i32> = ArcBlock::new();
411 &mut B
412 };
413 P.manage(block);
414
415 let arc = P.alloc(1).unwrap();
416
417 drop(arc);
418
419 assert_eq!(2, *P.alloc(2).unwrap());
420 }
421
422 #[test]
423 fn strong_count_starts_at_one() {
424 arc_pool!(P: i32);
425
426 let block = unsafe {
427 static mut B: ArcBlock<i32> = ArcBlock::new();
428 &mut B
429 };
430 P.manage(block);
431
432 let arc = P.alloc(1).ok().unwrap();
433
434 assert_eq!(1, arc.inner().strong.load(Ordering::Relaxed));
435 }
436
437 #[test]
438 fn clone_increases_strong_count() {
439 arc_pool!(P: i32);
440
441 let block = unsafe {
442 static mut B: ArcBlock<i32> = ArcBlock::new();
443 &mut B
444 };
445 P.manage(block);
446
447 let arc = P.alloc(1).ok().unwrap();
448
449 let before = arc.inner().strong.load(Ordering::Relaxed);
450
451 let arc2 = arc.clone();
452
453 let expected = before + 1;
454 assert_eq!(expected, arc.inner().strong.load(Ordering::Relaxed));
455 assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
456 }
457
458 #[test]
459 fn drop_decreases_strong_count() {
460 arc_pool!(P: i32);
461
462 let block = unsafe {
463 static mut B: ArcBlock<i32> = ArcBlock::new();
464 &mut B
465 };
466 P.manage(block);
467
468 let arc = P.alloc(1).ok().unwrap();
469 let arc2 = arc.clone();
470
471 let before = arc.inner().strong.load(Ordering::Relaxed);
472
473 drop(arc);
474
475 let expected = before - 1;
476 assert_eq!(expected, arc2.inner().strong.load(Ordering::Relaxed));
477 }
478
479 #[test]
480 fn runs_destructor_exactly_once_when_strong_count_reaches_zero() {
481 static COUNT: AtomicUsize = AtomicUsize::new(0);
482
483 pub struct S;
484
485 impl Drop for S {
486 fn drop(&mut self) {
487 COUNT.fetch_add(1, Ordering::Relaxed);
488 }
489 }
490
491 arc_pool!(P: S);
492
493 let block = unsafe {
494 static mut B: ArcBlock<S> = ArcBlock::new();
495 &mut B
496 };
497 P.manage(block);
498
499 let arc = P.alloc(S).ok().unwrap();
500
501 assert_eq!(0, COUNT.load(Ordering::Relaxed));
502
503 drop(arc);
504
505 assert_eq!(1, COUNT.load(Ordering::Relaxed));
506 }
507
508 #[test]
509 fn zst_is_well_aligned() {
510 #[repr(align(4096))]
511 pub struct Zst4096;
512
513 arc_pool!(P: Zst4096);
514
515 let block = unsafe {
516 static mut B: ArcBlock<Zst4096> = ArcBlock::new();
517 &mut B
518 };
519 P.manage(block);
520
521 let arc = P.alloc(Zst4096).ok().unwrap();
522
523 let raw = &*arc as *const Zst4096;
524 assert_eq!(0, raw as usize % 4096);
525 }
526}