one_shot_mutex/sync/
rwlock.rs

1use core::sync::atomic::{AtomicUsize, Ordering};
2
3use lock_api::{
4    GuardSend, RawRwLock, RawRwLockDowngrade, RawRwLockRecursive, RawRwLockUpgrade,
5    RawRwLockUpgradeDowngrade,
6};
7
8/// A one-shot readers-writer lock that panics instead of (dead)locking on contention.
9///
10/// This lock allows no contention and panics on [`lock_shared`], [`lock_exclusive`], [`lock_upgradable`], and [`upgrade`] if it is already locked conflictingly.
11/// This is useful in situations where contention would be a bug,
12/// such as in single-threaded programs that would deadlock on contention.
13///
14/// [`lock_shared`]: RawOneShotRwLock::lock_shared
15/// [`lock_exclusive`]: RawOneShotRwLock::lock_exclusive
16/// [`lock_upgradable`]: RawOneShotRwLock::lock_upgradable
17/// [`upgrade`]: RawOneShotRwLock::upgrade
18///
19/// # Examples
20///
21/// ```
22/// use one_shot_mutex::sync::OneShotRwLock;
23///
24/// static X: OneShotRwLock<i32> = OneShotRwLock::new(42);
25///
26/// // This is equivalent to `X.try_write().unwrap()`.
27/// let x = X.write();
28///
29/// // This panics instead of deadlocking.
30/// // let x2 = X.write();
31///
32/// // Once we unlock the mutex, we can lock it again.
33/// drop(x);
34/// let x = X.write();
35/// ```
36pub struct RawOneShotRwLock {
37    lock: AtomicUsize,
38}
39
40/// Normal shared lock counter
41const SHARED: usize = 1 << 2;
42/// Special upgradable shared lock flag
43const UPGRADABLE: usize = 1 << 1;
44/// Exclusive lock flag
45const EXCLUSIVE: usize = 1;
46
47impl RawOneShotRwLock {
48    pub const fn new() -> Self {
49        Self::INIT
50    }
51
52    #[inline]
53    fn is_locked_shared(&self) -> bool {
54        self.lock.load(Ordering::Relaxed) & !(EXCLUSIVE | UPGRADABLE) != 0
55    }
56
57    #[inline]
58    fn is_locked_upgradable(&self) -> bool {
59        self.lock.load(Ordering::Relaxed) & UPGRADABLE == UPGRADABLE
60    }
61
62    /// Acquire a shared lock, returning the new lock value.
63    #[inline]
64    fn acquire_shared(&self) -> usize {
65        let value = self.lock.fetch_add(SHARED, Ordering::Acquire);
66
67        // An arbitrary cap that allows us to catch overflows long before they happen
68        if value > usize::MAX / 2 {
69            self.lock.fetch_sub(SHARED, Ordering::Relaxed);
70            panic!("Too many shared locks, cannot safely proceed");
71        }
72
73        value
74    }
75}
76
77impl Default for RawOneShotRwLock {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83unsafe impl RawRwLock for RawOneShotRwLock {
84    #[allow(clippy::declare_interior_mutable_const)]
85    const INIT: Self = Self {
86        lock: AtomicUsize::new(0),
87    };
88
89    type GuardMarker = GuardSend;
90
91    #[inline]
92    fn lock_shared(&self) {
93        assert!(
94            self.try_lock_shared(),
95            "called `lock_shared` on a `RawOneShotRwLock` that is already locked exclusively"
96        );
97    }
98
99    #[inline]
100    fn try_lock_shared(&self) -> bool {
101        let value = self.acquire_shared();
102
103        let acquired = value & EXCLUSIVE != EXCLUSIVE;
104
105        if !acquired {
106            unsafe {
107                self.unlock_shared();
108            }
109        }
110
111        acquired
112    }
113
114    #[inline]
115    unsafe fn unlock_shared(&self) {
116        debug_assert!(self.is_locked_shared());
117
118        self.lock.fetch_sub(SHARED, Ordering::Release);
119    }
120
121    #[inline]
122    fn lock_exclusive(&self) {
123        assert!(
124            self.try_lock_exclusive(),
125            "called `lock_exclusive` on a `RawOneShotRwLock` that is already locked"
126        );
127    }
128
129    #[inline]
130    fn try_lock_exclusive(&self) -> bool {
131        self.lock
132            .compare_exchange(0, EXCLUSIVE, Ordering::Acquire, Ordering::Relaxed)
133            .is_ok()
134    }
135
136    #[inline]
137    unsafe fn unlock_exclusive(&self) {
138        debug_assert!(self.is_locked_exclusive());
139
140        self.lock.fetch_and(!EXCLUSIVE, Ordering::Release);
141    }
142
143    #[inline]
144    fn is_locked(&self) -> bool {
145        self.lock.load(Ordering::Relaxed) != 0
146    }
147
148    #[inline]
149    fn is_locked_exclusive(&self) -> bool {
150        self.lock.load(Ordering::Relaxed) & EXCLUSIVE == EXCLUSIVE
151    }
152}
153
154unsafe impl RawRwLockRecursive for RawOneShotRwLock {
155    #[inline]
156    fn lock_shared_recursive(&self) {
157        self.lock_shared();
158    }
159
160    #[inline]
161    fn try_lock_shared_recursive(&self) -> bool {
162        self.try_lock_shared()
163    }
164}
165
166unsafe impl RawRwLockDowngrade for RawOneShotRwLock {
167    #[inline]
168    unsafe fn downgrade(&self) {
169        // Reserve the shared guard for ourselves
170        self.acquire_shared();
171
172        unsafe {
173            self.unlock_exclusive();
174        }
175    }
176}
177
178unsafe impl RawRwLockUpgrade for RawOneShotRwLock {
179    #[inline]
180    fn lock_upgradable(&self) {
181        assert!(
182            self.try_lock_upgradable(),
183            "called `lock_upgradable` on a `RawOneShotRwLock` that is already locked upgradably or exclusively"
184        );
185    }
186
187    #[inline]
188    fn try_lock_upgradable(&self) -> bool {
189        let value = self.lock.fetch_or(UPGRADABLE, Ordering::Acquire);
190
191        let acquired = value & (UPGRADABLE | EXCLUSIVE) == 0;
192
193        if !acquired && value & UPGRADABLE == 0 {
194            unsafe {
195                self.unlock_upgradable();
196            }
197        }
198
199        acquired
200    }
201
202    #[inline]
203    unsafe fn unlock_upgradable(&self) {
204        debug_assert!(self.is_locked_upgradable());
205
206        self.lock.fetch_and(!UPGRADABLE, Ordering::Release);
207    }
208
209    #[inline]
210    unsafe fn upgrade(&self) {
211        assert!(
212            self.try_upgrade(),
213            "called `upgrade` on a `RawOneShotRwLock` that is also locked shared by others"
214        );
215    }
216
217    #[inline]
218    unsafe fn try_upgrade(&self) -> bool {
219        self.lock
220            .compare_exchange(UPGRADABLE, EXCLUSIVE, Ordering::Acquire, Ordering::Relaxed)
221            .is_ok()
222    }
223}
224
225unsafe impl RawRwLockUpgradeDowngrade for RawOneShotRwLock {
226    #[inline]
227    unsafe fn downgrade_upgradable(&self) {
228        self.acquire_shared();
229
230        unsafe {
231            self.unlock_upgradable();
232        }
233    }
234
235    #[inline]
236    unsafe fn downgrade_to_upgradable(&self) {
237        debug_assert!(self.is_locked_exclusive());
238
239        self.lock
240            .fetch_xor(UPGRADABLE | EXCLUSIVE, Ordering::Release);
241    }
242}
243
244/// A [`lock_api::RwLock`] based on [`RawOneShotRwLock`].
245pub type OneShotRwLock<T> = lock_api::RwLock<RawOneShotRwLock, T>;
246
247/// A [`lock_api::RwLockReadGuard`] based on [`RawOneShotRwLock`].
248pub type OneShotRwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, RawOneShotRwLock, T>;
249
250/// A [`lock_api::RwLockUpgradableReadGuard`] based on [`RawOneShotRwLock`].
251pub type OneShotRwLockUpgradableReadGuard<'a, T> =
252    lock_api::RwLockUpgradableReadGuard<'a, RawOneShotRwLock, T>;
253
254/// A [`lock_api::RwLockWriteGuard`] based on [`RawOneShotRwLock`].
255pub type OneShotRwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, RawOneShotRwLock, T>;
256
257#[cfg(test)]
258mod tests {
259    use lock_api::RwLockUpgradableReadGuard;
260
261    use super::*;
262
263    #[test]
264    fn lock_exclusive() {
265        let lock = OneShotRwLock::new(42);
266        let mut guard = lock.write();
267        assert_eq!(*guard, 42);
268
269        *guard += 1;
270        drop(guard);
271        let guard = lock.write();
272        assert_eq!(*guard, 43);
273    }
274
275    #[test]
276    #[should_panic]
277    fn lock_exclusive_panic() {
278        let lock = OneShotRwLock::new(42);
279        let _guard = lock.write();
280        let _guard2 = lock.write();
281    }
282
283    #[test]
284    #[should_panic]
285    fn lock_exclusive_shared_panic() {
286        let lock = OneShotRwLock::new(42);
287        let _guard = lock.write();
288        let _guard2 = lock.read();
289    }
290
291    #[test]
292    fn try_lock_exclusive() {
293        let lock = OneShotRwLock::new(42);
294        let mut guard = lock.try_write().unwrap();
295        assert_eq!(*guard, 42);
296        assert!(lock.try_write().is_none());
297
298        *guard += 1;
299        drop(guard);
300        let guard = lock.try_write().unwrap();
301        assert_eq!(*guard, 43);
302    }
303
304    #[test]
305    fn lock_shared() {
306        let lock = OneShotRwLock::new(42);
307        let guard = lock.read();
308        assert_eq!(*guard, 42);
309        let guard2 = lock.read();
310        assert_eq!(*guard2, 42);
311    }
312
313    #[test]
314    #[should_panic]
315    fn lock_shared_panic() {
316        let lock = OneShotRwLock::new(42);
317        let _guard = lock.write();
318        let _guard2 = lock.read();
319    }
320
321    #[test]
322    fn try_lock_shared() {
323        let lock = OneShotRwLock::new(42);
324        let guard = lock.try_read().unwrap();
325        assert_eq!(*guard, 42);
326        assert!(lock.try_write().is_none());
327
328        let guard2 = lock.try_read().unwrap();
329        assert_eq!(*guard2, 42);
330    }
331
332    #[test]
333    fn lock_upgradable() {
334        let lock = OneShotRwLock::new(42);
335        let guard = lock.upgradable_read();
336        assert_eq!(*guard, 42);
337        assert!(lock.try_write().is_none());
338
339        let mut upgraded = RwLockUpgradableReadGuard::upgrade(guard);
340        *upgraded += 1;
341        drop(upgraded);
342        let guard2 = lock.upgradable_read();
343        assert_eq!(*guard2, 43);
344    }
345
346    #[test]
347    #[should_panic]
348    fn lock_upgradable_panic() {
349        let lock = OneShotRwLock::new(42);
350        let _guard = lock.upgradable_read();
351        let _guard2 = lock.upgradable_read();
352    }
353
354    #[test]
355    #[should_panic]
356    fn lock_upgradable_write_panic() {
357        let lock = OneShotRwLock::new(42);
358        let _guard = lock.write();
359        let _guard2 = lock.upgradable_read();
360    }
361
362    #[test]
363    fn try_lock_upgradable() {
364        let lock = OneShotRwLock::new(42);
365        let guard = lock.try_upgradable_read().unwrap();
366        assert_eq!(*guard, 42);
367        assert!(lock.try_write().is_none());
368
369        let mut upgraded = RwLockUpgradableReadGuard::try_upgrade(guard).unwrap();
370        *upgraded += 1;
371        drop(upgraded);
372        let guard2 = lock.try_upgradable_read().unwrap();
373        assert_eq!(*guard2, 43);
374    }
375
376    #[test]
377    #[should_panic]
378    fn upgrade_panic() {
379        let lock = OneShotRwLock::new(42);
380        let guard = lock.upgradable_read();
381        let _guard2 = lock.read();
382        let _guard3 = RwLockUpgradableReadGuard::upgrade(guard);
383    }
384}