1#![no_std]
8
9use core::marker::PhantomData;
10use core::mem;
11use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, Not, Range, Shl, Shr};
12
13pub trait UnsignedPrimInt:
14 UnsignedPrimIntSealed
15 + Copy
16 + Eq
17 + Not<Output = Self>
18 + BitAnd<Output = Self>
19 + BitOr<Output = Self>
20 + BitAndAssign
21 + BitOrAssign
22 + Shl<usize, Output = Self>
23 + Shr<usize, Output = Self>
24 + From<bool> {
26 const NUM_BITS: usize = mem::size_of::<Self>() * 8;
27
28 fn zero() -> Self {
29 false.into()
30 }
31
32 fn one() -> Self {
33 true.into()
34 }
35}
36
37pub trait PrimInt: PrimIntSealed {
38 type Unsigned: UnsignedPrimInt;
39
40 fn cast_from_unsigned(val: Self::Unsigned) -> Self;
41 fn cast_to_unsigned(val: Self) -> Self::Unsigned;
42}
43
44impl<T> PrimInt for T
45where
46 T: UnsignedPrimInt + PrimIntSealed,
47{
48 type Unsigned = Self;
49
50 fn cast_from_unsigned(val: Self::Unsigned) -> Self {
51 val
52 }
53
54 fn cast_to_unsigned(val: Self) -> Self::Unsigned {
55 val
56 }
57}
58
59use sealing::{PrimIntSealed, UnsignedPrimIntSealed};
60
61mod sealing {
62 pub trait UnsignedPrimIntSealed {}
63
64 pub trait PrimIntSealed {}
65
66 impl<T: UnsignedPrimIntSealed> PrimIntSealed for T {}
67}
68
69macro_rules! impl_prim_int {
70 ($unsigned:ty, $signed:ty) => {
71 impl UnsignedPrimInt for $unsigned {}
72
73 impl PrimInt for $signed {
74 type Unsigned = $unsigned;
75
76 fn cast_from_unsigned(val: Self::Unsigned) -> Self {
77 val as Self
78 }
79
80 fn cast_to_unsigned(val: Self) -> Self::Unsigned {
81 val as Self::Unsigned
82 }
83 }
84
85 impl UnsignedPrimIntSealed for $unsigned {}
86
87 impl PrimIntSealed for $signed {}
88 };
89}
90
91impl_prim_int!(u8, i8);
92impl_prim_int!(u16, i16);
93impl_prim_int!(u32, i32);
94impl_prim_int!(u64, i64);
95impl_prim_int!(u128, i128);
96impl_prim_int!(usize, isize);
97
98trait UnsignedPrimIntExt: UnsignedPrimInt {
101 fn mask(range: Range<usize>) -> Self {
102 debug_assert!(range.start <= range.end);
103 debug_assert!(range.end <= Self::NUM_BITS);
104 let num_bits = range.end - range.start;
105 match num_bits {
107 0 => Self::zero(),
108 _ if num_bits == Self::NUM_BITS => !Self::zero(),
109 _ => !(!Self::zero() << num_bits) << range.start,
110 }
111 }
112
113 fn take(self, num_bits: usize) -> Self {
114 self & Self::mask(0..num_bits)
115 }
116}
117
118impl<T: UnsignedPrimInt> UnsignedPrimIntExt for T {}
119
120pub fn get_bit<T: UnsignedPrimInt>(src: &[T], i: usize) -> bool {
123 assert!(i < src.len() * T::NUM_BITS);
124 src[i / T::NUM_BITS] & (T::one() << (i % T::NUM_BITS)) != T::zero()
125}
126
127pub fn get_bits<T: UnsignedPrimInt, U: UnsignedPrimInt + TryFrom<T>>(
128 src: &[T],
129 src_range: Range<usize>,
130) -> U {
131 check_range::<T, U>(src, &src_range);
132
133 let num_bits = src_range.end - src_range.start;
134 let index_of_first_primitive = src_range.start / T::NUM_BITS;
135 let offset_into_first_primitive = src_range.start % T::NUM_BITS;
136 let num_bits_from_first_primitive = (T::NUM_BITS - offset_into_first_primitive).min(num_bits);
137
138 let bits_from_first_primitive = (src[index_of_first_primitive] >> offset_into_first_primitive)
139 .take(num_bits_from_first_primitive);
140
141 let mut bits = checked_cast::<T, U>(bits_from_first_primitive);
142 let mut num_bits_so_far = num_bits_from_first_primitive;
143 let mut index_of_cur_primitive = index_of_first_primitive + 1;
144
145 while num_bits_so_far < num_bits {
146 let num_bits_from_cur_primitive = (num_bits - num_bits_so_far).min(T::NUM_BITS);
147 let bits_from_cur_primitive = src[index_of_cur_primitive].take(num_bits_from_cur_primitive);
148 bits |= checked_cast::<T, U>(bits_from_cur_primitive) << num_bits_so_far;
149 num_bits_so_far += num_bits_from_cur_primitive;
150 index_of_cur_primitive += 1;
151 }
152
153 bits
154}
155
156pub fn set_bits<T: UnsignedPrimInt, U: UnsignedPrimInt + TryInto<T>>(
157 dst: &mut [T],
158 dst_range: Range<usize>,
159 src: U,
160) {
161 check_range::<T, U>(dst, &dst_range);
162
163 let num_bits = dst_range.end - dst_range.start;
164
165 assert!(num_bits == U::NUM_BITS || src >> num_bits == U::zero());
166
167 let index_of_first_primitive = dst_range.start / T::NUM_BITS;
168 let offset_into_first_primitive = dst_range.start % T::NUM_BITS;
169 let num_bits_for_first_primitive = (T::NUM_BITS - offset_into_first_primitive).min(num_bits);
170 let bits_for_first_primitive = src.take(num_bits_for_first_primitive);
171
172 dst[index_of_first_primitive] = (dst[index_of_first_primitive]
173 & !T::mask(
174 offset_into_first_primitive
175 ..(offset_into_first_primitive + num_bits_for_first_primitive),
176 ))
177 | checked_cast(bits_for_first_primitive) << offset_into_first_primitive;
178
179 let mut num_bits_so_far = num_bits_for_first_primitive;
180 let mut index_of_cur_primitive = index_of_first_primitive + 1;
181
182 while num_bits_so_far < num_bits {
183 let num_bits_for_cur_primitive = (num_bits - num_bits_so_far).min(T::NUM_BITS);
184 let bits_for_cur_primitive = (src >> num_bits_so_far).take(num_bits_for_cur_primitive);
185 dst[index_of_cur_primitive] = (dst[index_of_cur_primitive]
186 & T::mask(num_bits_for_cur_primitive..T::NUM_BITS))
187 | checked_cast(bits_for_cur_primitive);
188 num_bits_so_far += num_bits_for_cur_primitive;
189 index_of_cur_primitive += 1;
190 }
191}
192
193fn check_range<T: UnsignedPrimInt, U: UnsignedPrimInt>(arr: &[T], range: &Range<usize>) {
194 assert!(range.start <= range.end);
195 assert!(range.end <= arr.len() * T::NUM_BITS);
196 assert!(range.end - range.start <= U::NUM_BITS);
197}
198
199fn checked_cast<T: TryInto<U>, U>(val: T) -> U {
200 val.try_into().map_err(|_| unreachable!()).unwrap()
201}
202
203pub fn set_bits_from_slice<T, U>(
204 dst: &mut [T],
205 dst_range: Range<usize>,
206 src: &[U],
207 src_start: usize,
208) where
209 T: UnsignedPrimInt + TryFrom<usize>,
210 U: UnsignedPrimInt,
211 usize: TryFrom<U>,
212{
213 set_bits_from_slice_via::<_, _, usize>(dst, dst_range, src, src_start)
214}
215
216fn set_bits_from_slice_via<T, U, V>(
217 dst: &mut [T],
218 dst_range: Range<usize>,
219 src: &[U],
220 src_start: usize,
221) where
222 T: UnsignedPrimInt + TryFrom<V>,
223 U: UnsignedPrimInt,
224 V: UnsignedPrimInt + TryFrom<U>,
225{
226 let num_bits = dst_range.len();
227
228 assert!(dst_range.start <= dst_range.end);
229 assert!(dst_range.end <= dst.len() * T::NUM_BITS);
230 assert!(src_start + num_bits <= src.len() * U::NUM_BITS);
231
232 let mut cur_xfer_start = 0;
233 while cur_xfer_start < num_bits {
234 let cur_xfer_end = num_bits.min(cur_xfer_start + V::NUM_BITS);
235 let cur_xfer_src_range = (src_start + cur_xfer_start)..(src_start + cur_xfer_end);
236 let cur_xfer_dst_range =
237 (dst_range.start + cur_xfer_start)..(dst_range.start + cur_xfer_end);
238 let xfer: V = get_bits(src, cur_xfer_src_range);
239 set_bits(dst, cur_xfer_dst_range, xfer);
240 cur_xfer_start = cur_xfer_end;
241 }
242}
243
244pub fn get<T: UnsignedPrimInt, U: PrimInt>(src: &[T], src_start_bit: usize) -> U
247where
248 U::Unsigned: TryFrom<T>,
249{
250 let src_range = src_start_bit..(src_start_bit + U::Unsigned::NUM_BITS);
251 U::cast_from_unsigned(get_bits(src, src_range))
252}
253
254pub fn set<T: UnsignedPrimInt, U: PrimInt>(dst: &mut [T], dst_start_bit: usize, src: U)
255where
256 U::Unsigned: TryInto<T>,
257{
258 let dst_range = dst_start_bit..(dst_start_bit + U::Unsigned::NUM_BITS);
259 set_bits(dst, dst_range, U::cast_to_unsigned(src))
260}
261
262#[repr(transparent)]
265#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
266pub struct Bitfield<T, U> {
267 inner: T,
268 _phantom: PhantomData<U>,
269}
270
271impl<T, U> Bitfield<T, U> {
272 pub fn new(inner: T) -> Self {
273 Self {
274 inner,
275 _phantom: PhantomData,
276 }
277 }
278
279 pub fn into_inner(self) -> T {
280 self.inner
281 }
282
283 pub fn inner(&self) -> &T {
284 &self.inner
285 }
286
287 pub fn inner_mut(&mut self) -> &mut T {
288 &mut self.inner
289 }
290}
291
292impl<T: UnsignedPrimInt, const N: usize> Bitfield<[T; N], T> {
293 pub fn zeroed() -> Self {
294 Self::new([T::zero(); N])
295 }
296}
297
298impl<T: AsRef<[U]>, U: UnsignedPrimInt> Bitfield<T, U> {
299 pub fn bits(&self) -> &[U] {
300 self.inner().as_ref()
301 }
302
303 pub fn get_bits<V: UnsignedPrimInt + TryFrom<U>>(&self, range: Range<usize>) -> V {
304 get_bits(self.bits(), range)
305 }
306
307 pub fn get_bits_into_slice<V>(&self, range: Range<usize>, dst: &mut [V], dst_start: usize)
308 where
309 V: UnsignedPrimInt + TryFrom<usize>,
310 usize: TryFrom<U>,
311 {
312 let dst_range = dst_start..(dst_start + range.len());
313 set_bits_from_slice(dst, dst_range, self.bits(), range.start)
314 }
315
316 pub fn get<V: PrimInt>(&self, start_bit: usize) -> V
317 where
318 V::Unsigned: TryFrom<U>,
319 {
320 get(self.bits(), start_bit)
321 }
322}
323
324impl<T: AsMut<[U]>, U: UnsignedPrimInt> Bitfield<T, U> {
325 pub fn bits_mut(&mut self) -> &mut [U] {
326 self.inner_mut().as_mut()
327 }
328
329 pub fn set_bits<V: UnsignedPrimInt + TryInto<U>>(&mut self, range: Range<usize>, src: V) {
330 set_bits(self.bits_mut(), range, src)
331 }
332
333 pub fn set_bits_from_slice<V: UnsignedPrimInt>(
334 &mut self,
335 range: Range<usize>,
336 src: &[V],
337 src_start: usize,
338 ) where
339 U: TryFrom<usize>,
340 usize: TryFrom<V>,
341 {
342 set_bits_from_slice(self.bits_mut(), range, src, src_start)
343 }
344
345 pub fn set<V: PrimInt>(&mut self, start_bit: usize, src: V)
346 where
347 V::Unsigned: TryInto<U>,
348 {
349 set(self.bits_mut(), start_bit, src)
350 }
351}
352
353#[cfg(test)]
356#[allow(unused_imports)]
357mod test {
358
359 extern crate std;
360
361 use std::eprintln;
362 use std::fmt;
363
364 use super::*;
365
366 #[test]
367 fn zero_gets_zero() {
368 assert_eq!(Bitfield::<[u64; 2], _>::zeroed().get_bits::<u64>(50..80), 0);
369 }
370
371 fn set_and_get<
372 T: UnsignedPrimInt,
373 const N: usize,
374 U: UnsignedPrimInt + TryInto<T> + TryFrom<T> + fmt::Debug,
375 >(
376 range: Range<usize>,
377 val: U,
378 ) {
379 let mut arr = Bitfield::<[T; N], _>::zeroed();
380 set_bits(arr.inner_mut(), range.clone(), val);
381 let observed_val: U = get_bits(arr.inner(), range);
382 assert_eq!(observed_val, val);
383 }
384
385 #[test]
386 fn get_returns_what_was_set() {
387 set_and_get::<u8, 3, _>(8..16, !0u8);
388 set_and_get::<u8, 3, _>(2..18, !0u32 >> 16);
389 set_and_get::<u128, 1, _>(2..18, !0u32 >> 16);
390 }
391
392 #[test]
393 fn multiple_gets_return_what_was_set_with_multiple_sets() {
394 for init in [0, !0] {
395 let mut arr = Bitfield::<[u64; 1], u64>::new([init]);
396 arr.set_bits::<u64>(0..2, 0b11);
397 arr.set_bits::<u64>(60..64, 0b1111);
398 arr.set_bits::<u64>(10..11, 0b1);
399 assert_eq!(arr.get_bits::<u64>(0..2), 0b11);
400 assert_eq!(arr.get_bits::<u64>(60..64), 0b1111);
401 assert_eq!(arr.get_bits::<u64>(10..11), 0b1);
402 }
403 }
404}
405
406#[cfg(kani)]
407mod verification {
408 use super::*;
409
410 #[kani::proof]
411 #[kani::unwind(4)]
412 fn slice_ops() {
413 slice_ops_generic::<u64, 3, u8, 3, u8>(kani::any());
414 slice_ops_generic::<u8, 3, u64, 3, u8>(kani::any());
415 slice_ops_generic::<u64, 3, u8, 3, u32>(kani::any());
416 slice_ops_generic::<u8, 3, u64, 3, u32>(kani::any());
417 }
418
419 fn slice_ops_generic<T, const N: usize, U, const M: usize, V>((a, b): ([T; N], [U; M]))
421 where
422 T: UnsignedPrimInt + TryFrom<V>,
423 U: UnsignedPrimInt + TryFrom<V>,
424 V: UnsignedPrimInt + TryFrom<T> + TryFrom<U>,
425 {
426 let n: usize = kani::any();
427 let start_a: usize = kani::any();
428 let start_b: usize = kani::any();
429
430 kani::assume(n <= a.len());
431 kani::assume(n <= b.len());
432 kani::assume(
433 start_a
434 .checked_add(n)
435 .map(|end| end <= a.len())
436 .unwrap_or(false),
437 );
438 kani::assume(
439 start_b
440 .checked_add(n)
441 .map(|end| end <= b.len())
442 .unwrap_or(false),
443 );
444
445 let range_a = start_a..(start_a + n);
446
447 let mut a_mut = a;
448
449 set_bits_from_slice_via::<_, _, V>(&mut a_mut, range_a.clone(), &b, start_b);
450
451 let i: usize = kani::any();
452 kani::assume(i < a.len() * T::NUM_BITS);
453
454 let val = get_bit(&a_mut, i);
455
456 let val_expected = if range_a.contains(&i) {
457 get_bit(&b, i - start_a + start_b)
458 } else {
459 get_bit(&a, i)
460 };
461
462 assert_eq!(val, val_expected);
463 }
464}