sel4_bitfield_ops/
lib.rs

1//
2// Copyright 2023, Colias Group, LLC
3//
4// SPDX-License-Identifier: BSD-2-Clause
5//
6
7#![no_std]
8
9use core::marker::PhantomData;
10use core::mem;
11use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, Not, Range, Shl, Shr};
12
13pub trait UnsignedPrimInt:
14    UnsignedPrimIntSealed
15    + Copy
16    + Eq
17    + Not<Output = Self>
18    + BitAnd<Output = Self>
19    + BitOr<Output = Self>
20    + BitAndAssign
21    + BitOrAssign
22    + Shl<usize, Output = Self>
23    + Shr<usize, Output = Self>
24    + From<bool> // HACK for generic 0 and 1
25{
26    const NUM_BITS: usize = mem::size_of::<Self>() * 8;
27
28    fn zero() -> Self {
29        false.into()
30    }
31
32    fn one() -> Self {
33        true.into()
34    }
35}
36
37pub trait PrimInt: PrimIntSealed {
38    type Unsigned: UnsignedPrimInt;
39
40    fn cast_from_unsigned(val: Self::Unsigned) -> Self;
41    fn cast_to_unsigned(val: Self) -> Self::Unsigned;
42}
43
44impl<T> PrimInt for T
45where
46    T: UnsignedPrimInt + PrimIntSealed,
47{
48    type Unsigned = Self;
49
50    fn cast_from_unsigned(val: Self::Unsigned) -> Self {
51        val
52    }
53
54    fn cast_to_unsigned(val: Self) -> Self::Unsigned {
55        val
56    }
57}
58
59use sealing::{PrimIntSealed, UnsignedPrimIntSealed};
60
61mod sealing {
62    pub trait UnsignedPrimIntSealed {}
63
64    pub trait PrimIntSealed {}
65
66    impl<T: UnsignedPrimIntSealed> PrimIntSealed for T {}
67}
68
69macro_rules! impl_prim_int {
70    ($unsigned:ty, $signed:ty) => {
71        impl UnsignedPrimInt for $unsigned {}
72
73        impl PrimInt for $signed {
74            type Unsigned = $unsigned;
75
76            fn cast_from_unsigned(val: Self::Unsigned) -> Self {
77                val as Self
78            }
79
80            fn cast_to_unsigned(val: Self) -> Self::Unsigned {
81                val as Self::Unsigned
82            }
83        }
84
85        impl UnsignedPrimIntSealed for $unsigned {}
86
87        impl PrimIntSealed for $signed {}
88    };
89}
90
91impl_prim_int!(u8, i8);
92impl_prim_int!(u16, i16);
93impl_prim_int!(u32, i32);
94impl_prim_int!(u64, i64);
95impl_prim_int!(u128, i128);
96impl_prim_int!(usize, isize);
97
98// // //
99
100trait UnsignedPrimIntExt: UnsignedPrimInt {
101    fn mask(range: Range<usize>) -> Self {
102        debug_assert!(range.start <= range.end);
103        debug_assert!(range.end <= Self::NUM_BITS);
104        let num_bits = range.end - range.start;
105        // avoid overflow
106        match num_bits {
107            0 => Self::zero(),
108            _ if num_bits == Self::NUM_BITS => !Self::zero(),
109            _ => !(!Self::zero() << num_bits) << range.start,
110        }
111    }
112
113    fn take(self, num_bits: usize) -> Self {
114        self & Self::mask(0..num_bits)
115    }
116}
117
118impl<T: UnsignedPrimInt> UnsignedPrimIntExt for T {}
119
120// // //
121
122pub fn get_bit<T: UnsignedPrimInt>(src: &[T], i: usize) -> bool {
123    assert!(i < src.len() * T::NUM_BITS);
124    src[i / T::NUM_BITS] & (T::one() << (i % T::NUM_BITS)) != T::zero()
125}
126
127pub fn get_bits<T: UnsignedPrimInt, U: UnsignedPrimInt + TryFrom<T>>(
128    src: &[T],
129    src_range: Range<usize>,
130) -> U {
131    check_range::<T, U>(src, &src_range);
132
133    let num_bits = src_range.end - src_range.start;
134    let index_of_first_primitive = src_range.start / T::NUM_BITS;
135    let offset_into_first_primitive = src_range.start % T::NUM_BITS;
136    let num_bits_from_first_primitive = (T::NUM_BITS - offset_into_first_primitive).min(num_bits);
137
138    let bits_from_first_primitive = (src[index_of_first_primitive] >> offset_into_first_primitive)
139        .take(num_bits_from_first_primitive);
140
141    let mut bits = checked_cast::<T, U>(bits_from_first_primitive);
142    let mut num_bits_so_far = num_bits_from_first_primitive;
143    let mut index_of_cur_primitive = index_of_first_primitive + 1;
144
145    while num_bits_so_far < num_bits {
146        let num_bits_from_cur_primitive = (num_bits - num_bits_so_far).min(T::NUM_BITS);
147        let bits_from_cur_primitive = src[index_of_cur_primitive].take(num_bits_from_cur_primitive);
148        bits |= checked_cast::<T, U>(bits_from_cur_primitive) << num_bits_so_far;
149        num_bits_so_far += num_bits_from_cur_primitive;
150        index_of_cur_primitive += 1;
151    }
152
153    bits
154}
155
156pub fn set_bits<T: UnsignedPrimInt, U: UnsignedPrimInt + TryInto<T>>(
157    dst: &mut [T],
158    dst_range: Range<usize>,
159    src: U,
160) {
161    check_range::<T, U>(dst, &dst_range);
162
163    let num_bits = dst_range.end - dst_range.start;
164
165    assert!(num_bits == U::NUM_BITS || src >> num_bits == U::zero());
166
167    let index_of_first_primitive = dst_range.start / T::NUM_BITS;
168    let offset_into_first_primitive = dst_range.start % T::NUM_BITS;
169    let num_bits_for_first_primitive = (T::NUM_BITS - offset_into_first_primitive).min(num_bits);
170    let bits_for_first_primitive = src.take(num_bits_for_first_primitive);
171
172    dst[index_of_first_primitive] = (dst[index_of_first_primitive]
173        & !T::mask(
174            offset_into_first_primitive
175                ..(offset_into_first_primitive + num_bits_for_first_primitive),
176        ))
177        | checked_cast(bits_for_first_primitive) << offset_into_first_primitive;
178
179    let mut num_bits_so_far = num_bits_for_first_primitive;
180    let mut index_of_cur_primitive = index_of_first_primitive + 1;
181
182    while num_bits_so_far < num_bits {
183        let num_bits_for_cur_primitive = (num_bits - num_bits_so_far).min(T::NUM_BITS);
184        let bits_for_cur_primitive = (src >> num_bits_so_far).take(num_bits_for_cur_primitive);
185        dst[index_of_cur_primitive] = (dst[index_of_cur_primitive]
186            & T::mask(num_bits_for_cur_primitive..T::NUM_BITS))
187            | checked_cast(bits_for_cur_primitive);
188        num_bits_so_far += num_bits_for_cur_primitive;
189        index_of_cur_primitive += 1;
190    }
191}
192
193fn check_range<T: UnsignedPrimInt, U: UnsignedPrimInt>(arr: &[T], range: &Range<usize>) {
194    assert!(range.start <= range.end);
195    assert!(range.end <= arr.len() * T::NUM_BITS);
196    assert!(range.end - range.start <= U::NUM_BITS);
197}
198
199fn checked_cast<T: TryInto<U>, U>(val: T) -> U {
200    val.try_into().map_err(|_| unreachable!()).unwrap()
201}
202
203pub fn set_bits_from_slice<T, U>(
204    dst: &mut [T],
205    dst_range: Range<usize>,
206    src: &[U],
207    src_start: usize,
208) where
209    T: UnsignedPrimInt + TryFrom<usize>,
210    U: UnsignedPrimInt,
211    usize: TryFrom<U>,
212{
213    set_bits_from_slice_via::<_, _, usize>(dst, dst_range, src, src_start)
214}
215
216fn set_bits_from_slice_via<T, U, V>(
217    dst: &mut [T],
218    dst_range: Range<usize>,
219    src: &[U],
220    src_start: usize,
221) where
222    T: UnsignedPrimInt + TryFrom<V>,
223    U: UnsignedPrimInt,
224    V: UnsignedPrimInt + TryFrom<U>,
225{
226    let num_bits = dst_range.len();
227
228    assert!(dst_range.start <= dst_range.end);
229    assert!(dst_range.end <= dst.len() * T::NUM_BITS);
230    assert!(src_start + num_bits <= src.len() * U::NUM_BITS);
231
232    let mut cur_xfer_start = 0;
233    while cur_xfer_start < num_bits {
234        let cur_xfer_end = num_bits.min(cur_xfer_start + V::NUM_BITS);
235        let cur_xfer_src_range = (src_start + cur_xfer_start)..(src_start + cur_xfer_end);
236        let cur_xfer_dst_range =
237            (dst_range.start + cur_xfer_start)..(dst_range.start + cur_xfer_end);
238        let xfer: V = get_bits(src, cur_xfer_src_range);
239        set_bits(dst, cur_xfer_dst_range, xfer);
240        cur_xfer_start = cur_xfer_end;
241    }
242}
243
244// // //
245
246pub fn get<T: UnsignedPrimInt, U: PrimInt>(src: &[T], src_start_bit: usize) -> U
247where
248    U::Unsigned: TryFrom<T>,
249{
250    let src_range = src_start_bit..(src_start_bit + U::Unsigned::NUM_BITS);
251    U::cast_from_unsigned(get_bits(src, src_range))
252}
253
254pub fn set<T: UnsignedPrimInt, U: PrimInt>(dst: &mut [T], dst_start_bit: usize, src: U)
255where
256    U::Unsigned: TryInto<T>,
257{
258    let dst_range = dst_start_bit..(dst_start_bit + U::Unsigned::NUM_BITS);
259    set_bits(dst, dst_range, U::cast_to_unsigned(src))
260}
261
262// // //
263
264#[repr(transparent)]
265#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
266pub struct Bitfield<T, U> {
267    inner: T,
268    _phantom: PhantomData<U>,
269}
270
271impl<T, U> Bitfield<T, U> {
272    pub fn new(inner: T) -> Self {
273        Self {
274            inner,
275            _phantom: PhantomData,
276        }
277    }
278
279    pub fn into_inner(self) -> T {
280        self.inner
281    }
282
283    pub fn inner(&self) -> &T {
284        &self.inner
285    }
286
287    pub fn inner_mut(&mut self) -> &mut T {
288        &mut self.inner
289    }
290}
291
292impl<T: UnsignedPrimInt, const N: usize> Bitfield<[T; N], T> {
293    pub fn zeroed() -> Self {
294        Self::new([T::zero(); N])
295    }
296}
297
298impl<T: AsRef<[U]>, U: UnsignedPrimInt> Bitfield<T, U> {
299    pub fn bits(&self) -> &[U] {
300        self.inner().as_ref()
301    }
302
303    pub fn get_bits<V: UnsignedPrimInt + TryFrom<U>>(&self, range: Range<usize>) -> V {
304        get_bits(self.bits(), range)
305    }
306
307    pub fn get_bits_into_slice<V>(&self, range: Range<usize>, dst: &mut [V], dst_start: usize)
308    where
309        V: UnsignedPrimInt + TryFrom<usize>,
310        usize: TryFrom<U>,
311    {
312        let dst_range = dst_start..(dst_start + range.len());
313        set_bits_from_slice(dst, dst_range, self.bits(), range.start)
314    }
315
316    pub fn get<V: PrimInt>(&self, start_bit: usize) -> V
317    where
318        V::Unsigned: TryFrom<U>,
319    {
320        get(self.bits(), start_bit)
321    }
322}
323
324impl<T: AsMut<[U]>, U: UnsignedPrimInt> Bitfield<T, U> {
325    pub fn bits_mut(&mut self) -> &mut [U] {
326        self.inner_mut().as_mut()
327    }
328
329    pub fn set_bits<V: UnsignedPrimInt + TryInto<U>>(&mut self, range: Range<usize>, src: V) {
330        set_bits(self.bits_mut(), range, src)
331    }
332
333    pub fn set_bits_from_slice<V: UnsignedPrimInt>(
334        &mut self,
335        range: Range<usize>,
336        src: &[V],
337        src_start: usize,
338    ) where
339        U: TryFrom<usize>,
340        usize: TryFrom<V>,
341    {
342        set_bits_from_slice(self.bits_mut(), range, src, src_start)
343    }
344
345    pub fn set<V: PrimInt>(&mut self, start_bit: usize, src: V)
346    where
347        V::Unsigned: TryInto<U>,
348    {
349        set(self.bits_mut(), start_bit, src)
350    }
351}
352
353// // //
354
355#[cfg(test)]
356#[allow(unused_imports)]
357mod test {
358
359    extern crate std;
360
361    use std::eprintln;
362    use std::fmt;
363
364    use super::*;
365
366    #[test]
367    fn zero_gets_zero() {
368        assert_eq!(Bitfield::<[u64; 2], _>::zeroed().get_bits::<u64>(50..80), 0);
369    }
370
371    fn set_and_get<
372        T: UnsignedPrimInt,
373        const N: usize,
374        U: UnsignedPrimInt + TryInto<T> + TryFrom<T> + fmt::Debug,
375    >(
376        range: Range<usize>,
377        val: U,
378    ) {
379        let mut arr = Bitfield::<[T; N], _>::zeroed();
380        set_bits(arr.inner_mut(), range.clone(), val);
381        let observed_val: U = get_bits(arr.inner(), range);
382        assert_eq!(observed_val, val);
383    }
384
385    #[test]
386    fn get_returns_what_was_set() {
387        set_and_get::<u8, 3, _>(8..16, !0u8);
388        set_and_get::<u8, 3, _>(2..18, !0u32 >> 16);
389        set_and_get::<u128, 1, _>(2..18, !0u32 >> 16);
390    }
391
392    #[test]
393    fn multiple_gets_return_what_was_set_with_multiple_sets() {
394        for init in [0, !0] {
395            let mut arr = Bitfield::<[u64; 1], u64>::new([init]);
396            arr.set_bits::<u64>(0..2, 0b11);
397            arr.set_bits::<u64>(60..64, 0b1111);
398            arr.set_bits::<u64>(10..11, 0b1);
399            assert_eq!(arr.get_bits::<u64>(0..2), 0b11);
400            assert_eq!(arr.get_bits::<u64>(60..64), 0b1111);
401            assert_eq!(arr.get_bits::<u64>(10..11), 0b1);
402        }
403    }
404}
405
406#[cfg(kani)]
407mod verification {
408    use super::*;
409
410    #[kani::proof]
411    #[kani::unwind(4)]
412    fn slice_ops() {
413        slice_ops_generic::<u64, 3, u8, 3, u8>(kani::any());
414        slice_ops_generic::<u8, 3, u64, 3, u8>(kani::any());
415        slice_ops_generic::<u64, 3, u8, 3, u32>(kani::any());
416        slice_ops_generic::<u8, 3, u64, 3, u32>(kani::any());
417    }
418
419    // The type of kani::any() can't depend on generic parameters, so we pass the arrays as args.
420    fn slice_ops_generic<T, const N: usize, U, const M: usize, V>((a, b): ([T; N], [U; M]))
421    where
422        T: UnsignedPrimInt + TryFrom<V>,
423        U: UnsignedPrimInt + TryFrom<V>,
424        V: UnsignedPrimInt + TryFrom<T> + TryFrom<U>,
425    {
426        let n: usize = kani::any();
427        let start_a: usize = kani::any();
428        let start_b: usize = kani::any();
429
430        kani::assume(n <= a.len());
431        kani::assume(n <= b.len());
432        kani::assume(
433            start_a
434                .checked_add(n)
435                .map(|end| end <= a.len())
436                .unwrap_or(false),
437        );
438        kani::assume(
439            start_b
440                .checked_add(n)
441                .map(|end| end <= b.len())
442                .unwrap_or(false),
443        );
444
445        let range_a = start_a..(start_a + n);
446
447        let mut a_mut = a;
448
449        set_bits_from_slice_via::<_, _, V>(&mut a_mut, range_a.clone(), &b, start_b);
450
451        let i: usize = kani::any();
452        kani::assume(i < a.len() * T::NUM_BITS);
453
454        let val = get_bit(&a_mut, i);
455
456        let val_expected = if range_a.contains(&i) {
457            get_bit(&b, i - start_a + start_b)
458        } else {
459            get_bit(&a, i)
460        };
461
462        assert_eq!(val, val_expected);
463    }
464}