virtio_drivers/device/socket/
connectionmanager.rs

1use super::{
2    protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
3    VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE,
4};
5use crate::{transport::Transport, Hal, Result};
6use alloc::{boxed::Box, vec::Vec};
7use core::cmp::min;
8use core::convert::TryInto;
9use core::hint::spin_loop;
10use log::debug;
11use zerocopy::FromZeros;
12
13const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024;
14
15/// A higher level interface for VirtIO socket (vsock) devices.
16///
17/// This keeps track of multiple vsock connections.
18///
19/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
20/// bigger than `size_of::<VirtioVsockHdr>()`.
21///
22/// # Example
23///
24/// ```
25/// # use virtio_drivers::{Error, Hal};
26/// # use virtio_drivers::transport::Transport;
27/// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
28///
29/// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
30/// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
31///
32/// // Start a thread to call `socket.poll()` and handle events.
33///
34/// let remote_address = VsockAddr { cid: 2, port: 42 };
35/// let local_port = 1234;
36/// socket.connect(remote_address, local_port)?;
37///
38/// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
39///
40/// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
41///
42/// socket.shutdown(remote_address, local_port)?;
43/// # Ok(())
44/// # }
45/// ```
46pub struct VsockConnectionManager<
47    H: Hal,
48    T: Transport,
49    const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE,
50> {
51    driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
52    per_connection_buffer_capacity: u32,
53    connections: Vec<Connection>,
54    listening_ports: Vec<u32>,
55}
56
57#[derive(Debug)]
58struct Connection {
59    info: ConnectionInfo,
60    buffer: RingBuffer,
61    /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
62    /// still data in the buffer.
63    peer_requested_shutdown: bool,
64}
65
66impl Connection {
67    fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self {
68        let mut info = ConnectionInfo::new(peer, local_port);
69        info.buf_alloc = buffer_capacity;
70        Self {
71            info,
72            buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()),
73            peer_requested_shutdown: false,
74        }
75    }
76}
77
78impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize>
79    VsockConnectionManager<H, T, RX_BUFFER_SIZE>
80{
81    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
82    pub fn new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self {
83        Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
84    }
85
86    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with
87    /// the given per-connection buffer capacity.
88    pub fn new_with_capacity(
89        driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
90        per_connection_buffer_capacity: u32,
91    ) -> Self {
92        Self {
93            driver,
94            connections: Vec::new(),
95            listening_ports: Vec::new(),
96            per_connection_buffer_capacity,
97        }
98    }
99
100    /// Returns the CID which has been assigned to this guest.
101    pub fn guest_cid(&self) -> u64 {
102        self.driver.guest_cid()
103    }
104
105    /// Allows incoming connections on the given port number.
106    pub fn listen(&mut self, port: u32) {
107        if !self.listening_ports.contains(&port) {
108            self.listening_ports.push(port);
109        }
110    }
111
112    /// Stops allowing incoming connections on the given port number.
113    pub fn unlisten(&mut self, port: u32) {
114        self.listening_ports.retain(|p| *p != port);
115    }
116
117    /// Sends a request to connect to the given destination.
118    ///
119    /// This returns as soon as the request is sent; you should wait until `poll` returns a
120    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
121    /// before sending data.
122    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
123        if self.connections.iter().any(|connection| {
124            connection.info.dst == destination && connection.info.src_port == src_port
125        }) {
126            return Err(SocketError::ConnectionExists.into());
127        }
128
129        let new_connection =
130            Connection::new(destination, src_port, self.per_connection_buffer_capacity);
131
132        self.driver.connect(&new_connection.info)?;
133        debug!("Connection requested: {:?}", new_connection.info);
134        self.connections.push(new_connection);
135        Ok(())
136    }
137
138    /// Sends the buffer to the destination.
139    pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
140        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
141
142        self.driver.send(buffer, &mut connection.info)
143    }
144
145    /// Polls the vsock device to receive data or other updates.
146    pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
147        let guest_cid = self.driver.guest_cid();
148        let connections = &mut self.connections;
149        let per_connection_buffer_capacity = self.per_connection_buffer_capacity;
150
151        let result = self.driver.poll(|event, body| {
152            let connection = get_connection_for_event(connections, &event, guest_cid);
153
154            // Skip events which don't match any connection we know about, unless they are a
155            // connection request.
156            let connection = if let Some((_, connection)) = connection {
157                connection
158            } else if let VsockEventType::ConnectionRequest = event.event_type {
159                // If the requested connection already exists or the CID isn't ours, ignore it.
160                if connection.is_some() || event.destination.cid != guest_cid {
161                    return Ok(None);
162                }
163                // Add the new connection to our list, at least for now. It will be removed again
164                // below if we weren't listening on the port.
165                connections.push(Connection::new(
166                    event.source,
167                    event.destination.port,
168                    per_connection_buffer_capacity,
169                ));
170                connections.last_mut().unwrap()
171            } else {
172                return Ok(None);
173            };
174
175            // Update stored connection info.
176            connection.info.update_for_event(&event);
177
178            if let VsockEventType::Received { length } = event.event_type {
179                // Copy to buffer
180                if !connection.buffer.add(body) {
181                    return Err(SocketError::OutputBufferTooShort(length).into());
182                }
183            }
184
185            Ok(Some(event))
186        })?;
187
188        let Some(event) = result else {
189            return Ok(None);
190        };
191
192        // The connection must exist because we found it above in the callback.
193        let (connection_index, connection) =
194            get_connection_for_event(connections, &event, guest_cid).unwrap();
195
196        match event.event_type {
197            VsockEventType::ConnectionRequest => {
198                if self.listening_ports.contains(&event.destination.port) {
199                    self.driver.accept(&connection.info)?;
200                } else {
201                    // Reject the connection request and remove it from our list.
202                    self.driver.force_close(&connection.info)?;
203                    self.connections.swap_remove(connection_index);
204
205                    // No need to pass the request on to the client, as we've already rejected it.
206                    return Ok(None);
207                }
208            }
209            VsockEventType::Connected => {}
210            VsockEventType::Disconnected { reason } => {
211                // Wait until client reads all data before removing connection.
212                if connection.buffer.is_empty() {
213                    if reason == DisconnectReason::Shutdown {
214                        self.driver.force_close(&connection.info)?;
215                    }
216                    self.connections.swap_remove(connection_index);
217                } else {
218                    connection.peer_requested_shutdown = true;
219                }
220            }
221            VsockEventType::Received { .. } => {
222                // Already copied the buffer in the callback above.
223            }
224            VsockEventType::CreditRequest => {
225                // If the peer requested credit, send an update.
226                self.driver.credit_update(&connection.info)?;
227                // No need to pass the request on to the client, we've already handled it.
228                return Ok(None);
229            }
230            VsockEventType::CreditUpdate => {}
231        }
232
233        Ok(Some(event))
234    }
235
236    /// Reads data received from the given connection.
237    pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
238        let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
239
240        // Copy from ring buffer
241        let bytes_read = connection.buffer.drain(buffer);
242
243        connection.info.done_forwarding(bytes_read);
244
245        // If buffer is now empty and the peer requested shutdown, finish shutting down the
246        // connection.
247        if connection.peer_requested_shutdown && connection.buffer.is_empty() {
248            self.driver.force_close(&connection.info)?;
249            self.connections.swap_remove(connection_index);
250        }
251
252        Ok(bytes_read)
253    }
254
255    /// Returns the number of bytes in the receive buffer available to be read by `recv`.
256    ///
257    /// When the available bytes is 0, it indicates that the receive buffer is empty and does not
258    /// contain any data.
259    pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
260        let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
261        Ok(connection.buffer.used())
262    }
263
264    /// Sends a credit update to the given peer.
265    pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
266        let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
267        self.driver.credit_update(&connection.info)
268    }
269
270    /// Blocks until we get some event from the vsock device.
271    pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
272        loop {
273            if let Some(event) = self.poll()? {
274                return Ok(event);
275            } else {
276                spin_loop();
277            }
278        }
279    }
280
281    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
282    /// any more data.
283    ///
284    /// This returns as soon as the request is sent; you should wait until `poll` returns a
285    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
286    /// shutdown.
287    pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
288        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
289
290        self.driver.shutdown(&connection.info)
291    }
292
293    /// Forcibly closes the connection without waiting for the peer.
294    pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
295        let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
296
297        self.driver.force_close(&connection.info)?;
298
299        self.connections.swap_remove(index);
300        Ok(())
301    }
302}
303
304/// Returns the connection from the given list matching the given peer address and local port, and
305/// its index.
306///
307/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
308fn get_connection(
309    connections: &mut [Connection],
310    peer: VsockAddr,
311    local_port: u32,
312) -> core::result::Result<(usize, &mut Connection), SocketError> {
313    connections
314        .iter_mut()
315        .enumerate()
316        .find(|(_, connection)| {
317            connection.info.dst == peer && connection.info.src_port == local_port
318        })
319        .ok_or(SocketError::NotConnected)
320}
321
322/// Returns the connection from the given list matching the event, if any, and its index.
323fn get_connection_for_event<'a>(
324    connections: &'a mut [Connection],
325    event: &VsockEvent,
326    local_cid: u64,
327) -> Option<(usize, &'a mut Connection)> {
328    connections
329        .iter_mut()
330        .enumerate()
331        .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
332}
333
334#[derive(Debug)]
335struct RingBuffer {
336    buffer: Box<[u8]>,
337    /// The number of bytes currently in the buffer.
338    used: usize,
339    /// The index of the first used byte in the buffer.
340    start: usize,
341}
342
343impl RingBuffer {
344    pub fn new(capacity: usize) -> Self {
345        Self {
346            buffer: FromZeros::new_box_zeroed_with_elems(capacity).unwrap(),
347            used: 0,
348            start: 0,
349        }
350    }
351
352    /// Returns the number of bytes currently used in the buffer.
353    pub fn used(&self) -> usize {
354        self.used
355    }
356
357    /// Returns true iff there are currently no bytes in the buffer.
358    pub fn is_empty(&self) -> bool {
359        self.used == 0
360    }
361
362    /// Returns the number of bytes currently free in the buffer.
363    pub fn free(&self) -> usize {
364        self.buffer.len() - self.used
365    }
366
367    /// Adds the given bytes to the buffer if there is enough capacity for them all.
368    ///
369    /// Returns true if they were added, or false if they were not.
370    pub fn add(&mut self, bytes: &[u8]) -> bool {
371        if bytes.len() > self.free() {
372            return false;
373        }
374
375        // The index of the first available position in the buffer.
376        let first_available = (self.start + self.used) % self.buffer.len();
377        // The number of bytes to copy from `bytes` to `buffer` between `first_available` and
378        // `buffer.len()`.
379        let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
380        self.buffer[first_available..first_available + copy_length_before_wraparound]
381            .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
382        if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
383            self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
384        }
385        self.used += bytes.len();
386
387        true
388    }
389
390    /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
391    /// buffer.
392    pub fn drain(&mut self, out: &mut [u8]) -> usize {
393        let bytes_read = min(self.used, out.len());
394
395        // The number of bytes to copy out between `start` and the end of the buffer.
396        let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
397        // The number of bytes to copy out from the beginning of the buffer after wrapping around.
398        let read_after_wraparound = bytes_read
399            .checked_sub(read_before_wraparound)
400            .unwrap_or_default();
401
402        out[0..read_before_wraparound]
403            .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
404        out[read_before_wraparound..bytes_read]
405            .copy_from_slice(&self.buffer[0..read_after_wraparound]);
406
407        self.used -= bytes_read;
408        self.start = (self.start + bytes_read) % self.buffer.len();
409
410        bytes_read
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::{
418        config::ReadOnly,
419        device::socket::{
420            protocol::{
421                SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp,
422            },
423            vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
424        },
425        hal::fake::FakeHal,
426        transport::{
427            fake::{FakeTransport, QueueStatus, State},
428            DeviceType,
429        },
430    };
431    use alloc::{sync::Arc, vec};
432    use core::mem::size_of;
433    use std::{sync::Mutex, thread};
434    use zerocopy::{FromBytes, IntoBytes};
435
436    #[test]
437    fn send_recv() {
438        let host_cid = 2;
439        let guest_cid = 66;
440        let host_port = 1234;
441        let guest_port = 4321;
442        let host_address = VsockAddr {
443            cid: host_cid,
444            port: host_port,
445        };
446        let hello_from_guest = "Hello from guest";
447        let hello_from_host = "Hello from host";
448
449        let config_space = VirtioVsockConfig {
450            guest_cid_low: ReadOnly::new(66),
451            guest_cid_high: ReadOnly::new(0),
452        };
453        let state = Arc::new(Mutex::new(State::new(
454            vec![
455                QueueStatus::default(),
456                QueueStatus::default(),
457                QueueStatus::default(),
458            ],
459            config_space,
460        )));
461        let transport = FakeTransport {
462            device_type: DeviceType::Socket,
463            max_queue_size: 32,
464            device_features: 0,
465            state: state.clone(),
466        };
467        let mut socket = VsockConnectionManager::new(
468            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
469        );
470
471        // Start a thread to simulate the device.
472        let handle = thread::spawn(move || {
473            // Wait for connection request.
474            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
475            assert_eq!(
476                VirtioVsockHdr::read_from_bytes(
477                    state
478                        .lock()
479                        .unwrap()
480                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
481                        .as_slice()
482                )
483                .unwrap(),
484                VirtioVsockHdr {
485                    op: VirtioVsockOp::Request.into(),
486                    src_cid: guest_cid.into(),
487                    dst_cid: host_cid.into(),
488                    src_port: guest_port.into(),
489                    dst_port: host_port.into(),
490                    len: 0.into(),
491                    socket_type: SocketType::Stream.into(),
492                    flags: 0.into(),
493                    buf_alloc: 1024.into(),
494                    fwd_cnt: 0.into(),
495                }
496            );
497
498            // Accept connection and give the peer enough credit to send the message.
499            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
500                RX_QUEUE_IDX,
501                VirtioVsockHdr {
502                    op: VirtioVsockOp::Response.into(),
503                    src_cid: host_cid.into(),
504                    dst_cid: guest_cid.into(),
505                    src_port: host_port.into(),
506                    dst_port: guest_port.into(),
507                    len: 0.into(),
508                    socket_type: SocketType::Stream.into(),
509                    flags: 0.into(),
510                    buf_alloc: 50.into(),
511                    fwd_cnt: 0.into(),
512                }
513                .as_bytes(),
514            );
515
516            // Expect the guest to send some data.
517            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
518            let request = state
519                .lock()
520                .unwrap()
521                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
522            assert_eq!(
523                request.len(),
524                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
525            );
526            assert_eq!(
527                VirtioVsockHdr::read_from_prefix(request.as_slice())
528                    .unwrap()
529                    .0,
530                VirtioVsockHdr {
531                    op: VirtioVsockOp::Rw.into(),
532                    src_cid: guest_cid.into(),
533                    dst_cid: host_cid.into(),
534                    src_port: guest_port.into(),
535                    dst_port: host_port.into(),
536                    len: (hello_from_guest.len() as u32).into(),
537                    socket_type: SocketType::Stream.into(),
538                    flags: 0.into(),
539                    buf_alloc: 1024.into(),
540                    fwd_cnt: 0.into(),
541                }
542            );
543            assert_eq!(
544                &request[size_of::<VirtioVsockHdr>()..],
545                hello_from_guest.as_bytes()
546            );
547
548            println!("Host sending");
549
550            // Send a response.
551            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
552            VirtioVsockHdr {
553                op: VirtioVsockOp::Rw.into(),
554                src_cid: host_cid.into(),
555                dst_cid: guest_cid.into(),
556                src_port: host_port.into(),
557                dst_port: guest_port.into(),
558                len: (hello_from_host.len() as u32).into(),
559                socket_type: SocketType::Stream.into(),
560                flags: 0.into(),
561                buf_alloc: 50.into(),
562                fwd_cnt: (hello_from_guest.len() as u32).into(),
563            }
564            .write_to_prefix(response.as_mut_slice())
565            .unwrap();
566            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
567            state
568                .lock()
569                .unwrap()
570                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
571
572            // Expect a shutdown.
573            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
574            assert_eq!(
575                VirtioVsockHdr::read_from_bytes(
576                    state
577                        .lock()
578                        .unwrap()
579                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
580                        .as_slice()
581                )
582                .unwrap(),
583                VirtioVsockHdr {
584                    op: VirtioVsockOp::Shutdown.into(),
585                    src_cid: guest_cid.into(),
586                    dst_cid: host_cid.into(),
587                    src_port: guest_port.into(),
588                    dst_port: host_port.into(),
589                    len: 0.into(),
590                    socket_type: SocketType::Stream.into(),
591                    flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(),
592                    buf_alloc: 1024.into(),
593                    fwd_cnt: (hello_from_host.len() as u32).into(),
594                }
595            );
596        });
597
598        socket.connect(host_address, guest_port).unwrap();
599        assert_eq!(
600            socket.wait_for_event().unwrap(),
601            VsockEvent {
602                source: host_address,
603                destination: VsockAddr {
604                    cid: guest_cid,
605                    port: guest_port,
606                },
607                event_type: VsockEventType::Connected,
608                buffer_status: VsockBufferStatus {
609                    buffer_allocation: 50,
610                    forward_count: 0,
611                },
612            }
613        );
614        println!("Guest sending");
615        socket
616            .send(host_address, guest_port, "Hello from guest".as_bytes())
617            .unwrap();
618        println!("Guest waiting to receive.");
619        assert_eq!(
620            socket.wait_for_event().unwrap(),
621            VsockEvent {
622                source: host_address,
623                destination: VsockAddr {
624                    cid: guest_cid,
625                    port: guest_port,
626                },
627                event_type: VsockEventType::Received {
628                    length: hello_from_host.len()
629                },
630                buffer_status: VsockBufferStatus {
631                    buffer_allocation: 50,
632                    forward_count: hello_from_guest.len() as u32,
633                },
634            }
635        );
636        println!("Guest getting received data.");
637        let mut buffer = [0u8; 64];
638        assert_eq!(
639            socket.recv(host_address, guest_port, &mut buffer).unwrap(),
640            hello_from_host.len()
641        );
642        assert_eq!(
643            &buffer[0..hello_from_host.len()],
644            hello_from_host.as_bytes()
645        );
646        socket.shutdown(host_address, guest_port).unwrap();
647
648        handle.join().unwrap();
649    }
650
651    #[test]
652    fn incoming_connection() {
653        let host_cid = 2;
654        let guest_cid = 66;
655        let host_port = 1234;
656        let guest_port = 4321;
657        let wrong_guest_port = 4444;
658        let host_address = VsockAddr {
659            cid: host_cid,
660            port: host_port,
661        };
662
663        let config_space = VirtioVsockConfig {
664            guest_cid_low: ReadOnly::new(66),
665            guest_cid_high: ReadOnly::new(0),
666        };
667        let state = Arc::new(Mutex::new(State::new(
668            vec![
669                QueueStatus::default(),
670                QueueStatus::default(),
671                QueueStatus::default(),
672            ],
673            config_space,
674        )));
675        let transport = FakeTransport {
676            device_type: DeviceType::Socket,
677            max_queue_size: 32,
678            device_features: 0,
679            state: state.clone(),
680        };
681        let mut socket = VsockConnectionManager::new(
682            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
683        );
684
685        socket.listen(guest_port);
686
687        // Start a thread to simulate the device.
688        let handle = thread::spawn(move || {
689            // Send a connection request for a port the guest isn't listening on.
690            println!("Host sending connection request to wrong port");
691            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
692                RX_QUEUE_IDX,
693                VirtioVsockHdr {
694                    op: VirtioVsockOp::Request.into(),
695                    src_cid: host_cid.into(),
696                    dst_cid: guest_cid.into(),
697                    src_port: host_port.into(),
698                    dst_port: wrong_guest_port.into(),
699                    len: 0.into(),
700                    socket_type: SocketType::Stream.into(),
701                    flags: 0.into(),
702                    buf_alloc: 50.into(),
703                    fwd_cnt: 0.into(),
704                }
705                .as_bytes(),
706            );
707
708            // Expect a rejection.
709            println!("Host waiting for rejection");
710            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
711            assert_eq!(
712                VirtioVsockHdr::read_from_bytes(
713                    state
714                        .lock()
715                        .unwrap()
716                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
717                        .as_slice()
718                )
719                .unwrap(),
720                VirtioVsockHdr {
721                    op: VirtioVsockOp::Rst.into(),
722                    src_cid: guest_cid.into(),
723                    dst_cid: host_cid.into(),
724                    src_port: wrong_guest_port.into(),
725                    dst_port: host_port.into(),
726                    len: 0.into(),
727                    socket_type: SocketType::Stream.into(),
728                    flags: 0.into(),
729                    buf_alloc: 1024.into(),
730                    fwd_cnt: 0.into(),
731                }
732            );
733
734            // Send a connection request for a port the guest is listening on.
735            println!("Host sending connection request to right port");
736            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
737                RX_QUEUE_IDX,
738                VirtioVsockHdr {
739                    op: VirtioVsockOp::Request.into(),
740                    src_cid: host_cid.into(),
741                    dst_cid: guest_cid.into(),
742                    src_port: host_port.into(),
743                    dst_port: guest_port.into(),
744                    len: 0.into(),
745                    socket_type: SocketType::Stream.into(),
746                    flags: 0.into(),
747                    buf_alloc: 50.into(),
748                    fwd_cnt: 0.into(),
749                }
750                .as_bytes(),
751            );
752
753            // Expect a response.
754            println!("Host waiting for response");
755            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
756            assert_eq!(
757                VirtioVsockHdr::read_from_bytes(
758                    state
759                        .lock()
760                        .unwrap()
761                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
762                        .as_slice()
763                )
764                .unwrap(),
765                VirtioVsockHdr {
766                    op: VirtioVsockOp::Response.into(),
767                    src_cid: guest_cid.into(),
768                    dst_cid: host_cid.into(),
769                    src_port: guest_port.into(),
770                    dst_port: host_port.into(),
771                    len: 0.into(),
772                    socket_type: SocketType::Stream.into(),
773                    flags: 0.into(),
774                    buf_alloc: 1024.into(),
775                    fwd_cnt: 0.into(),
776                }
777            );
778
779            println!("Host finished");
780        });
781
782        // Expect an incoming connection.
783        println!("Guest expecting incoming connection.");
784        assert_eq!(
785            socket.wait_for_event().unwrap(),
786            VsockEvent {
787                source: host_address,
788                destination: VsockAddr {
789                    cid: guest_cid,
790                    port: guest_port,
791                },
792                event_type: VsockEventType::ConnectionRequest,
793                buffer_status: VsockBufferStatus {
794                    buffer_allocation: 50,
795                    forward_count: 0,
796                },
797            }
798        );
799
800        handle.join().unwrap();
801    }
802}