Browse Source

Separate TCP implementation from the host bindings (#8282)

* Move bind into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move start_connect into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move finish_connect into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move *_listen into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move accept into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move address methods into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move various option methods into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move shutdown methods into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move finish bind methods into Tcp type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Change connect's return type

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move shutdown over to io::Result

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Rearrange some code

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

* Move bind to io Error

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>

---------

Signed-off-by: Ryan Levick <ryan.levick@fermyon.com>
pull/8294/head
Ryan Levick 7 months ago
committed by GitHub
parent
commit
857de9a0ab
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 46
      crates/wasi/src/host/network.rs
  2. 448
      crates/wasi/src/host/tcp.rs
  3. 705
      crates/wasi/src/tcp.rs

46
crates/wasi/src/host/network.rs

@ -212,29 +212,34 @@ impl From<cap_net_ext::AddressFamily> for IpAddressFamily {
}
pub(crate) mod util {
use std::io;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::time::Duration;
use crate::bindings::sockets::network::ErrorCode;
use crate::network::SocketAddressFamily;
use crate::SocketResult;
use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt};
use rustix::fd::{AsFd, OwnedFd};
use rustix::io::Errno;
use rustix::net::sockopt;
pub fn validate_unicast(addr: &SocketAddr) -> SocketResult<()> {
pub fn validate_unicast(addr: &SocketAddr) -> io::Result<()> {
match to_canonical(&addr.ip()) {
IpAddr::V4(ipv4) => {
if ipv4.is_multicast() || ipv4.is_broadcast() {
Err(ErrorCode::InvalidArgument.into())
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Both IPv4 broadcast and multicast addresses are not supported",
))
} else {
Ok(())
}
}
IpAddr::V6(ipv6) => {
if ipv6.is_multicast() {
Err(ErrorCode::InvalidArgument.into())
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"IPv6 multicast addresses are not supported",
))
} else {
Ok(())
}
@ -242,13 +247,19 @@ pub(crate) mod util {
}
}
pub fn validate_remote_address(addr: &SocketAddr) -> SocketResult<()> {
pub fn validate_remote_address(addr: &SocketAddr) -> io::Result<()> {
if to_canonical(&addr.ip()).is_unspecified() {
return Err(ErrorCode::InvalidArgument.into());
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Remote address may not be `0.0.0.0` or `::`",
));
}
if addr.port() == 0 {
return Err(ErrorCode::InvalidArgument.into());
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Remote port may not be 0",
));
}
Ok(())
@ -257,7 +268,7 @@ pub(crate) mod util {
pub fn validate_address_family(
addr: &SocketAddr,
socket_family: &SocketAddressFamily,
) -> SocketResult<()> {
) -> io::Result<()> {
match (socket_family, addr.ip()) {
(SocketAddressFamily::Ipv4, IpAddr::V4(_)) => Ok(()),
(SocketAddressFamily::Ipv6, IpAddr::V6(ipv6)) => {
@ -266,14 +277,23 @@ pub(crate) mod util {
// since 2006, OS handling of them is inconsistent and our own
// validations don't take them into account either.
// Note that these are not the same as IPv4-*mapped* IPv6 addresses.
Err(ErrorCode::InvalidArgument.into())
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"IPv4-compatible IPv6 addresses are not supported",
))
} else if ipv6.to_ipv4_mapped().is_some() {
Err(ErrorCode::InvalidArgument.into())
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"IPv4-mapped IPv6 address passed to an IPv6-only socket",
))
} else {
Ok(())
}
}
_ => Err(ErrorCode::InvalidArgument.into()),
_ => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Address family mismatch",
)),
}
}
@ -301,7 +321,7 @@ pub(crate) mod util {
* Syscalls wrappers with (opinionated) portability fixes.
*/
pub fn udp_socket(family: AddressFamily, blocking: Blocking) -> std::io::Result<OwnedFd> {
pub fn udp_socket(family: AddressFamily, blocking: Blocking) -> io::Result<OwnedFd> {
// Delegate socket creation to cap_net_ext. They handle a couple of things for us:
// - On Windows: call WSAStartup if not done before.
// - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation,

448
crates/wasi/src/host/tcp.rs

@ -1,22 +1,14 @@
use crate::host::network::util;
use crate::network::SocketAddrUse;
use crate::runtime::with_ambient_tokio_runtime;
use crate::tcp::{TcpReadStream, TcpSocket, TcpState, TcpWriteStream};
use crate::{
bindings::{
io::streams::{InputStream, OutputStream},
sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network},
sockets::network::{IpAddressFamily, IpSocketAddress, Network},
sockets::tcp::{self, ShutdownType},
},
network::SocketAddressFamily,
};
use crate::{Pollable, SocketResult, WasiView};
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use rustix::net::sockopt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
use wasmtime::component::Resource;
@ -31,60 +23,14 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
) -> SocketResult<()> {
self.ctx().allowed_network_uses.check_allowed_tcp()?;
let table = self.table();
let socket = table.get(&this)?;
let network = table.get(&network)?;
let local_address: SocketAddr = local_address.into();
let tokio_socket = match &socket.tcp_state {
TcpState::Default(socket) => socket,
TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()),
_ => return Err(ErrorCode::InvalidState.into()),
};
util::validate_unicast(&local_address)?;
util::validate_address_family(&local_address, &socket.family)?;
{
// Ensure that we're allowed to connect to this address.
network.check_socket_addr(&local_address, SocketAddrUse::TcpBind)?;
// Automatically bypass the TIME_WAIT state when the user is trying
// to bind to a specific port:
let reuse_addr = local_address.port() > 0;
// Unconditionally (re)set SO_REUSEADDR, even when the value is false.
// This ensures we're not accidentally affected by any socket option
// state left behind by a previous failed call to this method (start_bind).
util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?;
// Perform the OS bind call.
tokio_socket.bind(local_address).map_err(|error| {
match Errno::from_io_error(&error) {
// From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html:
// > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket
//
// The most common reasons for this error should have already
// been handled by our own validation slightly higher up in this
// function. This error mapping is here just in case there is
// an edge case we didn't catch.
Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument,
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS
// Windows returns WSAENOBUFS when the ephemeral ports have been exhausted.
#[cfg(windows)]
Some(Errno::NOBUFS) => ErrorCode::AddressInUse,
_ => ErrorCode::from(error),
}
})?;
}
let socket = table.get_mut(&this)?;
// Ensure that we're allowed to connect to this address.
network.check_socket_addr(&local_address, SocketAddrUse::TcpBind)?;
socket.tcp_state = match std::mem::replace(&mut socket.tcp_state, TcpState::Closed) {
TcpState::Default(socket) => TcpState::BindStarted(socket),
_ => unreachable!(),
};
// Bind to the address.
table.get_mut(&this)?.start_bind(local_address)?;
Ok(())
}
@ -93,17 +39,7 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get_mut(&this)?;
match socket.tcp_state {
TcpState::BindStarted(..) => {}
_ => return Err(ErrorCode::NotInProgress.into()),
}
socket.tcp_state = match std::mem::replace(&mut socket.tcp_state, TcpState::Closed) {
TcpState::BindStarted(socket) => TcpState::Bound(socket),
_ => unreachable!(),
};
Ok(())
socket.finish_bind()
}
fn start_connect(
@ -114,37 +50,14 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
) -> SocketResult<()> {
self.ctx().allowed_network_uses.check_allowed_tcp()?;
let table = self.table();
let socket = table.get(&this)?;
let network = table.get(&network)?;
let remote_address: SocketAddr = remote_address.into();
match socket.tcp_state {
TcpState::Default(..) => {}
TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
_ => return Err(ErrorCode::InvalidState.into()),
};
util::validate_unicast(&remote_address)?;
util::validate_remote_address(&remote_address)?;
util::validate_address_family(&remote_address, &socket.family)?;
// Ensure that we're allowed to connect to this address.
network.check_socket_addr(&remote_address, SocketAddrUse::TcpConnect)?;
let socket = table.get_mut(&this)?;
let TcpState::Default(tokio_socket) =
std::mem::replace(&mut socket.tcp_state, TcpState::Closed)
else {
unreachable!();
};
let future = tokio_socket.connect(remote_address);
socket.tcp_state = TcpState::Connecting(Box::pin(future));
// Start connection
table.get_mut(&this)?.start_connect(remote_address)?;
Ok(())
}
@ -156,45 +69,12 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get_mut(&this)?;
let previous_state = std::mem::replace(&mut socket.tcp_state, TcpState::Closed);
let result = match previous_state {
TcpState::ConnectReady(result) => result,
TcpState::Connecting(mut future) => {
let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
Poll::Ready(result) => result,
Poll::Pending => {
socket.tcp_state = TcpState::Connecting(future);
return Err(ErrorCode::WouldBlock.into());
}
}
}
previous_state => {
socket.tcp_state = previous_state;
return Err(ErrorCode::NotInProgress.into());
}
};
let (input, output) = socket.finish_connect()?;
match result {
Ok(stream) => {
let stream = Arc::new(stream);
let input: InputStream =
InputStream::Host(Box::new(TcpReadStream::new(stream.clone())));
let output: OutputStream = Box::new(TcpWriteStream::new(stream.clone()));
let input_stream = self.table().push_child(input, &this)?;
let output_stream = self.table().push_child(output, &this)?;
let socket = self.table().get_mut(&this)?;
socket.tcp_state = TcpState::Connected(stream);
Ok((input_stream, output_stream))
}
Err(err) => {
socket.tcp_state = TcpState::Closed;
Err(err.into())
}
}
let input_stream = self.table().push_child(input, &this)?;
let output_stream = self.table().push_child(output, &this)?;
Ok((input_stream, output_stream))
}
fn start_listen(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<()> {
@ -202,62 +82,13 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get_mut(&this)?;
match std::mem::replace(&mut socket.tcp_state, TcpState::Closed) {
TcpState::Bound(tokio_socket) => {
socket.tcp_state = TcpState::ListenStarted(tokio_socket);
Ok(())
}
TcpState::ListenStarted(tokio_socket) => {
socket.tcp_state = TcpState::ListenStarted(tokio_socket);
Err(ErrorCode::ConcurrencyConflict.into())
}
previous_state => {
socket.tcp_state = previous_state;
Err(ErrorCode::InvalidState.into())
}
}
socket.start_listen()
}
fn finish_listen(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<()> {
let table = self.table();
let socket = table.get_mut(&this)?;
let tokio_socket = match std::mem::replace(&mut socket.tcp_state, TcpState::Closed) {
TcpState::ListenStarted(tokio_socket) => tokio_socket,
previous_state => {
socket.tcp_state = previous_state;
return Err(ErrorCode::NotInProgress.into());
}
};
match with_ambient_tokio_runtime(|| tokio_socket.listen(socket.listen_backlog_size)) {
Ok(listener) => {
socket.tcp_state = TcpState::Listening {
listener,
pending_accept: None,
};
Ok(())
}
Err(err) => {
socket.tcp_state = TcpState::Closed;
Err(match Errno::from_io_error(&err) {
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
// According to the docs, `listen` can return EMFILE on Windows.
// This is odd, because we're not trying to create a new socket
// or file descriptor of any kind. So we rewrite it to less
// surprising error code.
//
// At the time of writing, this behavior has never been experimentally
// observed by any of the wasmtime authors, so we're relying fully
// on Microsoft's documentation here.
#[cfg(windows)]
Some(Errno::MFILE) => Errno::NOBUFS.into(),
_ => err.into(),
})
}
}
socket.finish_listen()
}
fn accept(
@ -272,90 +103,7 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get_mut(&this)?;
let TcpState::Listening {
listener,
pending_accept,
} = &mut socket.tcp_state
else {
return Err(ErrorCode::InvalidState.into());
};
let result = match pending_accept.take() {
Some(result) => result,
None => {
let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
.map_ok(|(stream, _)| stream)
{
Poll::Ready(result) => result,
Poll::Pending => Err(Errno::WOULDBLOCK.into()),
}
}
};
let client = result.map_err(|err| match Errno::from_io_error(&err) {
// From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
// > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
// > or the service provider is still processing a callback function.
//
// wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
// because in POSIX this error is only returned by a non-blocking
// `connect` and wasi-sockets has a different solution for that.
#[cfg(windows)]
Some(Errno::INPROGRESS) => Errno::INTR.into(),
// Normalize Linux' non-standard behavior.
//
// From https://man7.org/linux/man-pages/man2/accept.2.html:
// > Linux accept() passes already-pending network errors on the
// > new socket as an error code from accept(). This behavior
// > differs from other BSD socket implementations. (...)
#[cfg(target_os = "linux")]
Some(
Errno::CONNRESET
| Errno::NETRESET
| Errno::HOSTUNREACH
| Errno::HOSTDOWN
| Errno::NETDOWN
| Errno::NETUNREACH
| Errno::PROTO
| Errno::NOPROTOOPT
| Errno::NONET
| Errno::OPNOTSUPP,
) => Errno::CONNABORTED.into(),
_ => err,
})?;
#[cfg(target_os = "macos")]
{
// Manually inherit socket options from listener. We only have to
// do this on platforms that don't already do this automatically
// and only if a specific value was explicitly set on the listener.
if let Some(size) = socket.receive_buffer_size {
_ = util::set_socket_recv_buffer_size(&client, size); // Ignore potential error.
}
if let Some(size) = socket.send_buffer_size {
_ = util::set_socket_send_buffer_size(&client, size); // Ignore potential error.
}
// For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
if let (SocketAddressFamily::Ipv6, Some(ttl)) = (socket.family, socket.hop_limit) {
_ = util::set_ipv6_unicast_hops(&client, ttl); // Ignore potential error.
}
if let Some(value) = socket.keep_alive_idle_time {
_ = util::set_tcp_keepidle(&client, value); // Ignore potential error.
}
}
let client = Arc::new(client);
let input: InputStream = InputStream::Host(Box::new(TcpReadStream::new(client.clone())));
let output: OutputStream = Box::new(TcpWriteStream::new(client.clone()));
let tcp_socket = TcpSocket::from_state(TcpState::Connected(client), socket.family)?;
let (tcp_socket, input, output) = socket.accept()?;
let tcp_socket = self.table().push(tcp_socket)?;
let input_stream = self.table().push_child(input, &tcp_socket)?;
@ -368,38 +116,21 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get(&this)?;
let view = match socket.tcp_state {
TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()),
TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()),
_ => socket.as_std_view()?,
};
Ok(view.local_addr()?.into())
socket.local_address().map(Into::into)
}
fn remote_address(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<IpSocketAddress> {
let table = self.table();
let socket = table.get(&this)?;
let view = match socket.tcp_state {
TcpState::Connected(..) => socket.as_std_view()?,
TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
_ => return Err(ErrorCode::InvalidState.into()),
};
Ok(view.peer_addr()?.into())
socket.remote_address().map(Into::into)
}
fn is_listening(&mut self, this: Resource<tcp::TcpSocket>) -> Result<bool, anyhow::Error> {
let table = self.table();
let socket = table.get(&this)?;
match socket.tcp_state {
TcpState::Listening { .. } => Ok(true),
_ => Ok(false),
}
Ok(socket.is_listening())
}
fn address_family(
@ -409,7 +140,7 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get(&this)?;
match socket.family {
match socket.address_family() {
SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4),
SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6),
}
@ -420,49 +151,19 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
this: Resource<tcp::TcpSocket>,
value: u64,
) -> SocketResult<()> {
const MIN_BACKLOG: u32 = 1;
const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
let table = self.table();
let socket = table.get_mut(&this)?;
if value == 0 {
return Err(ErrorCode::InvalidArgument.into());
}
// Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
let value = value
.try_into()
.unwrap_or(u32::MAX)
.clamp(MIN_BACKLOG, MAX_BACKLOG);
match &socket.tcp_state {
TcpState::Default(..) | TcpState::Bound(..) => {
// Socket not listening yet. Stash value for first invocation to `listen`.
socket.listen_backlog_size = value;
Ok(())
}
TcpState::Listening { listener, .. } => {
// Try to update the backlog by calling `listen` again.
// Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
rustix::net::listen(&listener, value.try_into().unwrap())
.map_err(|_| ErrorCode::NotSupported)?;
socket.listen_backlog_size = value;
Ok(())
}
_ => Err(ErrorCode::InvalidState.into()),
}
let value = value.try_into().unwrap_or(u32::MAX);
socket.set_listen_backlog_size(value)
}
fn keep_alive_enabled(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<bool> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
Ok(sockopt::get_socket_keepalive(view)?)
socket.keep_alive_enabled()
}
fn set_keep_alive_enabled(
@ -472,15 +173,13 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
) -> SocketResult<()> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
Ok(sockopt::set_socket_keepalive(view, value)?)
socket.set_keep_alive_enabled(value)
}
fn keep_alive_idle_time(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<u64> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
Ok(sockopt::get_tcp_keepidle(view)?.as_nanos() as u64)
Ok(socket.keep_alive_idle_time()?.as_nanos() as u64)
}
fn set_keep_alive_idle_time(
@ -491,25 +190,13 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get_mut(&this)?;
let duration = Duration::from_nanos(value);
{
let view = &*socket.as_std_view()?;
util::set_tcp_keepidle(view, duration)?;
}
#[cfg(target_os = "macos")]
{
socket.keep_alive_idle_time = Some(duration);
}
Ok(())
socket.set_keep_alive_idle_time(duration)
}
fn keep_alive_interval(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<u64> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
Ok(sockopt::get_tcp_keepintvl(view)?.as_nanos() as u64)
Ok(socket.keep_alive_interval()?.as_nanos() as u64)
}
fn set_keep_alive_interval(
@ -519,15 +206,13 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
) -> SocketResult<()> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
Ok(util::set_tcp_keepintvl(view, Duration::from_nanos(value))?)
socket.set_keep_alive_interval(Duration::from_nanos(value))
}
fn keep_alive_count(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<u32> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
Ok(sockopt::get_tcp_keepcnt(view)?)
socket.keep_alive_count()
}
fn set_keep_alive_count(
@ -537,50 +222,26 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
) -> SocketResult<()> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
Ok(util::set_tcp_keepcnt(view, value)?)
socket.set_keep_alive_count(value)
}
fn hop_limit(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<u8> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
let ttl = match socket.family {
SocketAddressFamily::Ipv4 => util::get_ip_ttl(view)?,
SocketAddressFamily::Ipv6 => util::get_ipv6_unicast_hops(view)?,
};
Ok(ttl)
socket.hop_limit()
}
fn set_hop_limit(&mut self, this: Resource<tcp::TcpSocket>, value: u8) -> SocketResult<()> {
let table = self.table();
let socket = table.get_mut(&this)?;
{
let view = &*socket.as_std_view()?;
match socket.family {
SocketAddressFamily::Ipv4 => util::set_ip_ttl(view, value)?,
SocketAddressFamily::Ipv6 => util::set_ipv6_unicast_hops(view, value)?,
}
}
#[cfg(target_os = "macos")]
{
socket.hop_limit = Some(value);
}
Ok(())
socket.set_hop_limit(value)
}
fn receive_buffer_size(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<u64> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
let value = util::get_socket_recv_buffer_size(view)?;
Ok(value as u64)
Ok(socket.receive_buffer_size()? as u64)
}
fn set_receive_buffer_size(
@ -591,27 +252,14 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get_mut(&this)?;
let value = value.try_into().unwrap_or(usize::MAX);
{
let view = &*socket.as_std_view()?;
util::set_socket_recv_buffer_size(view, value)?;
}
#[cfg(target_os = "macos")]
{
socket.receive_buffer_size = Some(value);
}
Ok(())
socket.set_receive_buffer_size(value)
}
fn send_buffer_size(&mut self, this: Resource<tcp::TcpSocket>) -> SocketResult<u64> {
let table = self.table();
let socket = table.get(&this)?;
let view = &*socket.as_std_view()?;
let value = util::get_socket_send_buffer_size(view)?;
Ok(value as u64)
Ok(socket.send_buffer_size()? as u64)
}
fn set_send_buffer_size(
@ -622,18 +270,7 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get_mut(&this)?;
let value = value.try_into().unwrap_or(usize::MAX);
{
let view = &*socket.as_std_view()?;
util::set_socket_send_buffer_size(view, value)?;
}
#[cfg(target_os = "macos")]
{
socket.send_buffer_size = Some(value);
}
Ok(())
socket.set_send_buffer_size(value)
}
fn subscribe(&mut self, this: Resource<tcp::TcpSocket>) -> anyhow::Result<Resource<Pollable>> {
@ -648,21 +285,12 @@ impl<T: WasiView> crate::host::tcp::tcp::HostTcpSocket for T {
let table = self.table();
let socket = table.get(&this)?;
let stream = match &socket.tcp_state {
TcpState::Connected(stream) => stream,
_ => return Err(ErrorCode::InvalidState.into()),
};
let how = match shutdown_type {
ShutdownType::Receive => std::net::Shutdown::Read,
ShutdownType::Send => std::net::Shutdown::Write,
ShutdownType::Both => std::net::Shutdown::Both,
};
stream
.as_socketlike_view::<std::net::TcpStream>()
.shutdown(how)?;
Ok(())
Ok(socket.shutdown(how)?)
}
fn drop(&mut self, this: Resource<tcp::TcpSocket>) -> Result<(), anyhow::Error> {

705
crates/wasi/src/tcp.rs

@ -1,16 +1,24 @@
use crate::bindings::sockets::tcp::ErrorCode;
use crate::host::network;
use crate::network::SocketAddressFamily;
use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle};
use crate::{HostInputStream, HostOutputStream, SocketResult, StreamError, Subscribe};
use crate::{
HostInputStream, HostOutputStream, InputStream, OutputStream, SocketResult, StreamError,
Subscribe,
};
use anyhow::{Error, Result};
use cap_net_ext::AddressFamily;
use futures::Future;
use io_lifetimes::views::SocketlikeView;
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use rustix::net::sockopt;
use std::io;
use std::mem;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
/// Value taken from rust std library.
const DEFAULT_BACKLOG: u32 = 128;
@ -19,7 +27,7 @@ const DEFAULT_BACKLOG: u32 = 128;
///
/// This represents the various states a socket can be in during the
/// activities of binding, listening, accepting, and connecting.
pub(crate) enum TcpState {
enum TcpState {
/// The initial state for a newly-created socket.
Default(tokio::net::TcpSocket),
@ -51,30 +59,612 @@ pub(crate) enum TcpState {
Closed,
}
impl std::fmt::Debug for TcpState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Default(_) => f.debug_tuple("Default").finish(),
Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
Self::Bound(_) => f.debug_tuple("Bound").finish(),
Self::ListenStarted(_) => f.debug_tuple("ListenStarted").finish(),
Self::Listening { pending_accept, .. } => f
.debug_struct("Listening")
.field("pending_accept", pending_accept)
.finish(),
Self::Connecting(_) => f.debug_tuple("Connecting").finish(),
Self::ConnectReady(_) => f.debug_tuple("ConnectReady").finish(),
Self::Connected(_) => f.debug_tuple("Connected").finish(),
Self::Closed => write!(f, "Closed"),
}
}
}
/// A host TCP socket, plus associated bookkeeping.
///
/// The inner state is wrapped in an Arc because the same underlying socket is
/// used for implementing the stream types.
pub struct TcpSocket {
/// The current state in the bind/listen/accept/connect progression.
pub(crate) tcp_state: TcpState,
tcp_state: TcpState,
/// The desired listen queue size.
pub(crate) listen_backlog_size: u32,
listen_backlog_size: u32,
pub(crate) family: SocketAddressFamily,
family: SocketAddressFamily,
// The socket options below are not automatically inherited from the listener
// on all platforms. So we keep track of which options have been explicitly
// set and manually apply those values to newly accepted clients.
#[cfg(target_os = "macos")]
pub(crate) receive_buffer_size: Option<usize>,
receive_buffer_size: Option<usize>,
#[cfg(target_os = "macos")]
pub(crate) send_buffer_size: Option<usize>,
send_buffer_size: Option<usize>,
#[cfg(target_os = "macos")]
pub(crate) hop_limit: Option<u8>,
hop_limit: Option<u8>,
#[cfg(target_os = "macos")]
pub(crate) keep_alive_idle_time: Option<std::time::Duration>,
keep_alive_idle_time: Option<std::time::Duration>,
}
impl TcpSocket {
/// Create a new socket in the given family.
pub fn new(family: AddressFamily) -> io::Result<Self> {
with_ambient_tokio_runtime(|| {
let (socket, family) = match family {
AddressFamily::Ipv4 => {
let socket = tokio::net::TcpSocket::new_v4()?;
(socket, SocketAddressFamily::Ipv4)
}
AddressFamily::Ipv6 => {
let socket = tokio::net::TcpSocket::new_v6()?;
sockopt::set_ipv6_v6only(&socket, true)?;
(socket, SocketAddressFamily::Ipv6)
}
};
Self::from_state(TcpState::Default(socket), family)
})
}
/// Create a `TcpSocket` from an existing socket.
fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result<Self> {
Ok(Self {
tcp_state: state,
listen_backlog_size: DEFAULT_BACKLOG,
family,
#[cfg(target_os = "macos")]
receive_buffer_size: None,
#[cfg(target_os = "macos")]
send_buffer_size: None,
#[cfg(target_os = "macos")]
hop_limit: None,
#[cfg(target_os = "macos")]
keep_alive_idle_time: None,
})
}
fn as_std_view(&self) -> SocketResult<SocketlikeView<'_, std::net::TcpStream>> {
use crate::bindings::sockets::network::ErrorCode;
match &self.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
Ok(socket.as_socketlike_view::<std::net::TcpStream>())
}
TcpState::Connected(stream) => Ok(stream.as_socketlike_view::<std::net::TcpStream>()),
TcpState::Listening { listener, .. } => {
Ok(listener.as_socketlike_view::<std::net::TcpStream>())
}
TcpState::BindStarted(..)
| TcpState::ListenStarted(..)
| TcpState::Connecting(..)
| TcpState::ConnectReady(..)
| TcpState::Closed => Err(ErrorCode::InvalidState.into()),
}
}
}
impl TcpSocket {
pub fn start_bind(&mut self, local_address: SocketAddr) -> io::Result<()> {
let tokio_socket = match &self.tcp_state {
TcpState::Default(socket) => socket,
TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()),
_ => return Err(Errno::ISCONN.into()),
};
network::util::validate_unicast(&local_address)?;
network::util::validate_address_family(&local_address, &self.family)?;
{
// Automatically bypass the TIME_WAIT state when the user is trying
// to bind to a specific port:
let reuse_addr = local_address.port() > 0;
// Unconditionally (re)set SO_REUSEADDR, even when the value is false.
// This ensures we're not accidentally affected by any socket option
// state left behind by a previous failed call to this method (start_bind).
network::util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?;
// Perform the OS bind call.
tokio_socket.bind(local_address).map_err(|error| {
match Errno::from_io_error(&error) {
// From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html:
// > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket
//
// The most common reasons for this error should have already
// been handled by our own validation slightly higher up in this
// function. This error mapping is here just in case there is
// an edge case we didn't catch.
Some(Errno::AFNOSUPPORT) => io::Error::new(
io::ErrorKind::InvalidInput,
"The specified address is not a valid address for the address family of the specified socket",
),
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS
// Windows returns WSAENOBUFS when the ephemeral ports have been exhausted.
#[cfg(windows)]
Some(Errno::NOBUFS) => io::Error::new(io::ErrorKind::AddrInUse, "no more free local ports"),
_ => error,
}
})?;
self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::Default(socket) => TcpState::BindStarted(socket),
_ => unreachable!(),
};
Ok(())
}
}
pub fn finish_bind(&mut self) -> SocketResult<()> {
match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::BindStarted(socket) => {
self.tcp_state = TcpState::Bound(socket);
Ok(())
}
current_state => {
// Reset the state so that the outside world doesn't see this socket as closed
self.tcp_state = current_state;
Err(ErrorCode::NotInProgress.into())
}
}
}
pub fn start_connect(&mut self, remote_address: SocketAddr) -> SocketResult<()> {
match self.tcp_state {
TcpState::Default(..) => {}
TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
_ => return Err(ErrorCode::InvalidState.into()),
};
network::util::validate_unicast(&remote_address)?;
network::util::validate_remote_address(&remote_address)?;
network::util::validate_address_family(&remote_address, &self.family)?;
let TcpState::Default(tokio_socket) =
std::mem::replace(&mut self.tcp_state, TcpState::Closed)
else {
unreachable!();
};
let future = tokio_socket.connect(remote_address);
self.tcp_state = TcpState::Connecting(Box::pin(future));
Ok(())
}
pub fn finish_connect(&mut self) -> SocketResult<(InputStream, OutputStream)> {
let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed);
let result = match previous_state {
TcpState::ConnectReady(result) => result,
TcpState::Connecting(mut future) => {
let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
Poll::Ready(result) => result,
Poll::Pending => {
self.tcp_state = TcpState::Connecting(future);
return Err(ErrorCode::WouldBlock.into());
}
}
}
previous_state => {
self.tcp_state = previous_state;
return Err(ErrorCode::NotInProgress.into());
}
};
match result {
Ok(stream) => {
let stream = Arc::new(stream);
self.tcp_state = TcpState::Connected(stream.clone());
let input: InputStream =
InputStream::Host(Box::new(TcpReadStream::new(stream.clone())));
let output: OutputStream = Box::new(TcpWriteStream::new(stream));
Ok((input, output))
}
Err(err) => {
self.tcp_state = TcpState::Closed;
Err(err.into())
}
}
}
pub fn start_listen(&mut self) -> SocketResult<()> {
match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::Bound(tokio_socket) => {
self.tcp_state = TcpState::ListenStarted(tokio_socket);
Ok(())
}
TcpState::ListenStarted(tokio_socket) => {
self.tcp_state = TcpState::ListenStarted(tokio_socket);
Err(ErrorCode::ConcurrencyConflict.into())
}
previous_state => {
self.tcp_state = previous_state;
Err(ErrorCode::InvalidState.into())
}
}
}
pub fn finish_listen(&mut self) -> SocketResult<()> {
let tokio_socket = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::ListenStarted(tokio_socket) => tokio_socket,
previous_state => {
self.tcp_state = previous_state;
return Err(ErrorCode::NotInProgress.into());
}
};
match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
Ok(listener) => {
self.tcp_state = TcpState::Listening {
listener,
pending_accept: None,
};
Ok(())
}
Err(err) => {
self.tcp_state = TcpState::Closed;
Err(match Errno::from_io_error(&err) {
// See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE
// According to the docs, `listen` can return EMFILE on Windows.
// This is odd, because we're not trying to create a new socket
// or file descriptor of any kind. So we rewrite it to less
// surprising error code.
//
// At the time of writing, this behavior has never been experimentally
// observed by any of the wasmtime authors, so we're relying fully
// on Microsoft's documentation here.
#[cfg(windows)]
Some(Errno::MFILE) => Errno::NOBUFS.into(),
_ => err.into(),
})
}
}
}
pub fn accept(&mut self) -> SocketResult<(Self, InputStream, OutputStream)> {
let TcpState::Listening {
listener,
pending_accept,
} = &mut self.tcp_state
else {
return Err(ErrorCode::InvalidState.into());
};
let result = match pending_accept.take() {
Some(result) => result,
None => {
let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
.map_ok(|(stream, _)| stream)
{
Poll::Ready(result) => result,
Poll::Pending => Err(Errno::WOULDBLOCK.into()),
}
}
};
let client = result.map_err(|err| match Errno::from_io_error(&err) {
// From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS
// > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress,
// > or the service provider is still processing a callback function.
//
// wasi-sockets doesn't have an equivalent to the EINPROGRESS error,
// because in POSIX this error is only returned by a non-blocking
// `connect` and wasi-sockets has a different solution for that.
#[cfg(windows)]
Some(Errno::INPROGRESS) => Errno::INTR.into(),
// Normalize Linux' non-standard behavior.
//
// From https://man7.org/linux/man-pages/man2/accept.2.html:
// > Linux accept() passes already-pending network errors on the
// > new socket as an error code from accept(). This behavior
// > differs from other BSD socket implementations. (...)
#[cfg(target_os = "linux")]
Some(
Errno::CONNRESET
| Errno::NETRESET
| Errno::HOSTUNREACH
| Errno::HOSTDOWN
| Errno::NETDOWN
| Errno::NETUNREACH
| Errno::PROTO
| Errno::NOPROTOOPT
| Errno::NONET
| Errno::OPNOTSUPP,
) => Errno::CONNABORTED.into(),
_ => err,
})?;
#[cfg(target_os = "macos")]
{
// Manually inherit socket options from listener. We only have to
// do this on platforms that don't already do this automatically
// and only if a specific value was explicitly set on the listener.
if let Some(size) = self.receive_buffer_size {
_ = network::util::set_socket_recv_buffer_size(&client, size); // Ignore potential error.
}
if let Some(size) = self.send_buffer_size {
_ = network::util::set_socket_send_buffer_size(&client, size); // Ignore potential error.
}
// For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't.
if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) {
_ = network::util::set_ipv6_unicast_hops(&client, ttl); // Ignore potential error.
}
if let Some(value) = self.keep_alive_idle_time {
_ = network::util::set_tcp_keepidle(&client, value); // Ignore potential error.
}
}
let client = Arc::new(client);
let input: InputStream = InputStream::Host(Box::new(TcpReadStream::new(client.clone())));
let output: OutputStream = Box::new(TcpWriteStream::new(client.clone()));
let tcp_socket = TcpSocket::from_state(TcpState::Connected(client), self.family)?;
Ok((tcp_socket, input, output))
}
pub fn local_address(&self) -> SocketResult<SocketAddr> {
let view = match self.tcp_state {
TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()),
TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()),
_ => self.as_std_view()?,
};
Ok(view.local_addr()?)
}
pub fn remote_address(&self) -> SocketResult<SocketAddr> {
let view = match self.tcp_state {
TcpState::Connected(..) => self.as_std_view()?,
TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
_ => return Err(ErrorCode::InvalidState.into()),
};
Ok(view.peer_addr()?)
}
pub fn is_listening(&self) -> bool {
matches!(self.tcp_state, TcpState::Listening { .. })
}
pub fn address_family(&self) -> SocketAddressFamily {
self.family
}
pub fn set_listen_backlog_size(&mut self, value: u32) -> SocketResult<()> {
const MIN_BACKLOG: u32 = 1;
const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further.
if value == 0 {
return Err(ErrorCode::InvalidArgument.into());
}
// Silently clamp backlog size. This is OK for us to do, because operating systems do this too.
let value = value.clamp(MIN_BACKLOG, MAX_BACKLOG);
match &self.tcp_state {
TcpState::Default(..) | TcpState::Bound(..) => {
// Socket not listening yet. Stash value for first invocation to `listen`.
}
TcpState::Listening { listener, .. } => {
// Try to update the backlog by calling `listen` again.
// Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact.
rustix::net::listen(&listener, value.try_into().unwrap())
.map_err(|_| ErrorCode::NotSupported)?;
}
_ => return Err(ErrorCode::InvalidState.into()),
}
self.listen_backlog_size = value;
Ok(())
}
pub fn keep_alive_enabled(&self) -> SocketResult<bool> {
let view = &*self.as_std_view()?;
Ok(sockopt::get_socket_keepalive(view)?)
}
pub fn set_keep_alive_enabled(&self, value: bool) -> SocketResult<()> {
let view = &*self.as_std_view()?;
Ok(sockopt::set_socket_keepalive(view, value)?)
}
pub fn keep_alive_idle_time(&self) -> SocketResult<std::time::Duration> {
let view = &*self.as_std_view()?;
Ok(sockopt::get_tcp_keepidle(view)?)
}
pub fn set_keep_alive_idle_time(&mut self, duration: std::time::Duration) -> SocketResult<()> {
{
let view = &*self.as_std_view()?;
network::util::set_tcp_keepidle(view, duration)?;
}
#[cfg(target_os = "macos")]
{
self.keep_alive_idle_time = Some(duration);
}
Ok(())
}
pub fn keep_alive_interval(&self) -> SocketResult<std::time::Duration> {
let view = &*self.as_std_view()?;
Ok(sockopt::get_tcp_keepintvl(view)?)
}
pub fn set_keep_alive_interval(&self, duration: std::time::Duration) -> SocketResult<()> {
let view = &*self.as_std_view()?;
Ok(network::util::set_tcp_keepintvl(view, duration)?)
}
pub fn keep_alive_count(&self) -> SocketResult<u32> {
let view = &*self.as_std_view()?;
Ok(sockopt::get_tcp_keepcnt(view)?)
}
pub fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> {
// TODO(rylev): do we need to check and clamp the value?
let view = &*self.as_std_view()?;
Ok(network::util::set_tcp_keepcnt(view, value)?)
}
pub fn hop_limit(&self) -> SocketResult<u8> {
let view = &*self.as_std_view()?;
let ttl = match self.family {
SocketAddressFamily::Ipv4 => network::util::get_ip_ttl(view)?,
SocketAddressFamily::Ipv6 => network::util::get_ipv6_unicast_hops(view)?,
};
Ok(ttl)
}
pub fn set_hop_limit(&mut self, value: u8) -> SocketResult<()> {
{
let view = &*self.as_std_view()?;
match self.family {
SocketAddressFamily::Ipv4 => network::util::set_ip_ttl(view, value)?,
SocketAddressFamily::Ipv6 => network::util::set_ipv6_unicast_hops(view, value)?,
}
}
#[cfg(target_os = "macos")]
{
self.hop_limit = Some(value);
}
Ok(())
}
pub fn receive_buffer_size(&self) -> SocketResult<usize> {
let view = &*self.as_std_view()?;
Ok(network::util::get_socket_recv_buffer_size(view)?)
}
pub fn set_receive_buffer_size(&mut self, value: usize) -> SocketResult<()> {
{
let view = &*self.as_std_view()?;
network::util::set_socket_recv_buffer_size(view, value)?;
}
#[cfg(target_os = "macos")]
{
self.receive_buffer_size = Some(value);
}
Ok(())
}
pub fn send_buffer_size(&self) -> SocketResult<usize> {
let view = &*self.as_std_view()?;
Ok(network::util::get_socket_send_buffer_size(view)?)
}
pub fn set_send_buffer_size(&mut self, value: usize) -> SocketResult<()> {
{
let view = &*self.as_std_view()?;
network::util::set_socket_send_buffer_size(view, value)?;
}
#[cfg(target_os = "macos")]
{
self.send_buffer_size = Some(value);
}
Ok(())
}
pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
let stream = match &self.tcp_state {
TcpState::Connected(stream) => stream,
_ => {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"socket not connected",
))
}
};
stream
.as_socketlike_view::<std::net::TcpStream>()
.shutdown(how)?;
Ok(())
}
}
#[async_trait::async_trait]
impl Subscribe for TcpSocket {
async fn ready(&mut self) {
match &mut self.tcp_state {
TcpState::Default(..)
| TcpState::BindStarted(..)
| TcpState::Bound(..)
| TcpState::ListenStarted(..)
| TcpState::ConnectReady(..)
| TcpState::Closed
| TcpState::Connected(..) => {
// No async operation in progress.
}
TcpState::Connecting(future) => {
self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
}
TcpState::Listening {
listener,
pending_accept,
} => match pending_accept {
Some(_) => {}
None => {
let result = futures::future::poll_fn(|cx| {
listener.poll_accept(cx).map_ok(|(stream, _)| stream)
})
.await;
*pending_accept = Some(result);
}
},
}
}
}
pub(crate) struct TcpReadStream {
@ -255,94 +845,3 @@ impl Subscribe for TcpWriteStream {
}
}
}
impl TcpSocket {
/// Create a new socket in the given family.
pub fn new(family: AddressFamily) -> io::Result<Self> {
with_ambient_tokio_runtime(|| {
let (socket, family) = match family {
AddressFamily::Ipv4 => {
let socket = tokio::net::TcpSocket::new_v4()?;
(socket, SocketAddressFamily::Ipv4)
}
AddressFamily::Ipv6 => {
let socket = tokio::net::TcpSocket::new_v6()?;
sockopt::set_ipv6_v6only(&socket, true)?;
(socket, SocketAddressFamily::Ipv6)
}
};
Self::from_state(TcpState::Default(socket), family)
})
}
/// Create a `TcpSocket` from an existing socket.
pub(crate) fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result<Self> {
Ok(Self {
tcp_state: state,
listen_backlog_size: DEFAULT_BACKLOG,
family,
#[cfg(target_os = "macos")]
receive_buffer_size: None,
#[cfg(target_os = "macos")]
send_buffer_size: None,
#[cfg(target_os = "macos")]
hop_limit: None,
#[cfg(target_os = "macos")]
keep_alive_idle_time: None,
})
}
pub(crate) fn as_std_view(&self) -> SocketResult<SocketlikeView<'_, std::net::TcpStream>> {
use crate::bindings::sockets::network::ErrorCode;
match &self.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
Ok(socket.as_socketlike_view::<std::net::TcpStream>())
}
TcpState::Connected(stream) => Ok(stream.as_socketlike_view::<std::net::TcpStream>()),
TcpState::Listening { listener, .. } => {
Ok(listener.as_socketlike_view::<std::net::TcpStream>())
}
TcpState::BindStarted(..)
| TcpState::ListenStarted(..)
| TcpState::Connecting(..)
| TcpState::ConnectReady(..)
| TcpState::Closed => Err(ErrorCode::InvalidState.into()),
}
}
}
#[async_trait::async_trait]
impl Subscribe for TcpSocket {
async fn ready(&mut self) {
match &mut self.tcp_state {
TcpState::Default(..)
| TcpState::BindStarted(..)
| TcpState::Bound(..)
| TcpState::ListenStarted(..)
| TcpState::ConnectReady(..)
| TcpState::Closed
| TcpState::Connected(..) => {
// No async operation in progress.
}
TcpState::Connecting(future) => {
self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
}
TcpState::Listening {
listener,
pending_accept,
} => match pending_accept {
Some(_) => {}
None => {
let result = futures::future::poll_fn(|cx| {
listener.poll_accept(cx).map_ok(|(stream, _)| stream)
})
.await;
*pending_accept = Some(result);
}
},
}
}
}

Loading…
Cancel
Save