diff --git a/src/executor/vsock.rs b/src/executor/vsock.rs index e1d429af4e..930f40d9db 100644 --- a/src/executor/vsock.rs +++ b/src/executor/vsock.rs @@ -107,12 +107,13 @@ async fn vsock_run() { let port = header.dst_port.to_ne(); let type_ = Type::try_from(header.type_.to_ne()).unwrap(); let mut vsock_guard = VSOCK_MAP.lock(); + let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap(); if let Some(raw) = vsock_guard.get_mut_socket(port) { if op == Op::Request && raw.state == VsockState::Listen && type_ == Type::Stream { raw.state = VsockState::ReceiveRequest; - raw.remote_cid = header.src_cid.to_ne().try_into().unwrap(); + raw.remote_cid = header_cid; raw.remote_port = header.src_port.to_ne(); raw.peer_buf_alloc = header.buf_alloc.to_ne(); raw.rx_waker.wake(); @@ -121,21 +122,38 @@ async fn vsock_run() { && type_ == Type::Stream && op == Op::Rw { - raw.buffer.extend_from_slice(data); - raw.fwd_cnt = raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap()); - raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); - raw.tx_waker.wake(); - raw.rx_waker.wake(); - hdr = Some(*header); - fwd_cnt = raw.fwd_cnt; + if raw.remote_cid == header_cid { + raw.buffer.extend_from_slice(data); + raw.fwd_cnt = + raw.fwd_cnt.wrapping_add(u32::try_from(data.len()).unwrap()); + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + raw.rx_waker.wake(); + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; + } else { + trace!("Receive message from invalid source {}", header_cid); + } } else if op == Op::CreditUpdate { - raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); - raw.tx_waker.wake(); + if raw.remote_cid == header_cid { + raw.peer_fwd_cnt = header.fwd_cnt.to_ne(); + raw.tx_waker.wake(); + } else { + trace!("Receive message from invalid source {}", header_cid); + } } else if op == Op::Shutdown { - raw.state = VsockState::Shutdown; + if raw.remote_cid == header_cid { + raw.state = VsockState::Shutdown; + } else { + trace!("Receive message from invalid source {}", header_cid); + } } else { - hdr = Some(*header); - fwd_cnt = raw.fwd_cnt; + if raw.remote_cid == header_cid { + hdr = Some(*header); + fwd_cnt = raw.fwd_cnt; + } else { + trace!("Receive message from invalid source {}", header_cid); + } } } });