virtio_drivers/device/socket/
connectionmanager.rs

1use super::{
2    protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
3    VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE,
4};
5use crate::{transport::Transport, Hal, Result};
6use alloc::{boxed::Box, vec::Vec};
7use core::cmp::min;
8use core::convert::TryInto;
9use core::hint::spin_loop;
10use log::debug;
11use zerocopy::FromZeroes;
12
13const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024;
14
15/// A higher level interface for VirtIO socket (vsock) devices.
16///
17/// This keeps track of multiple vsock connections.
18///
19/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
20/// bigger than `size_of::<VirtioVsockHdr>()`.
21///
22/// # Example
23///
24/// ```
25/// # use virtio_drivers::{Error, Hal};
26/// # use virtio_drivers::transport::Transport;
27/// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
28///
29/// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
30/// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
31///
32/// // Start a thread to call `socket.poll()` and handle events.
33///
34/// let remote_address = VsockAddr { cid: 2, port: 42 };
35/// let local_port = 1234;
36/// socket.connect(remote_address, local_port)?;
37///
38/// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
39///
40/// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
41///
42/// socket.shutdown(remote_address, local_port)?;
43/// # Ok(())
44/// # }
45/// ```
46pub struct VsockConnectionManager<
47    H: Hal,
48    T: Transport,
49    const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE,
50> {
51    driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
52    per_connection_buffer_capacity: u32,
53    connections: Vec<Connection>,
54    listening_ports: Vec<u32>,
55}
56
57#[derive(Debug)]
58struct Connection {
59    info: ConnectionInfo,
60    buffer: RingBuffer,
61    /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
62    /// still data in the buffer.
63    peer_requested_shutdown: bool,
64}
65
66impl Connection {
67    fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self {
68        let mut info = ConnectionInfo::new(peer, local_port);
69        info.buf_alloc = buffer_capacity;
70        Self {
71            info,
72            buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()),
73            peer_requested_shutdown: false,
74        }
75    }
76}
77
78impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize>
79    VsockConnectionManager<H, T, RX_BUFFER_SIZE>
80{
81    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
82    pub fn new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self {
83        Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
84    }
85
86    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with
87    /// the given per-connection buffer capacity.
88    pub fn new_with_capacity(
89        driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
90        per_connection_buffer_capacity: u32,
91    ) -> Self {
92        Self {
93            driver,
94            connections: Vec::new(),
95            listening_ports: Vec::new(),
96            per_connection_buffer_capacity,
97        }
98    }
99
100    /// Returns the CID which has been assigned to this guest.
101    pub fn guest_cid(&self) -> u64 {
102        self.driver.guest_cid()
103    }
104
105    /// Allows incoming connections on the given port number.
106    pub fn listen(&mut self, port: u32) {
107        if !self.listening_ports.contains(&port) {
108            self.listening_ports.push(port);
109        }
110    }
111
112    /// Stops allowing incoming connections on the given port number.
113    pub fn unlisten(&mut self, port: u32) {
114        self.listening_ports.retain(|p| *p != port);
115    }
116
117    /// Sends a request to connect to the given destination.
118    ///
119    /// This returns as soon as the request is sent; you should wait until `poll` returns a
120    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
121    /// before sending data.
122    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
123        if self.connections.iter().any(|connection| {
124            connection.info.dst == destination && connection.info.src_port == src_port
125        }) {
126            return Err(SocketError::ConnectionExists.into());
127        }
128
129        let new_connection =
130            Connection::new(destination, src_port, self.per_connection_buffer_capacity);
131
132        self.driver.connect(&new_connection.info)?;
133        debug!("Connection requested: {:?}", new_connection.info);
134        self.connections.push(new_connection);
135        Ok(())
136    }
137
138    /// Sends the buffer to the destination.
139    pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
140        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
141
142        self.driver.send(buffer, &mut connection.info)
143    }
144
145    /// Polls the vsock device to receive data or other updates.
146    pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
147        let guest_cid = self.driver.guest_cid();
148        let connections = &mut self.connections;
149        let per_connection_buffer_capacity = self.per_connection_buffer_capacity;
150
151        let result = self.driver.poll(|event, body| {
152            let connection = get_connection_for_event(connections, &event, guest_cid);
153
154            // Skip events which don't match any connection we know about, unless they are a
155            // connection request.
156            let connection = if let Some((_, connection)) = connection {
157                connection
158            } else if let VsockEventType::ConnectionRequest = event.event_type {
159                // If the requested connection already exists or the CID isn't ours, ignore it.
160                if connection.is_some() || event.destination.cid != guest_cid {
161                    return Ok(None);
162                }
163                // Add the new connection to our list, at least for now. It will be removed again
164                // below if we weren't listening on the port.
165                connections.push(Connection::new(
166                    event.source,
167                    event.destination.port,
168                    per_connection_buffer_capacity,
169                ));
170                connections.last_mut().unwrap()
171            } else {
172                return Ok(None);
173            };
174
175            // Update stored connection info.
176            connection.info.update_for_event(&event);
177
178            if let VsockEventType::Received { length } = event.event_type {
179                // Copy to buffer
180                if !connection.buffer.add(body) {
181                    return Err(SocketError::OutputBufferTooShort(length).into());
182                }
183            }
184
185            Ok(Some(event))
186        })?;
187
188        let Some(event) = result else {
189            return Ok(None);
190        };
191
192        // The connection must exist because we found it above in the callback.
193        let (connection_index, connection) =
194            get_connection_for_event(connections, &event, guest_cid).unwrap();
195
196        match event.event_type {
197            VsockEventType::ConnectionRequest => {
198                if self.listening_ports.contains(&event.destination.port) {
199                    self.driver.accept(&connection.info)?;
200                } else {
201                    // Reject the connection request and remove it from our list.
202                    self.driver.force_close(&connection.info)?;
203                    self.connections.swap_remove(connection_index);
204
205                    // No need to pass the request on to the client, as we've already rejected it.
206                    return Ok(None);
207                }
208            }
209            VsockEventType::Connected => {}
210            VsockEventType::Disconnected { reason } => {
211                // Wait until client reads all data before removing connection.
212                if connection.buffer.is_empty() {
213                    if reason == DisconnectReason::Shutdown {
214                        self.driver.force_close(&connection.info)?;
215                    }
216                    self.connections.swap_remove(connection_index);
217                } else {
218                    connection.peer_requested_shutdown = true;
219                }
220            }
221            VsockEventType::Received { .. } => {
222                // Already copied the buffer in the callback above.
223            }
224            VsockEventType::CreditRequest => {
225                // If the peer requested credit, send an update.
226                self.driver.credit_update(&connection.info)?;
227                // No need to pass the request on to the client, we've already handled it.
228                return Ok(None);
229            }
230            VsockEventType::CreditUpdate => {}
231        }
232
233        Ok(Some(event))
234    }
235
236    /// Reads data received from the given connection.
237    pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
238        let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
239
240        // Copy from ring buffer
241        let bytes_read = connection.buffer.drain(buffer);
242
243        connection.info.done_forwarding(bytes_read);
244
245        // If buffer is now empty and the peer requested shutdown, finish shutting down the
246        // connection.
247        if connection.peer_requested_shutdown && connection.buffer.is_empty() {
248            self.driver.force_close(&connection.info)?;
249            self.connections.swap_remove(connection_index);
250        }
251
252        Ok(bytes_read)
253    }
254
255    /// Returns the number of bytes in the receive buffer available to be read by `recv`.
256    ///
257    /// When the available bytes is 0, it indicates that the receive buffer is empty and does not
258    /// contain any data.
259    pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
260        let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
261        Ok(connection.buffer.used())
262    }
263
264    /// Sends a credit update to the given peer.
265    pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
266        let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
267        self.driver.credit_update(&connection.info)
268    }
269
270    /// Blocks until we get some event from the vsock device.
271    pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
272        loop {
273            if let Some(event) = self.poll()? {
274                return Ok(event);
275            } else {
276                spin_loop();
277            }
278        }
279    }
280
281    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
282    /// any more data.
283    ///
284    /// This returns as soon as the request is sent; you should wait until `poll` returns a
285    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
286    /// shutdown.
287    pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
288        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
289
290        self.driver.shutdown(&connection.info)
291    }
292
293    /// Forcibly closes the connection without waiting for the peer.
294    pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
295        let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
296
297        self.driver.force_close(&connection.info)?;
298
299        self.connections.swap_remove(index);
300        Ok(())
301    }
302}
303
304/// Returns the connection from the given list matching the given peer address and local port, and
305/// its index.
306///
307/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
308fn get_connection(
309    connections: &mut [Connection],
310    peer: VsockAddr,
311    local_port: u32,
312) -> core::result::Result<(usize, &mut Connection), SocketError> {
313    connections
314        .iter_mut()
315        .enumerate()
316        .find(|(_, connection)| {
317            connection.info.dst == peer && connection.info.src_port == local_port
318        })
319        .ok_or(SocketError::NotConnected)
320}
321
322/// Returns the connection from the given list matching the event, if any, and its index.
323fn get_connection_for_event<'a>(
324    connections: &'a mut [Connection],
325    event: &VsockEvent,
326    local_cid: u64,
327) -> Option<(usize, &'a mut Connection)> {
328    connections
329        .iter_mut()
330        .enumerate()
331        .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
332}
333
334#[derive(Debug)]
335struct RingBuffer {
336    buffer: Box<[u8]>,
337    /// The number of bytes currently in the buffer.
338    used: usize,
339    /// The index of the first used byte in the buffer.
340    start: usize,
341}
342
343impl RingBuffer {
344    pub fn new(capacity: usize) -> Self {
345        Self {
346            buffer: FromZeroes::new_box_slice_zeroed(capacity),
347            used: 0,
348            start: 0,
349        }
350    }
351
352    /// Returns the number of bytes currently used in the buffer.
353    pub fn used(&self) -> usize {
354        self.used
355    }
356
357    /// Returns true iff there are currently no bytes in the buffer.
358    pub fn is_empty(&self) -> bool {
359        self.used == 0
360    }
361
362    /// Returns the number of bytes currently free in the buffer.
363    pub fn free(&self) -> usize {
364        self.buffer.len() - self.used
365    }
366
367    /// Adds the given bytes to the buffer if there is enough capacity for them all.
368    ///
369    /// Returns true if they were added, or false if they were not.
370    pub fn add(&mut self, bytes: &[u8]) -> bool {
371        if bytes.len() > self.free() {
372            return false;
373        }
374
375        // The index of the first available position in the buffer.
376        let first_available = (self.start + self.used) % self.buffer.len();
377        // The number of bytes to copy from `bytes` to `buffer` between `first_available` and
378        // `buffer.len()`.
379        let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
380        self.buffer[first_available..first_available + copy_length_before_wraparound]
381            .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
382        if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
383            self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
384        }
385        self.used += bytes.len();
386
387        true
388    }
389
390    /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
391    /// buffer.
392    pub fn drain(&mut self, out: &mut [u8]) -> usize {
393        let bytes_read = min(self.used, out.len());
394
395        // The number of bytes to copy out between `start` and the end of the buffer.
396        let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
397        // The number of bytes to copy out from the beginning of the buffer after wrapping around.
398        let read_after_wraparound = bytes_read
399            .checked_sub(read_before_wraparound)
400            .unwrap_or_default();
401
402        out[0..read_before_wraparound]
403            .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
404        out[read_before_wraparound..bytes_read]
405            .copy_from_slice(&self.buffer[0..read_after_wraparound]);
406
407        self.used -= bytes_read;
408        self.start = (self.start + bytes_read) % self.buffer.len();
409
410        bytes_read
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::{
418        device::socket::{
419            protocol::{
420                SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp,
421            },
422            vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
423        },
424        hal::fake::FakeHal,
425        transport::{
426            fake::{FakeTransport, QueueStatus, State},
427            DeviceType,
428        },
429        volatile::ReadOnly,
430    };
431    use alloc::{sync::Arc, vec};
432    use core::{mem::size_of, ptr::NonNull};
433    use std::{sync::Mutex, thread};
434    use zerocopy::{AsBytes, FromBytes};
435
436    #[test]
437    fn send_recv() {
438        let host_cid = 2;
439        let guest_cid = 66;
440        let host_port = 1234;
441        let guest_port = 4321;
442        let host_address = VsockAddr {
443            cid: host_cid,
444            port: host_port,
445        };
446        let hello_from_guest = "Hello from guest";
447        let hello_from_host = "Hello from host";
448
449        let mut config_space = VirtioVsockConfig {
450            guest_cid_low: ReadOnly::new(66),
451            guest_cid_high: ReadOnly::new(0),
452        };
453        let state = Arc::new(Mutex::new(State {
454            queues: vec![
455                QueueStatus::default(),
456                QueueStatus::default(),
457                QueueStatus::default(),
458            ],
459            ..Default::default()
460        }));
461        let transport = FakeTransport {
462            device_type: DeviceType::Socket,
463            max_queue_size: 32,
464            device_features: 0,
465            config_space: NonNull::from(&mut config_space),
466            state: state.clone(),
467        };
468        let mut socket = VsockConnectionManager::new(
469            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
470        );
471
472        // Start a thread to simulate the device.
473        let handle = thread::spawn(move || {
474            // Wait for connection request.
475            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
476            assert_eq!(
477                VirtioVsockHdr::read_from(
478                    state
479                        .lock()
480                        .unwrap()
481                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
482                        .as_slice()
483                )
484                .unwrap(),
485                VirtioVsockHdr {
486                    op: VirtioVsockOp::Request.into(),
487                    src_cid: guest_cid.into(),
488                    dst_cid: host_cid.into(),
489                    src_port: guest_port.into(),
490                    dst_port: host_port.into(),
491                    len: 0.into(),
492                    socket_type: SocketType::Stream.into(),
493                    flags: 0.into(),
494                    buf_alloc: 1024.into(),
495                    fwd_cnt: 0.into(),
496                }
497            );
498
499            // Accept connection and give the peer enough credit to send the message.
500            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
501                RX_QUEUE_IDX,
502                VirtioVsockHdr {
503                    op: VirtioVsockOp::Response.into(),
504                    src_cid: host_cid.into(),
505                    dst_cid: guest_cid.into(),
506                    src_port: host_port.into(),
507                    dst_port: guest_port.into(),
508                    len: 0.into(),
509                    socket_type: SocketType::Stream.into(),
510                    flags: 0.into(),
511                    buf_alloc: 50.into(),
512                    fwd_cnt: 0.into(),
513                }
514                .as_bytes(),
515            );
516
517            // Expect the guest to send some data.
518            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
519            let request = state
520                .lock()
521                .unwrap()
522                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
523            assert_eq!(
524                request.len(),
525                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
526            );
527            assert_eq!(
528                VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
529                VirtioVsockHdr {
530                    op: VirtioVsockOp::Rw.into(),
531                    src_cid: guest_cid.into(),
532                    dst_cid: host_cid.into(),
533                    src_port: guest_port.into(),
534                    dst_port: host_port.into(),
535                    len: (hello_from_guest.len() as u32).into(),
536                    socket_type: SocketType::Stream.into(),
537                    flags: 0.into(),
538                    buf_alloc: 1024.into(),
539                    fwd_cnt: 0.into(),
540                }
541            );
542            assert_eq!(
543                &request[size_of::<VirtioVsockHdr>()..],
544                hello_from_guest.as_bytes()
545            );
546
547            println!("Host sending");
548
549            // Send a response.
550            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
551            VirtioVsockHdr {
552                op: VirtioVsockOp::Rw.into(),
553                src_cid: host_cid.into(),
554                dst_cid: guest_cid.into(),
555                src_port: host_port.into(),
556                dst_port: guest_port.into(),
557                len: (hello_from_host.len() as u32).into(),
558                socket_type: SocketType::Stream.into(),
559                flags: 0.into(),
560                buf_alloc: 50.into(),
561                fwd_cnt: (hello_from_guest.len() as u32).into(),
562            }
563            .write_to_prefix(response.as_mut_slice());
564            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
565            state
566                .lock()
567                .unwrap()
568                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
569
570            // Expect a shutdown.
571            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
572            assert_eq!(
573                VirtioVsockHdr::read_from(
574                    state
575                        .lock()
576                        .unwrap()
577                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
578                        .as_slice()
579                )
580                .unwrap(),
581                VirtioVsockHdr {
582                    op: VirtioVsockOp::Shutdown.into(),
583                    src_cid: guest_cid.into(),
584                    dst_cid: host_cid.into(),
585                    src_port: guest_port.into(),
586                    dst_port: host_port.into(),
587                    len: 0.into(),
588                    socket_type: SocketType::Stream.into(),
589                    flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(),
590                    buf_alloc: 1024.into(),
591                    fwd_cnt: (hello_from_host.len() as u32).into(),
592                }
593            );
594        });
595
596        socket.connect(host_address, guest_port).unwrap();
597        assert_eq!(
598            socket.wait_for_event().unwrap(),
599            VsockEvent {
600                source: host_address,
601                destination: VsockAddr {
602                    cid: guest_cid,
603                    port: guest_port,
604                },
605                event_type: VsockEventType::Connected,
606                buffer_status: VsockBufferStatus {
607                    buffer_allocation: 50,
608                    forward_count: 0,
609                },
610            }
611        );
612        println!("Guest sending");
613        socket
614            .send(host_address, guest_port, "Hello from guest".as_bytes())
615            .unwrap();
616        println!("Guest waiting to receive.");
617        assert_eq!(
618            socket.wait_for_event().unwrap(),
619            VsockEvent {
620                source: host_address,
621                destination: VsockAddr {
622                    cid: guest_cid,
623                    port: guest_port,
624                },
625                event_type: VsockEventType::Received {
626                    length: hello_from_host.len()
627                },
628                buffer_status: VsockBufferStatus {
629                    buffer_allocation: 50,
630                    forward_count: hello_from_guest.len() as u32,
631                },
632            }
633        );
634        println!("Guest getting received data.");
635        let mut buffer = [0u8; 64];
636        assert_eq!(
637            socket.recv(host_address, guest_port, &mut buffer).unwrap(),
638            hello_from_host.len()
639        );
640        assert_eq!(
641            &buffer[0..hello_from_host.len()],
642            hello_from_host.as_bytes()
643        );
644        socket.shutdown(host_address, guest_port).unwrap();
645
646        handle.join().unwrap();
647    }
648
649    #[test]
650    fn incoming_connection() {
651        let host_cid = 2;
652        let guest_cid = 66;
653        let host_port = 1234;
654        let guest_port = 4321;
655        let wrong_guest_port = 4444;
656        let host_address = VsockAddr {
657            cid: host_cid,
658            port: host_port,
659        };
660
661        let mut config_space = VirtioVsockConfig {
662            guest_cid_low: ReadOnly::new(66),
663            guest_cid_high: ReadOnly::new(0),
664        };
665        let state = Arc::new(Mutex::new(State {
666            queues: vec![
667                QueueStatus::default(),
668                QueueStatus::default(),
669                QueueStatus::default(),
670            ],
671            ..Default::default()
672        }));
673        let transport = FakeTransport {
674            device_type: DeviceType::Socket,
675            max_queue_size: 32,
676            device_features: 0,
677            config_space: NonNull::from(&mut config_space),
678            state: state.clone(),
679        };
680        let mut socket = VsockConnectionManager::new(
681            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
682        );
683
684        socket.listen(guest_port);
685
686        // Start a thread to simulate the device.
687        let handle = thread::spawn(move || {
688            // Send a connection request for a port the guest isn't listening on.
689            println!("Host sending connection request to wrong port");
690            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
691                RX_QUEUE_IDX,
692                VirtioVsockHdr {
693                    op: VirtioVsockOp::Request.into(),
694                    src_cid: host_cid.into(),
695                    dst_cid: guest_cid.into(),
696                    src_port: host_port.into(),
697                    dst_port: wrong_guest_port.into(),
698                    len: 0.into(),
699                    socket_type: SocketType::Stream.into(),
700                    flags: 0.into(),
701                    buf_alloc: 50.into(),
702                    fwd_cnt: 0.into(),
703                }
704                .as_bytes(),
705            );
706
707            // Expect a rejection.
708            println!("Host waiting for rejection");
709            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
710            assert_eq!(
711                VirtioVsockHdr::read_from(
712                    state
713                        .lock()
714                        .unwrap()
715                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
716                        .as_slice()
717                )
718                .unwrap(),
719                VirtioVsockHdr {
720                    op: VirtioVsockOp::Rst.into(),
721                    src_cid: guest_cid.into(),
722                    dst_cid: host_cid.into(),
723                    src_port: wrong_guest_port.into(),
724                    dst_port: host_port.into(),
725                    len: 0.into(),
726                    socket_type: SocketType::Stream.into(),
727                    flags: 0.into(),
728                    buf_alloc: 1024.into(),
729                    fwd_cnt: 0.into(),
730                }
731            );
732
733            // Send a connection request for a port the guest is listening on.
734            println!("Host sending connection request to right port");
735            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
736                RX_QUEUE_IDX,
737                VirtioVsockHdr {
738                    op: VirtioVsockOp::Request.into(),
739                    src_cid: host_cid.into(),
740                    dst_cid: guest_cid.into(),
741                    src_port: host_port.into(),
742                    dst_port: guest_port.into(),
743                    len: 0.into(),
744                    socket_type: SocketType::Stream.into(),
745                    flags: 0.into(),
746                    buf_alloc: 50.into(),
747                    fwd_cnt: 0.into(),
748                }
749                .as_bytes(),
750            );
751
752            // Expect a response.
753            println!("Host waiting for response");
754            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
755            assert_eq!(
756                VirtioVsockHdr::read_from(
757                    state
758                        .lock()
759                        .unwrap()
760                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
761                        .as_slice()
762                )
763                .unwrap(),
764                VirtioVsockHdr {
765                    op: VirtioVsockOp::Response.into(),
766                    src_cid: guest_cid.into(),
767                    dst_cid: host_cid.into(),
768                    src_port: guest_port.into(),
769                    dst_port: host_port.into(),
770                    len: 0.into(),
771                    socket_type: SocketType::Stream.into(),
772                    flags: 0.into(),
773                    buf_alloc: 1024.into(),
774                    fwd_cnt: 0.into(),
775                }
776            );
777
778            println!("Host finished");
779        });
780
781        // Expect an incoming connection.
782        println!("Guest expecting incoming connection.");
783        assert_eq!(
784            socket.wait_for_event().unwrap(),
785            VsockEvent {
786                source: host_address,
787                destination: VsockAddr {
788                    cid: guest_cid,
789                    port: guest_port,
790                },
791                event_type: VsockEventType::ConnectionRequest,
792                buffer_status: VsockBufferStatus {
793                    buffer_allocation: 50,
794                    forward_count: 0,
795                },
796            }
797        );
798
799        handle.join().unwrap();
800    }
801}