sel4_shared_ring_buffer_block_io/
owned.rs

1//
2// Copyright 2023, Colias Group, LLC
3//
4// SPDX-License-Identifier: BSD-2-Clause
5//
6
7use core::alloc::Layout;
8use core::marker::PhantomData;
9use core::task::{Poll, Waker};
10
11use sel4_abstract_allocator::{AbstractAllocator, AbstractAllocatorAllocation};
12use sel4_async_block_io::{access::Access, Operation};
13use sel4_shared_memory::SharedMemoryRef;
14use sel4_shared_ring_buffer::{
15    roles::Provide, Descriptor, PeerMisbehaviorError as SharedRingBuffersPeerMisbehaviorError,
16    RingBuffers,
17};
18use sel4_shared_ring_buffer_block_io_types::{
19    BlockIORequest, BlockIORequestStatus, BlockIORequestType,
20};
21use sel4_shared_ring_buffer_bookkeeping::{slot_set_semaphore::*, slot_tracker::*};
22
23pub use crate::errors::{Error, ErrorOrUserError, IOError, PeerMisbehaviorError, UserError};
24
25pub struct OwnedSharedRingBufferBlockIO<S, A: AbstractAllocator, F> {
26    dma_region: SharedMemoryRef<'static, [u8]>,
27    bounce_buffer_allocator: A,
28    ring_buffers: RingBuffers<'static, Provide, F, BlockIORequest>,
29    requests: SlotTracker<StateTypesImpl<A>>,
30    slot_set_semaphore: SlotSetSemaphore<S, NUM_SLOT_POOLS>,
31}
32
33const RING_BUFFERS_SLOT_POOL_INDEX: usize = 0;
34const REQUESTS_SLOT_POOL_INDEX: usize = 1;
35const NUM_SLOT_POOLS: usize = 2;
36
37struct StateTypesImpl<A> {
38    _phantom: PhantomData<A>,
39}
40
41impl<A: AbstractAllocator> SlotStateTypes for StateTypesImpl<A> {
42    type Common = ();
43    type Free = ();
44    type Occupied = Occupied<A>;
45}
46
47struct Occupied<A: AbstractAllocator> {
48    req: BlockIORequest,
49    state: OccupiedState,
50    allocation: A::Allocation,
51}
52
53enum OccupiedState {
54    Pending { waker: Option<Waker> },
55    Canceled,
56    Complete { error: Option<IOError> },
57}
58
59pub enum IssueRequestBuf<'a> {
60    Read { len: usize },
61    Write { buf: &'a [u8] },
62}
63
64impl<'a> IssueRequestBuf<'a> {
65    pub fn new<A: Access>(operation: &'a Operation<'a, A>) -> Self {
66        match operation {
67            Operation::Read { buf, .. } => Self::Read { len: buf.len() },
68            Operation::Write { buf, .. } => Self::Write { buf },
69        }
70    }
71
72    fn len(&self) -> usize {
73        match self {
74            Self::Read { len } => *len,
75            Self::Write { buf } => buf.len(),
76        }
77    }
78
79    fn ty(&self) -> BlockIORequestType {
80        match self {
81            Self::Read { .. } => BlockIORequestType::Read,
82            Self::Write { .. } => BlockIORequestType::Write,
83        }
84    }
85}
86
87pub enum PollRequestBuf<'a> {
88    Read { buf: &'a mut [u8] },
89    Write,
90}
91
92impl<'a> PollRequestBuf<'a> {
93    pub fn new<'b, A: Access>(operation: &'a mut Operation<'b, A>) -> Self
94    where
95        'b: 'a,
96    {
97        match operation {
98            Operation::Read { buf, .. } => Self::Read { buf },
99            Operation::Write { .. } => Self::Write,
100        }
101    }
102}
103
104impl<S: SlotSemaphore, A: AbstractAllocator, F: FnMut()> OwnedSharedRingBufferBlockIO<S, A, F> {
105    pub fn new(
106        dma_region: SharedMemoryRef<'static, [u8]>,
107        bounce_buffer_allocator: A,
108        mut ring_buffers: RingBuffers<'static, Provide, F, BlockIORequest>,
109    ) -> Self {
110        assert!(ring_buffers.free_mut().is_empty().unwrap());
111        assert!(ring_buffers.used_mut().is_empty().unwrap());
112        let n = ring_buffers.free().capacity();
113        Self {
114            dma_region,
115            bounce_buffer_allocator,
116            ring_buffers,
117            requests: SlotTracker::new_with_capacity((), (), n),
118            slot_set_semaphore: SlotSetSemaphore::new([n, n]),
119        }
120    }
121
122    pub fn slot_set_semaphore(&self) -> &SlotSetSemaphoreHandle<S, NUM_SLOT_POOLS> {
123        self.slot_set_semaphore.handle()
124    }
125
126    fn report_current_num_free_current_num_free_ring_buffers_slots(
127        &mut self,
128    ) -> Result<(), ErrorOrUserError> {
129        let current_num_free = self.requests.num_free();
130        self.slot_set_semaphore
131            .report_current_num_free_slots(RING_BUFFERS_SLOT_POOL_INDEX, current_num_free)
132            .unwrap();
133        Ok(())
134    }
135
136    fn report_current_num_free_current_num_free_requests_slots(
137        &mut self,
138    ) -> Result<(), ErrorOrUserError> {
139        let current_num_free = self.ring_buffers.free_mut().num_empty_slots()?;
140        self.slot_set_semaphore
141            .report_current_num_free_slots(REQUESTS_SLOT_POOL_INDEX, current_num_free)
142            .unwrap();
143        Ok(())
144    }
145
146    fn can_issue_requests(
147        &mut self,
148        n: usize,
149    ) -> Result<bool, SharedRingBuffersPeerMisbehaviorError> {
150        let can =
151            self.ring_buffers.free_mut().num_empty_slots()? >= n && self.requests.num_free() >= n;
152        Ok(can)
153    }
154
155    pub fn issue_read_request(
156        &mut self,
157        reservation: &mut SlotSetReservation<'_, S, NUM_SLOT_POOLS>,
158        start_block_idx: u64,
159        num_bytes: usize,
160    ) -> Result<usize, ErrorOrUserError> {
161        self.issue_request(
162            reservation,
163            start_block_idx,
164            &mut IssueRequestBuf::Read { len: num_bytes },
165        )
166    }
167
168    pub fn issue_write_request(
169        &mut self,
170        reservation: &mut SlotSetReservation<'_, S, NUM_SLOT_POOLS>,
171        start_block_idx: u64,
172        buf: &[u8],
173    ) -> Result<usize, ErrorOrUserError> {
174        self.issue_request(
175            reservation,
176            start_block_idx,
177            &mut IssueRequestBuf::Write { buf },
178        )
179    }
180
181    pub fn issue_request(
182        &mut self,
183        reservation: &mut SlotSetReservation<'_, S, NUM_SLOT_POOLS>,
184        start_block_idx: u64,
185        buf: &mut IssueRequestBuf,
186    ) -> Result<usize, ErrorOrUserError> {
187        if reservation.count() < 1 {
188            return Err(UserError::TooManyOutstandingRequests.into());
189        }
190
191        assert!(self.can_issue_requests(1)?);
192
193        let request_index = self.requests.peek_next_free_index().unwrap();
194
195        let allocation = self
196            .bounce_buffer_allocator
197            .allocate(Layout::from_size_align(buf.len(), 1).unwrap())
198            .map_err(|_| Error::BounceBufferAllocationError)?;
199
200        if let IssueRequestBuf::Write { buf } = buf {
201            self.dma_region
202                .as_mut_ptr()
203                .index(allocation.range())
204                .copy_from_slice(buf);
205        }
206
207        let req = BlockIORequest::new(
208            BlockIORequestStatus::Pending,
209            buf.ty(),
210            start_block_idx.try_into().unwrap(),
211            Descriptor::from_encoded_addr_range(allocation.range(), request_index),
212        );
213
214        self.requests
215            .occupy(Occupied {
216                req,
217                state: OccupiedState::Pending { waker: None },
218                allocation,
219            })
220            .unwrap();
221
222        self.ring_buffers
223            .free_mut()
224            .enqueue_and_commit(req)?
225            .unwrap();
226
227        self.ring_buffers.notify_mut();
228
229        self.slot_set_semaphore.consume(reservation, 1).unwrap();
230
231        Ok(request_index)
232    }
233
234    pub fn cancel_request(&mut self, request_index: usize) -> Result<(), ErrorOrUserError> {
235        let state_value = self.requests.get_state_value_mut(request_index)?;
236        let occupied = state_value.as_occupied()?;
237        match &occupied.state {
238            OccupiedState::Pending { .. } => {
239                occupied.state = OccupiedState::Canceled;
240            }
241            OccupiedState::Complete { .. } => {
242                let occupied = self.requests.free(request_index, ()).unwrap();
243                self.bounce_buffer_allocator.deallocate(occupied.allocation);
244                self.report_current_num_free_current_num_free_requests_slots()?;
245            }
246            _ => {
247                return Err(UserError::RequestStateMismatch.into());
248            }
249        }
250        Ok(())
251    }
252
253    pub fn poll_read_request(
254        &mut self,
255        request_index: usize,
256        buf: &mut [u8],
257        waker: Option<Waker>,
258    ) -> Result<Poll<Result<(), IOError>>, ErrorOrUserError> {
259        self.poll_request(request_index, &mut PollRequestBuf::Read { buf }, waker)
260    }
261
262    pub fn poll_write_request(
263        &mut self,
264        request_index: usize,
265        waker: Option<Waker>,
266    ) -> Result<Poll<Result<(), IOError>>, ErrorOrUserError> {
267        self.poll_request(request_index, &mut PollRequestBuf::Write, waker)
268    }
269
270    pub fn poll_request(
271        &mut self,
272        request_index: usize,
273        buf: &mut PollRequestBuf,
274        waker: Option<Waker>,
275    ) -> Result<Poll<Result<(), IOError>>, ErrorOrUserError> {
276        let state_value = self.requests.get_state_value_mut(request_index)?;
277        let occupied = state_value.as_occupied()?;
278
279        Ok(match &mut occupied.state {
280            OccupiedState::Pending {
281                waker: ref mut waker_slot,
282            } => {
283                if let Some(waker) = waker {
284                    waker_slot.replace(waker);
285                }
286                Poll::Pending
287            }
288            OccupiedState::Complete { error } => {
289                let val = match error {
290                    None => Ok(()),
291                    Some(err) => Err(err.clone()),
292                };
293
294                let range = occupied.req.buf().encoded_addr_range();
295
296                match buf {
297                    PollRequestBuf::Read { buf } => {
298                        self.dma_region
299                            .as_mut_ptr()
300                            .index(range.clone())
301                            .copy_into_slice(buf);
302                    }
303                    PollRequestBuf::Write => {}
304                }
305
306                let occupied = self.requests.free(request_index, ()).unwrap();
307                self.bounce_buffer_allocator.deallocate(occupied.allocation);
308                self.report_current_num_free_current_num_free_requests_slots()?;
309
310                Poll::Ready(val)
311            }
312            _ => {
313                return Err(UserError::RequestStateMismatch.into());
314            }
315        })
316    }
317
318    pub fn poll(&mut self) -> Result<bool, ErrorOrUserError> {
319        self.report_current_num_free_current_num_free_ring_buffers_slots()?;
320
321        let mut notify = false;
322
323        while let Some(completed_req) = self.ring_buffers.used_mut().dequeue()? {
324            let request_index = completed_req.buf().cookie();
325
326            let state_value = self
327                .requests
328                .get_state_value_mut(request_index)
329                .map_err(|_| PeerMisbehaviorError::OutOfBoundsCookie)?;
330
331            let occupied = state_value
332                .as_occupied()
333                .map_err(|_| PeerMisbehaviorError::StateMismatch)?;
334
335            {
336                let mut observed_request = completed_req;
337                observed_request.set_status(BlockIORequestStatus::Pending);
338                if observed_request != occupied.req {
339                    return Err(PeerMisbehaviorError::DescriptorMismatch.into());
340                }
341            }
342
343            match &mut occupied.state {
344                OccupiedState::Pending { waker } => {
345                    let waker = waker.take();
346
347                    let status = completed_req
348                        .status()
349                        .map_err(|_| PeerMisbehaviorError::InvalidDescriptor)?;
350
351                    occupied.state = OccupiedState::Complete {
352                        error: match status {
353                            BlockIORequestStatus::Pending => {
354                                return Err(PeerMisbehaviorError::InvalidDescriptor.into());
355                            }
356                            BlockIORequestStatus::Ok => None,
357                            BlockIORequestStatus::IOError => Some(IOError),
358                        },
359                    };
360
361                    if let Some(waker) = waker {
362                        waker.wake();
363                    }
364                }
365                OccupiedState::Canceled => {
366                    let occupied = self.requests.free(request_index, ()).unwrap();
367                    self.bounce_buffer_allocator.deallocate(occupied.allocation);
368                    self.report_current_num_free_current_num_free_requests_slots()?;
369                }
370                _ => {
371                    return Err(UserError::RequestStateMismatch.into());
372                }
373            }
374
375            notify = true;
376        }
377
378        if notify {
379            self.ring_buffers.notify_mut();
380        }
381
382        Ok(notify)
383    }
384}