From 857de9a0ab1ca48d1f7ee0ba174bd935bb470f60 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Tue, 2 Apr 2024 23:42:57 +0200 Subject: [PATCH] Separate TCP implementation from the host bindings (#8282) * Move bind into Tcp type Signed-off-by: Ryan Levick * Move start_connect into Tcp type Signed-off-by: Ryan Levick * Move finish_connect into Tcp type Signed-off-by: Ryan Levick * Move *_listen into Tcp type Signed-off-by: Ryan Levick * Move accept into Tcp type Signed-off-by: Ryan Levick * Move address methods into Tcp type Signed-off-by: Ryan Levick * Move various option methods into Tcp type Signed-off-by: Ryan Levick * Move shutdown methods into Tcp type Signed-off-by: Ryan Levick * Move finish bind methods into Tcp type Signed-off-by: Ryan Levick * Change connect's return type Signed-off-by: Ryan Levick * Move shutdown over to io::Result Signed-off-by: Ryan Levick * Rearrange some code Signed-off-by: Ryan Levick * Move bind to io Error Signed-off-by: Ryan Levick --------- Signed-off-by: Ryan Levick --- crates/wasi/src/host/network.rs | 46 ++- crates/wasi/src/host/tcp.rs | 448 ++------------------ crates/wasi/src/tcp.rs | 705 +++++++++++++++++++++++++++----- 3 files changed, 673 insertions(+), 526 deletions(-) diff --git a/crates/wasi/src/host/network.rs b/crates/wasi/src/host/network.rs index be832d448a..287a4b5b4e 100644 --- a/crates/wasi/src/host/network.rs +++ b/crates/wasi/src/host/network.rs @@ -212,29 +212,34 @@ impl From for IpAddressFamily { } pub(crate) mod util { + use std::io; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::time::Duration; - use crate::bindings::sockets::network::ErrorCode; use crate::network::SocketAddressFamily; - use crate::SocketResult; use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt}; use rustix::fd::{AsFd, OwnedFd}; use rustix::io::Errno; use rustix::net::sockopt; - pub fn validate_unicast(addr: &SocketAddr) -> SocketResult<()> { + pub fn validate_unicast(addr: &SocketAddr) -> io::Result<()> { match to_canonical(&addr.ip()) { IpAddr::V4(ipv4) => { if ipv4.is_multicast() || ipv4.is_broadcast() { - Err(ErrorCode::InvalidArgument.into()) + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Both IPv4 broadcast and multicast addresses are not supported", + )) } else { Ok(()) } } IpAddr::V6(ipv6) => { if ipv6.is_multicast() { - Err(ErrorCode::InvalidArgument.into()) + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "IPv6 multicast addresses are not supported", + )) } else { Ok(()) } @@ -242,13 +247,19 @@ pub(crate) mod util { } } - pub fn validate_remote_address(addr: &SocketAddr) -> SocketResult<()> { + pub fn validate_remote_address(addr: &SocketAddr) -> io::Result<()> { if to_canonical(&addr.ip()).is_unspecified() { - return Err(ErrorCode::InvalidArgument.into()); + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Remote address may not be `0.0.0.0` or `::`", + )); } if addr.port() == 0 { - return Err(ErrorCode::InvalidArgument.into()); + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Remote port may not be 0", + )); } Ok(()) @@ -257,7 +268,7 @@ pub(crate) mod util { pub fn validate_address_family( addr: &SocketAddr, socket_family: &SocketAddressFamily, - ) -> SocketResult<()> { + ) -> io::Result<()> { match (socket_family, addr.ip()) { (SocketAddressFamily::Ipv4, IpAddr::V4(_)) => Ok(()), (SocketAddressFamily::Ipv6, IpAddr::V6(ipv6)) => { @@ -266,14 +277,23 @@ pub(crate) mod util { // since 2006, OS handling of them is inconsistent and our own // validations don't take them into account either. // Note that these are not the same as IPv4-*mapped* IPv6 addresses. - Err(ErrorCode::InvalidArgument.into()) + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "IPv4-compatible IPv6 addresses are not supported", + )) } else if ipv6.to_ipv4_mapped().is_some() { - Err(ErrorCode::InvalidArgument.into()) + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "IPv4-mapped IPv6 address passed to an IPv6-only socket", + )) } else { Ok(()) } } - _ => Err(ErrorCode::InvalidArgument.into()), + _ => Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Address family mismatch", + )), } } @@ -301,7 +321,7 @@ pub(crate) mod util { * Syscalls wrappers with (opinionated) portability fixes. */ - pub fn udp_socket(family: AddressFamily, blocking: Blocking) -> std::io::Result { + pub fn udp_socket(family: AddressFamily, blocking: Blocking) -> io::Result { // Delegate socket creation to cap_net_ext. They handle a couple of things for us: // - On Windows: call WSAStartup if not done before. // - Set the NONBLOCK and CLOEXEC flags. Either immediately during socket creation, diff --git a/crates/wasi/src/host/tcp.rs b/crates/wasi/src/host/tcp.rs index 11d44d8825..d0ccaaa42c 100644 --- a/crates/wasi/src/host/tcp.rs +++ b/crates/wasi/src/host/tcp.rs @@ -1,22 +1,14 @@ -use crate::host::network::util; use crate::network::SocketAddrUse; -use crate::runtime::with_ambient_tokio_runtime; -use crate::tcp::{TcpReadStream, TcpSocket, TcpState, TcpWriteStream}; use crate::{ bindings::{ io::streams::{InputStream, OutputStream}, - sockets::network::{ErrorCode, IpAddressFamily, IpSocketAddress, Network}, + sockets::network::{IpAddressFamily, IpSocketAddress, Network}, sockets::tcp::{self, ShutdownType}, }, network::SocketAddressFamily, }; use crate::{Pollable, SocketResult, WasiView}; -use io_lifetimes::AsSocketlike; -use rustix::io::Errno; -use rustix::net::sockopt; use std::net::SocketAddr; -use std::sync::Arc; -use std::task::Poll; use std::time::Duration; use wasmtime::component::Resource; @@ -31,60 +23,14 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { ) -> SocketResult<()> { self.ctx().allowed_network_uses.check_allowed_tcp()?; let table = self.table(); - let socket = table.get(&this)?; let network = table.get(&network)?; let local_address: SocketAddr = local_address.into(); - let tokio_socket = match &socket.tcp_state { - TcpState::Default(socket) => socket, - TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()), - _ => return Err(ErrorCode::InvalidState.into()), - }; - - util::validate_unicast(&local_address)?; - util::validate_address_family(&local_address, &socket.family)?; - - { - // Ensure that we're allowed to connect to this address. - network.check_socket_addr(&local_address, SocketAddrUse::TcpBind)?; - - // Automatically bypass the TIME_WAIT state when the user is trying - // to bind to a specific port: - let reuse_addr = local_address.port() > 0; - - // Unconditionally (re)set SO_REUSEADDR, even when the value is false. - // This ensures we're not accidentally affected by any socket option - // state left behind by a previous failed call to this method (start_bind). - util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?; - - // Perform the OS bind call. - tokio_socket.bind(local_address).map_err(|error| { - match Errno::from_io_error(&error) { - // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: - // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket - // - // The most common reasons for this error should have already - // been handled by our own validation slightly higher up in this - // function. This error mapping is here just in case there is - // an edge case we didn't catch. - Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument, - - // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS - // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. - #[cfg(windows)] - Some(Errno::NOBUFS) => ErrorCode::AddressInUse, - - _ => ErrorCode::from(error), - } - })?; - } - - let socket = table.get_mut(&this)?; + // Ensure that we're allowed to connect to this address. + network.check_socket_addr(&local_address, SocketAddrUse::TcpBind)?; - socket.tcp_state = match std::mem::replace(&mut socket.tcp_state, TcpState::Closed) { - TcpState::Default(socket) => TcpState::BindStarted(socket), - _ => unreachable!(), - }; + // Bind to the address. + table.get_mut(&this)?.start_bind(local_address)?; Ok(()) } @@ -93,17 +39,7 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get_mut(&this)?; - match socket.tcp_state { - TcpState::BindStarted(..) => {} - _ => return Err(ErrorCode::NotInProgress.into()), - } - - socket.tcp_state = match std::mem::replace(&mut socket.tcp_state, TcpState::Closed) { - TcpState::BindStarted(socket) => TcpState::Bound(socket), - _ => unreachable!(), - }; - - Ok(()) + socket.finish_bind() } fn start_connect( @@ -114,37 +50,14 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { ) -> SocketResult<()> { self.ctx().allowed_network_uses.check_allowed_tcp()?; let table = self.table(); - let socket = table.get(&this)?; let network = table.get(&network)?; let remote_address: SocketAddr = remote_address.into(); - match socket.tcp_state { - TcpState::Default(..) => {} - - TcpState::Connecting(..) | TcpState::ConnectReady(..) => { - return Err(ErrorCode::ConcurrencyConflict.into()) - } - - _ => return Err(ErrorCode::InvalidState.into()), - }; - - util::validate_unicast(&remote_address)?; - util::validate_remote_address(&remote_address)?; - util::validate_address_family(&remote_address, &socket.family)?; - // Ensure that we're allowed to connect to this address. network.check_socket_addr(&remote_address, SocketAddrUse::TcpConnect)?; - let socket = table.get_mut(&this)?; - let TcpState::Default(tokio_socket) = - std::mem::replace(&mut socket.tcp_state, TcpState::Closed) - else { - unreachable!(); - }; - - let future = tokio_socket.connect(remote_address); - - socket.tcp_state = TcpState::Connecting(Box::pin(future)); + // Start connection + table.get_mut(&this)?.start_connect(remote_address)?; Ok(()) } @@ -156,45 +69,12 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get_mut(&this)?; - let previous_state = std::mem::replace(&mut socket.tcp_state, TcpState::Closed); - let result = match previous_state { - TcpState::ConnectReady(result) => result, - TcpState::Connecting(mut future) => { - let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); - match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) { - Poll::Ready(result) => result, - Poll::Pending => { - socket.tcp_state = TcpState::Connecting(future); - return Err(ErrorCode::WouldBlock.into()); - } - } - } - previous_state => { - socket.tcp_state = previous_state; - return Err(ErrorCode::NotInProgress.into()); - } - }; + let (input, output) = socket.finish_connect()?; - match result { - Ok(stream) => { - let stream = Arc::new(stream); - - let input: InputStream = - InputStream::Host(Box::new(TcpReadStream::new(stream.clone()))); - let output: OutputStream = Box::new(TcpWriteStream::new(stream.clone())); - - let input_stream = self.table().push_child(input, &this)?; - let output_stream = self.table().push_child(output, &this)?; - - let socket = self.table().get_mut(&this)?; - socket.tcp_state = TcpState::Connected(stream); - Ok((input_stream, output_stream)) - } - Err(err) => { - socket.tcp_state = TcpState::Closed; - Err(err.into()) - } - } + let input_stream = self.table().push_child(input, &this)?; + let output_stream = self.table().push_child(output, &this)?; + + Ok((input_stream, output_stream)) } fn start_listen(&mut self, this: Resource) -> SocketResult<()> { @@ -202,62 +82,13 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get_mut(&this)?; - match std::mem::replace(&mut socket.tcp_state, TcpState::Closed) { - TcpState::Bound(tokio_socket) => { - socket.tcp_state = TcpState::ListenStarted(tokio_socket); - Ok(()) - } - TcpState::ListenStarted(tokio_socket) => { - socket.tcp_state = TcpState::ListenStarted(tokio_socket); - Err(ErrorCode::ConcurrencyConflict.into()) - } - previous_state => { - socket.tcp_state = previous_state; - Err(ErrorCode::InvalidState.into()) - } - } + socket.start_listen() } fn finish_listen(&mut self, this: Resource) -> SocketResult<()> { let table = self.table(); let socket = table.get_mut(&this)?; - - let tokio_socket = match std::mem::replace(&mut socket.tcp_state, TcpState::Closed) { - TcpState::ListenStarted(tokio_socket) => tokio_socket, - previous_state => { - socket.tcp_state = previous_state; - return Err(ErrorCode::NotInProgress.into()); - } - }; - - match with_ambient_tokio_runtime(|| tokio_socket.listen(socket.listen_backlog_size)) { - Ok(listener) => { - socket.tcp_state = TcpState::Listening { - listener, - pending_accept: None, - }; - Ok(()) - } - Err(err) => { - socket.tcp_state = TcpState::Closed; - - Err(match Errno::from_io_error(&err) { - // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE - // According to the docs, `listen` can return EMFILE on Windows. - // This is odd, because we're not trying to create a new socket - // or file descriptor of any kind. So we rewrite it to less - // surprising error code. - // - // At the time of writing, this behavior has never been experimentally - // observed by any of the wasmtime authors, so we're relying fully - // on Microsoft's documentation here. - #[cfg(windows)] - Some(Errno::MFILE) => Errno::NOBUFS.into(), - - _ => err.into(), - }) - } - } + socket.finish_listen() } fn accept( @@ -272,90 +103,7 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get_mut(&this)?; - let TcpState::Listening { - listener, - pending_accept, - } = &mut socket.tcp_state - else { - return Err(ErrorCode::InvalidState.into()); - }; - - let result = match pending_accept.take() { - Some(result) => result, - None => { - let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); - match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx)) - .map_ok(|(stream, _)| stream) - { - Poll::Ready(result) => result, - Poll::Pending => Err(Errno::WOULDBLOCK.into()), - } - } - }; - - let client = result.map_err(|err| match Errno::from_io_error(&err) { - // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS - // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress, - // > or the service provider is still processing a callback function. - // - // wasi-sockets doesn't have an equivalent to the EINPROGRESS error, - // because in POSIX this error is only returned by a non-blocking - // `connect` and wasi-sockets has a different solution for that. - #[cfg(windows)] - Some(Errno::INPROGRESS) => Errno::INTR.into(), - - // Normalize Linux' non-standard behavior. - // - // From https://man7.org/linux/man-pages/man2/accept.2.html: - // > Linux accept() passes already-pending network errors on the - // > new socket as an error code from accept(). This behavior - // > differs from other BSD socket implementations. (...) - #[cfg(target_os = "linux")] - Some( - Errno::CONNRESET - | Errno::NETRESET - | Errno::HOSTUNREACH - | Errno::HOSTDOWN - | Errno::NETDOWN - | Errno::NETUNREACH - | Errno::PROTO - | Errno::NOPROTOOPT - | Errno::NONET - | Errno::OPNOTSUPP, - ) => Errno::CONNABORTED.into(), - - _ => err, - })?; - - #[cfg(target_os = "macos")] - { - // Manually inherit socket options from listener. We only have to - // do this on platforms that don't already do this automatically - // and only if a specific value was explicitly set on the listener. - - if let Some(size) = socket.receive_buffer_size { - _ = util::set_socket_recv_buffer_size(&client, size); // Ignore potential error. - } - - if let Some(size) = socket.send_buffer_size { - _ = util::set_socket_send_buffer_size(&client, size); // Ignore potential error. - } - - // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't. - if let (SocketAddressFamily::Ipv6, Some(ttl)) = (socket.family, socket.hop_limit) { - _ = util::set_ipv6_unicast_hops(&client, ttl); // Ignore potential error. - } - - if let Some(value) = socket.keep_alive_idle_time { - _ = util::set_tcp_keepidle(&client, value); // Ignore potential error. - } - } - - let client = Arc::new(client); - - let input: InputStream = InputStream::Host(Box::new(TcpReadStream::new(client.clone()))); - let output: OutputStream = Box::new(TcpWriteStream::new(client.clone())); - let tcp_socket = TcpSocket::from_state(TcpState::Connected(client), socket.family)?; + let (tcp_socket, input, output) = socket.accept()?; let tcp_socket = self.table().push(tcp_socket)?; let input_stream = self.table().push_child(input, &tcp_socket)?; @@ -368,38 +116,21 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get(&this)?; - let view = match socket.tcp_state { - TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()), - TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()), - _ => socket.as_std_view()?, - }; - - Ok(view.local_addr()?.into()) + socket.local_address().map(Into::into) } fn remote_address(&mut self, this: Resource) -> SocketResult { let table = self.table(); let socket = table.get(&this)?; - let view = match socket.tcp_state { - TcpState::Connected(..) => socket.as_std_view()?, - TcpState::Connecting(..) | TcpState::ConnectReady(..) => { - return Err(ErrorCode::ConcurrencyConflict.into()) - } - _ => return Err(ErrorCode::InvalidState.into()), - }; - - Ok(view.peer_addr()?.into()) + socket.remote_address().map(Into::into) } fn is_listening(&mut self, this: Resource) -> Result { let table = self.table(); let socket = table.get(&this)?; - match socket.tcp_state { - TcpState::Listening { .. } => Ok(true), - _ => Ok(false), - } + Ok(socket.is_listening()) } fn address_family( @@ -409,7 +140,7 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get(&this)?; - match socket.family { + match socket.address_family() { SocketAddressFamily::Ipv4 => Ok(IpAddressFamily::Ipv4), SocketAddressFamily::Ipv6 => Ok(IpAddressFamily::Ipv6), } @@ -420,49 +151,19 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { this: Resource, value: u64, ) -> SocketResult<()> { - const MIN_BACKLOG: u32 = 1; - const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further. - let table = self.table(); let socket = table.get_mut(&this)?; - if value == 0 { - return Err(ErrorCode::InvalidArgument.into()); - } - // Silently clamp backlog size. This is OK for us to do, because operating systems do this too. - let value = value - .try_into() - .unwrap_or(u32::MAX) - .clamp(MIN_BACKLOG, MAX_BACKLOG); - - match &socket.tcp_state { - TcpState::Default(..) | TcpState::Bound(..) => { - // Socket not listening yet. Stash value for first invocation to `listen`. - socket.listen_backlog_size = value; - - Ok(()) - } - TcpState::Listening { listener, .. } => { - // Try to update the backlog by calling `listen` again. - // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact. - - rustix::net::listen(&listener, value.try_into().unwrap()) - .map_err(|_| ErrorCode::NotSupported)?; - - socket.listen_backlog_size = value; - - Ok(()) - } - _ => Err(ErrorCode::InvalidState.into()), - } + let value = value.try_into().unwrap_or(u32::MAX); + + socket.set_listen_backlog_size(value) } fn keep_alive_enabled(&mut self, this: Resource) -> SocketResult { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - Ok(sockopt::get_socket_keepalive(view)?) + socket.keep_alive_enabled() } fn set_keep_alive_enabled( @@ -472,15 +173,13 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { ) -> SocketResult<()> { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - Ok(sockopt::set_socket_keepalive(view, value)?) + socket.set_keep_alive_enabled(value) } fn keep_alive_idle_time(&mut self, this: Resource) -> SocketResult { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - Ok(sockopt::get_tcp_keepidle(view)?.as_nanos() as u64) + Ok(socket.keep_alive_idle_time()?.as_nanos() as u64) } fn set_keep_alive_idle_time( @@ -491,25 +190,13 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get_mut(&this)?; let duration = Duration::from_nanos(value); - { - let view = &*socket.as_std_view()?; - - util::set_tcp_keepidle(view, duration)?; - } - - #[cfg(target_os = "macos")] - { - socket.keep_alive_idle_time = Some(duration); - } - - Ok(()) + socket.set_keep_alive_idle_time(duration) } fn keep_alive_interval(&mut self, this: Resource) -> SocketResult { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - Ok(sockopt::get_tcp_keepintvl(view)?.as_nanos() as u64) + Ok(socket.keep_alive_interval()?.as_nanos() as u64) } fn set_keep_alive_interval( @@ -519,15 +206,13 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { ) -> SocketResult<()> { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - Ok(util::set_tcp_keepintvl(view, Duration::from_nanos(value))?) + socket.set_keep_alive_interval(Duration::from_nanos(value)) } fn keep_alive_count(&mut self, this: Resource) -> SocketResult { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - Ok(sockopt::get_tcp_keepcnt(view)?) + socket.keep_alive_count() } fn set_keep_alive_count( @@ -537,50 +222,26 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { ) -> SocketResult<()> { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - Ok(util::set_tcp_keepcnt(view, value)?) + socket.set_keep_alive_count(value) } fn hop_limit(&mut self, this: Resource) -> SocketResult { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - - let ttl = match socket.family { - SocketAddressFamily::Ipv4 => util::get_ip_ttl(view)?, - SocketAddressFamily::Ipv6 => util::get_ipv6_unicast_hops(view)?, - }; - - Ok(ttl) + socket.hop_limit() } fn set_hop_limit(&mut self, this: Resource, value: u8) -> SocketResult<()> { let table = self.table(); let socket = table.get_mut(&this)?; - { - let view = &*socket.as_std_view()?; - - match socket.family { - SocketAddressFamily::Ipv4 => util::set_ip_ttl(view, value)?, - SocketAddressFamily::Ipv6 => util::set_ipv6_unicast_hops(view, value)?, - } - } - - #[cfg(target_os = "macos")] - { - socket.hop_limit = Some(value); - } - - Ok(()) + socket.set_hop_limit(value) } fn receive_buffer_size(&mut self, this: Resource) -> SocketResult { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - let value = util::get_socket_recv_buffer_size(view)?; - Ok(value as u64) + Ok(socket.receive_buffer_size()? as u64) } fn set_receive_buffer_size( @@ -591,27 +252,14 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get_mut(&this)?; let value = value.try_into().unwrap_or(usize::MAX); - { - let view = &*socket.as_std_view()?; - - util::set_socket_recv_buffer_size(view, value)?; - } - - #[cfg(target_os = "macos")] - { - socket.receive_buffer_size = Some(value); - } - - Ok(()) + socket.set_receive_buffer_size(value) } fn send_buffer_size(&mut self, this: Resource) -> SocketResult { let table = self.table(); let socket = table.get(&this)?; - let view = &*socket.as_std_view()?; - let value = util::get_socket_send_buffer_size(view)?; - Ok(value as u64) + Ok(socket.send_buffer_size()? as u64) } fn set_send_buffer_size( @@ -622,18 +270,7 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get_mut(&this)?; let value = value.try_into().unwrap_or(usize::MAX); - { - let view = &*socket.as_std_view()?; - - util::set_socket_send_buffer_size(view, value)?; - } - - #[cfg(target_os = "macos")] - { - socket.send_buffer_size = Some(value); - } - - Ok(()) + socket.set_send_buffer_size(value) } fn subscribe(&mut self, this: Resource) -> anyhow::Result> { @@ -648,21 +285,12 @@ impl crate::host::tcp::tcp::HostTcpSocket for T { let table = self.table(); let socket = table.get(&this)?; - let stream = match &socket.tcp_state { - TcpState::Connected(stream) => stream, - _ => return Err(ErrorCode::InvalidState.into()), - }; - let how = match shutdown_type { ShutdownType::Receive => std::net::Shutdown::Read, ShutdownType::Send => std::net::Shutdown::Write, ShutdownType::Both => std::net::Shutdown::Both, }; - - stream - .as_socketlike_view::() - .shutdown(how)?; - Ok(()) + Ok(socket.shutdown(how)?) } fn drop(&mut self, this: Resource) -> Result<(), anyhow::Error> { diff --git a/crates/wasi/src/tcp.rs b/crates/wasi/src/tcp.rs index bc970c3bf6..e385874d93 100644 --- a/crates/wasi/src/tcp.rs +++ b/crates/wasi/src/tcp.rs @@ -1,16 +1,24 @@ +use crate::bindings::sockets::tcp::ErrorCode; +use crate::host::network; use crate::network::SocketAddressFamily; use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle}; -use crate::{HostInputStream, HostOutputStream, SocketResult, StreamError, Subscribe}; +use crate::{ + HostInputStream, HostOutputStream, InputStream, OutputStream, SocketResult, StreamError, + Subscribe, +}; use anyhow::{Error, Result}; use cap_net_ext::AddressFamily; use futures::Future; use io_lifetimes::views::SocketlikeView; use io_lifetimes::AsSocketlike; +use rustix::io::Errno; use rustix::net::sockopt; use std::io; use std::mem; +use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; +use std::task::Poll; /// Value taken from rust std library. const DEFAULT_BACKLOG: u32 = 128; @@ -19,7 +27,7 @@ const DEFAULT_BACKLOG: u32 = 128; /// /// This represents the various states a socket can be in during the /// activities of binding, listening, accepting, and connecting. -pub(crate) enum TcpState { +enum TcpState { /// The initial state for a newly-created socket. Default(tokio::net::TcpSocket), @@ -51,30 +59,612 @@ pub(crate) enum TcpState { Closed, } +impl std::fmt::Debug for TcpState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Default(_) => f.debug_tuple("Default").finish(), + Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(), + Self::Bound(_) => f.debug_tuple("Bound").finish(), + Self::ListenStarted(_) => f.debug_tuple("ListenStarted").finish(), + Self::Listening { pending_accept, .. } => f + .debug_struct("Listening") + .field("pending_accept", pending_accept) + .finish(), + Self::Connecting(_) => f.debug_tuple("Connecting").finish(), + Self::ConnectReady(_) => f.debug_tuple("ConnectReady").finish(), + Self::Connected(_) => f.debug_tuple("Connected").finish(), + Self::Closed => write!(f, "Closed"), + } + } +} + /// A host TCP socket, plus associated bookkeeping. -/// -/// The inner state is wrapped in an Arc because the same underlying socket is -/// used for implementing the stream types. pub struct TcpSocket { /// The current state in the bind/listen/accept/connect progression. - pub(crate) tcp_state: TcpState, + tcp_state: TcpState, /// The desired listen queue size. - pub(crate) listen_backlog_size: u32, + listen_backlog_size: u32, - pub(crate) family: SocketAddressFamily, + family: SocketAddressFamily, // The socket options below are not automatically inherited from the listener // on all platforms. So we keep track of which options have been explicitly // set and manually apply those values to newly accepted clients. #[cfg(target_os = "macos")] - pub(crate) receive_buffer_size: Option, + receive_buffer_size: Option, #[cfg(target_os = "macos")] - pub(crate) send_buffer_size: Option, + send_buffer_size: Option, #[cfg(target_os = "macos")] - pub(crate) hop_limit: Option, + hop_limit: Option, #[cfg(target_os = "macos")] - pub(crate) keep_alive_idle_time: Option, + keep_alive_idle_time: Option, +} + +impl TcpSocket { + /// Create a new socket in the given family. + pub fn new(family: AddressFamily) -> io::Result { + with_ambient_tokio_runtime(|| { + let (socket, family) = match family { + AddressFamily::Ipv4 => { + let socket = tokio::net::TcpSocket::new_v4()?; + (socket, SocketAddressFamily::Ipv4) + } + AddressFamily::Ipv6 => { + let socket = tokio::net::TcpSocket::new_v6()?; + sockopt::set_ipv6_v6only(&socket, true)?; + (socket, SocketAddressFamily::Ipv6) + } + }; + + Self::from_state(TcpState::Default(socket), family) + }) + } + + /// Create a `TcpSocket` from an existing socket. + fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result { + Ok(Self { + tcp_state: state, + listen_backlog_size: DEFAULT_BACKLOG, + family, + #[cfg(target_os = "macos")] + receive_buffer_size: None, + #[cfg(target_os = "macos")] + send_buffer_size: None, + #[cfg(target_os = "macos")] + hop_limit: None, + #[cfg(target_os = "macos")] + keep_alive_idle_time: None, + }) + } + + fn as_std_view(&self) -> SocketResult> { + use crate::bindings::sockets::network::ErrorCode; + + match &self.tcp_state { + TcpState::Default(socket) | TcpState::Bound(socket) => { + Ok(socket.as_socketlike_view::()) + } + TcpState::Connected(stream) => Ok(stream.as_socketlike_view::()), + TcpState::Listening { listener, .. } => { + Ok(listener.as_socketlike_view::()) + } + + TcpState::BindStarted(..) + | TcpState::ListenStarted(..) + | TcpState::Connecting(..) + | TcpState::ConnectReady(..) + | TcpState::Closed => Err(ErrorCode::InvalidState.into()), + } + } +} + +impl TcpSocket { + pub fn start_bind(&mut self, local_address: SocketAddr) -> io::Result<()> { + let tokio_socket = match &self.tcp_state { + TcpState::Default(socket) => socket, + TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()), + _ => return Err(Errno::ISCONN.into()), + }; + + network::util::validate_unicast(&local_address)?; + network::util::validate_address_family(&local_address, &self.family)?; + + { + // Automatically bypass the TIME_WAIT state when the user is trying + // to bind to a specific port: + let reuse_addr = local_address.port() > 0; + + // Unconditionally (re)set SO_REUSEADDR, even when the value is false. + // This ensures we're not accidentally affected by any socket option + // state left behind by a previous failed call to this method (start_bind). + network::util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?; + + // Perform the OS bind call. + tokio_socket.bind(local_address).map_err(|error| { + match Errno::from_io_error(&error) { + // From https://pubs.opengroup.org/onlinepubs/9699919799/functions/bind.html: + // > [EAFNOSUPPORT] The specified address is not a valid address for the address family of the specified socket + // + // The most common reasons for this error should have already + // been handled by our own validation slightly higher up in this + // function. This error mapping is here just in case there is + // an edge case we didn't catch. + Some(Errno::AFNOSUPPORT) => io::Error::new( + io::ErrorKind::InvalidInput, + "The specified address is not a valid address for the address family of the specified socket", + ), + + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-bind#:~:text=WSAENOBUFS + // Windows returns WSAENOBUFS when the ephemeral ports have been exhausted. + #[cfg(windows)] + Some(Errno::NOBUFS) => io::Error::new(io::ErrorKind::AddrInUse, "no more free local ports"), + + _ => error, + } + })?; + + self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::Default(socket) => TcpState::BindStarted(socket), + _ => unreachable!(), + }; + + Ok(()) + } + } + + pub fn finish_bind(&mut self) -> SocketResult<()> { + match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::BindStarted(socket) => { + self.tcp_state = TcpState::Bound(socket); + Ok(()) + } + current_state => { + // Reset the state so that the outside world doesn't see this socket as closed + self.tcp_state = current_state; + Err(ErrorCode::NotInProgress.into()) + } + } + } + + pub fn start_connect(&mut self, remote_address: SocketAddr) -> SocketResult<()> { + match self.tcp_state { + TcpState::Default(..) => {} + + TcpState::Connecting(..) | TcpState::ConnectReady(..) => { + return Err(ErrorCode::ConcurrencyConflict.into()) + } + + _ => return Err(ErrorCode::InvalidState.into()), + }; + + network::util::validate_unicast(&remote_address)?; + network::util::validate_remote_address(&remote_address)?; + network::util::validate_address_family(&remote_address, &self.family)?; + + let TcpState::Default(tokio_socket) = + std::mem::replace(&mut self.tcp_state, TcpState::Closed) + else { + unreachable!(); + }; + + let future = tokio_socket.connect(remote_address); + + self.tcp_state = TcpState::Connecting(Box::pin(future)); + Ok(()) + } + + pub fn finish_connect(&mut self) -> SocketResult<(InputStream, OutputStream)> { + let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed); + let result = match previous_state { + TcpState::ConnectReady(result) => result, + TcpState::Connecting(mut future) => { + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); + match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) { + Poll::Ready(result) => result, + Poll::Pending => { + self.tcp_state = TcpState::Connecting(future); + return Err(ErrorCode::WouldBlock.into()); + } + } + } + previous_state => { + self.tcp_state = previous_state; + return Err(ErrorCode::NotInProgress.into()); + } + }; + + match result { + Ok(stream) => { + let stream = Arc::new(stream); + self.tcp_state = TcpState::Connected(stream.clone()); + let input: InputStream = + InputStream::Host(Box::new(TcpReadStream::new(stream.clone()))); + let output: OutputStream = Box::new(TcpWriteStream::new(stream)); + Ok((input, output)) + } + Err(err) => { + self.tcp_state = TcpState::Closed; + Err(err.into()) + } + } + } + + pub fn start_listen(&mut self) -> SocketResult<()> { + match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::Bound(tokio_socket) => { + self.tcp_state = TcpState::ListenStarted(tokio_socket); + Ok(()) + } + TcpState::ListenStarted(tokio_socket) => { + self.tcp_state = TcpState::ListenStarted(tokio_socket); + Err(ErrorCode::ConcurrencyConflict.into()) + } + previous_state => { + self.tcp_state = previous_state; + Err(ErrorCode::InvalidState.into()) + } + } + } + + pub fn finish_listen(&mut self) -> SocketResult<()> { + let tokio_socket = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) { + TcpState::ListenStarted(tokio_socket) => tokio_socket, + previous_state => { + self.tcp_state = previous_state; + return Err(ErrorCode::NotInProgress.into()); + } + }; + + match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) { + Ok(listener) => { + self.tcp_state = TcpState::Listening { + listener, + pending_accept: None, + }; + Ok(()) + } + Err(err) => { + self.tcp_state = TcpState::Closed; + + Err(match Errno::from_io_error(&err) { + // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-listen#:~:text=WSAEMFILE + // According to the docs, `listen` can return EMFILE on Windows. + // This is odd, because we're not trying to create a new socket + // or file descriptor of any kind. So we rewrite it to less + // surprising error code. + // + // At the time of writing, this behavior has never been experimentally + // observed by any of the wasmtime authors, so we're relying fully + // on Microsoft's documentation here. + #[cfg(windows)] + Some(Errno::MFILE) => Errno::NOBUFS.into(), + + _ => err.into(), + }) + } + } + } + + pub fn accept(&mut self) -> SocketResult<(Self, InputStream, OutputStream)> { + let TcpState::Listening { + listener, + pending_accept, + } = &mut self.tcp_state + else { + return Err(ErrorCode::InvalidState.into()); + }; + + let result = match pending_accept.take() { + Some(result) => result, + None => { + let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref()); + match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx)) + .map_ok(|(stream, _)| stream) + { + Poll::Ready(result) => result, + Poll::Pending => Err(Errno::WOULDBLOCK.into()), + } + } + }; + + let client = result.map_err(|err| match Errno::from_io_error(&err) { + // From: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-accept#:~:text=WSAEINPROGRESS + // > WSAEINPROGRESS: A blocking Windows Sockets 1.1 call is in progress, + // > or the service provider is still processing a callback function. + // + // wasi-sockets doesn't have an equivalent to the EINPROGRESS error, + // because in POSIX this error is only returned by a non-blocking + // `connect` and wasi-sockets has a different solution for that. + #[cfg(windows)] + Some(Errno::INPROGRESS) => Errno::INTR.into(), + + // Normalize Linux' non-standard behavior. + // + // From https://man7.org/linux/man-pages/man2/accept.2.html: + // > Linux accept() passes already-pending network errors on the + // > new socket as an error code from accept(). This behavior + // > differs from other BSD socket implementations. (...) + #[cfg(target_os = "linux")] + Some( + Errno::CONNRESET + | Errno::NETRESET + | Errno::HOSTUNREACH + | Errno::HOSTDOWN + | Errno::NETDOWN + | Errno::NETUNREACH + | Errno::PROTO + | Errno::NOPROTOOPT + | Errno::NONET + | Errno::OPNOTSUPP, + ) => Errno::CONNABORTED.into(), + + _ => err, + })?; + + #[cfg(target_os = "macos")] + { + // Manually inherit socket options from listener. We only have to + // do this on platforms that don't already do this automatically + // and only if a specific value was explicitly set on the listener. + + if let Some(size) = self.receive_buffer_size { + _ = network::util::set_socket_recv_buffer_size(&client, size); // Ignore potential error. + } + + if let Some(size) = self.send_buffer_size { + _ = network::util::set_socket_send_buffer_size(&client, size); // Ignore potential error. + } + + // For some reason, IP_TTL is inherited, but IPV6_UNICAST_HOPS isn't. + if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) { + _ = network::util::set_ipv6_unicast_hops(&client, ttl); // Ignore potential error. + } + + if let Some(value) = self.keep_alive_idle_time { + _ = network::util::set_tcp_keepidle(&client, value); // Ignore potential error. + } + } + + let client = Arc::new(client); + + let input: InputStream = InputStream::Host(Box::new(TcpReadStream::new(client.clone()))); + let output: OutputStream = Box::new(TcpWriteStream::new(client.clone())); + let tcp_socket = TcpSocket::from_state(TcpState::Connected(client), self.family)?; + + Ok((tcp_socket, input, output)) + } + + pub fn local_address(&self) -> SocketResult { + let view = match self.tcp_state { + TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()), + TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()), + _ => self.as_std_view()?, + }; + + Ok(view.local_addr()?) + } + + pub fn remote_address(&self) -> SocketResult { + let view = match self.tcp_state { + TcpState::Connected(..) => self.as_std_view()?, + TcpState::Connecting(..) | TcpState::ConnectReady(..) => { + return Err(ErrorCode::ConcurrencyConflict.into()) + } + _ => return Err(ErrorCode::InvalidState.into()), + }; + + Ok(view.peer_addr()?) + } + + pub fn is_listening(&self) -> bool { + matches!(self.tcp_state, TcpState::Listening { .. }) + } + + pub fn address_family(&self) -> SocketAddressFamily { + self.family + } + + pub fn set_listen_backlog_size(&mut self, value: u32) -> SocketResult<()> { + const MIN_BACKLOG: u32 = 1; + const MAX_BACKLOG: u32 = i32::MAX as u32; // OS'es will most likely limit it down even further. + + if value == 0 { + return Err(ErrorCode::InvalidArgument.into()); + } + + // Silently clamp backlog size. This is OK for us to do, because operating systems do this too. + let value = value.clamp(MIN_BACKLOG, MAX_BACKLOG); + + match &self.tcp_state { + TcpState::Default(..) | TcpState::Bound(..) => { + // Socket not listening yet. Stash value for first invocation to `listen`. + } + TcpState::Listening { listener, .. } => { + // Try to update the backlog by calling `listen` again. + // Not all platforms support this. We'll only update our own value if the OS supports changing the backlog size after the fact. + + rustix::net::listen(&listener, value.try_into().unwrap()) + .map_err(|_| ErrorCode::NotSupported)?; + } + _ => return Err(ErrorCode::InvalidState.into()), + } + self.listen_backlog_size = value; + + Ok(()) + } + + pub fn keep_alive_enabled(&self) -> SocketResult { + let view = &*self.as_std_view()?; + Ok(sockopt::get_socket_keepalive(view)?) + } + + pub fn set_keep_alive_enabled(&self, value: bool) -> SocketResult<()> { + let view = &*self.as_std_view()?; + Ok(sockopt::set_socket_keepalive(view, value)?) + } + + pub fn keep_alive_idle_time(&self) -> SocketResult { + let view = &*self.as_std_view()?; + Ok(sockopt::get_tcp_keepidle(view)?) + } + + pub fn set_keep_alive_idle_time(&mut self, duration: std::time::Duration) -> SocketResult<()> { + { + let view = &*self.as_std_view()?; + network::util::set_tcp_keepidle(view, duration)?; + } + + #[cfg(target_os = "macos")] + { + self.keep_alive_idle_time = Some(duration); + } + + Ok(()) + } + + pub fn keep_alive_interval(&self) -> SocketResult { + let view = &*self.as_std_view()?; + Ok(sockopt::get_tcp_keepintvl(view)?) + } + + pub fn set_keep_alive_interval(&self, duration: std::time::Duration) -> SocketResult<()> { + let view = &*self.as_std_view()?; + Ok(network::util::set_tcp_keepintvl(view, duration)?) + } + + pub fn keep_alive_count(&self) -> SocketResult { + let view = &*self.as_std_view()?; + Ok(sockopt::get_tcp_keepcnt(view)?) + } + + pub fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> { + // TODO(rylev): do we need to check and clamp the value? + + let view = &*self.as_std_view()?; + Ok(network::util::set_tcp_keepcnt(view, value)?) + } + + pub fn hop_limit(&self) -> SocketResult { + let view = &*self.as_std_view()?; + + let ttl = match self.family { + SocketAddressFamily::Ipv4 => network::util::get_ip_ttl(view)?, + SocketAddressFamily::Ipv6 => network::util::get_ipv6_unicast_hops(view)?, + }; + + Ok(ttl) + } + + pub fn set_hop_limit(&mut self, value: u8) -> SocketResult<()> { + { + let view = &*self.as_std_view()?; + + match self.family { + SocketAddressFamily::Ipv4 => network::util::set_ip_ttl(view, value)?, + SocketAddressFamily::Ipv6 => network::util::set_ipv6_unicast_hops(view, value)?, + } + } + + #[cfg(target_os = "macos")] + { + self.hop_limit = Some(value); + } + + Ok(()) + } + + pub fn receive_buffer_size(&self) -> SocketResult { + let view = &*self.as_std_view()?; + + Ok(network::util::get_socket_recv_buffer_size(view)?) + } + + pub fn set_receive_buffer_size(&mut self, value: usize) -> SocketResult<()> { + { + let view = &*self.as_std_view()?; + + network::util::set_socket_recv_buffer_size(view, value)?; + } + + #[cfg(target_os = "macos")] + { + self.receive_buffer_size = Some(value); + } + + Ok(()) + } + + pub fn send_buffer_size(&self) -> SocketResult { + let view = &*self.as_std_view()?; + + Ok(network::util::get_socket_send_buffer_size(view)?) + } + + pub fn set_send_buffer_size(&mut self, value: usize) -> SocketResult<()> { + { + let view = &*self.as_std_view()?; + + network::util::set_socket_send_buffer_size(view, value)?; + } + + #[cfg(target_os = "macos")] + { + self.send_buffer_size = Some(value); + } + + Ok(()) + } + + pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> { + let stream = match &self.tcp_state { + TcpState::Connected(stream) => stream, + _ => { + return Err(io::Error::new( + io::ErrorKind::NotConnected, + "socket not connected", + )) + } + }; + + stream + .as_socketlike_view::() + .shutdown(how)?; + Ok(()) + } +} + +#[async_trait::async_trait] +impl Subscribe for TcpSocket { + async fn ready(&mut self) { + match &mut self.tcp_state { + TcpState::Default(..) + | TcpState::BindStarted(..) + | TcpState::Bound(..) + | TcpState::ListenStarted(..) + | TcpState::ConnectReady(..) + | TcpState::Closed + | TcpState::Connected(..) => { + // No async operation in progress. + } + TcpState::Connecting(future) => { + self.tcp_state = TcpState::ConnectReady(future.as_mut().await); + } + TcpState::Listening { + listener, + pending_accept, + } => match pending_accept { + Some(_) => {} + None => { + let result = futures::future::poll_fn(|cx| { + listener.poll_accept(cx).map_ok(|(stream, _)| stream) + }) + .await; + *pending_accept = Some(result); + } + }, + } + } } pub(crate) struct TcpReadStream { @@ -255,94 +845,3 @@ impl Subscribe for TcpWriteStream { } } } - -impl TcpSocket { - /// Create a new socket in the given family. - pub fn new(family: AddressFamily) -> io::Result { - with_ambient_tokio_runtime(|| { - let (socket, family) = match family { - AddressFamily::Ipv4 => { - let socket = tokio::net::TcpSocket::new_v4()?; - (socket, SocketAddressFamily::Ipv4) - } - AddressFamily::Ipv6 => { - let socket = tokio::net::TcpSocket::new_v6()?; - sockopt::set_ipv6_v6only(&socket, true)?; - (socket, SocketAddressFamily::Ipv6) - } - }; - - Self::from_state(TcpState::Default(socket), family) - }) - } - - /// Create a `TcpSocket` from an existing socket. - pub(crate) fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result { - Ok(Self { - tcp_state: state, - listen_backlog_size: DEFAULT_BACKLOG, - family, - #[cfg(target_os = "macos")] - receive_buffer_size: None, - #[cfg(target_os = "macos")] - send_buffer_size: None, - #[cfg(target_os = "macos")] - hop_limit: None, - #[cfg(target_os = "macos")] - keep_alive_idle_time: None, - }) - } - - pub(crate) fn as_std_view(&self) -> SocketResult> { - use crate::bindings::sockets::network::ErrorCode; - - match &self.tcp_state { - TcpState::Default(socket) | TcpState::Bound(socket) => { - Ok(socket.as_socketlike_view::()) - } - TcpState::Connected(stream) => Ok(stream.as_socketlike_view::()), - TcpState::Listening { listener, .. } => { - Ok(listener.as_socketlike_view::()) - } - - TcpState::BindStarted(..) - | TcpState::ListenStarted(..) - | TcpState::Connecting(..) - | TcpState::ConnectReady(..) - | TcpState::Closed => Err(ErrorCode::InvalidState.into()), - } - } -} - -#[async_trait::async_trait] -impl Subscribe for TcpSocket { - async fn ready(&mut self) { - match &mut self.tcp_state { - TcpState::Default(..) - | TcpState::BindStarted(..) - | TcpState::Bound(..) - | TcpState::ListenStarted(..) - | TcpState::ConnectReady(..) - | TcpState::Closed - | TcpState::Connected(..) => { - // No async operation in progress. - } - TcpState::Connecting(future) => { - self.tcp_state = TcpState::ConnectReady(future.as_mut().await); - } - TcpState::Listening { - listener, - pending_accept, - } => match pending_accept { - Some(_) => {} - None => { - let result = futures::future::poll_fn(|cx| { - listener.poll_accept(cx).map_ok(|(stream, _)| stream) - }) - .await; - *pending_accept = Some(result); - } - }, - } - } -}