From c37e5e0721085199333ccf2f24941be4947bc1a7 Mon Sep 17 00:00:00 2001 From: Manos Pitsidianakis Date: Fri, 11 Aug 2023 21:03:32 +0300 Subject: [PATCH] =?UTF-8?q?melib/connections:=20use=20Happy=20Eyeballs=20a?= =?UTF-8?q?lgorithm=20=EA=99=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds a Happy Eyeballs [1] implementation taken from the happy-eyeballs crate, which is in public domain. While the function lookup_ip[0] iterates through the addresses returned by A and AAAA records from a DNS lookup, it returns the first one which always is an IPv4 address, unless there only is an AAAA record. RFC6555 [1] recommends an algorithm for choosing the fastest address to connect to, called "Happy Eyeballs". Ꙭ [0]: melib/src/utils/connections.rs:497 [1]: https://www.rfc-editor.org/rfc/rfc6555 Fixes #268 --- Cargo.lock | 2 + melib/Cargo.toml | 2 + melib/src/imap/connection.rs | 19 +- melib/src/nntp/connection.rs | 11 +- melib/src/smtp.rs | 20 +- melib/src/utils/connections.rs | 7 + melib/src/utils/connections/smol.rs | 304 +++++++++++++++++++ melib/src/utils/connections/std_net.rs | 393 +++++++++++++++++++++++++ 8 files changed, 728 insertions(+), 30 deletions(-) create mode 100644 melib/src/utils/connections/smol.rs create mode 100644 melib/src/utils/connections/std_net.rs diff --git a/Cargo.lock b/Cargo.lock index 75edbb87..ceccd340 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1279,6 +1279,7 @@ dependencies = [ "nix", "nom", "notify", + "polling", "regex", "rusqlite", "serde", @@ -1287,6 +1288,7 @@ dependencies = [ "serde_path_to_error", "smallvec", "smol", + "socket2", "stderrlog", "unicode-segmentation", "uuid", diff --git a/melib/Cargo.toml b/melib/Cargo.toml index 7fd97ef6..a5f191c4 100644 --- a/melib/Cargo.toml +++ b/melib/Cargo.toml @@ -38,6 +38,7 @@ native-tls = { version = "0.2.3", default-features = false, optional = true } nix = "^0.24" nom = { version = "7" } notify = { version = "4.0.15", optional = true } +polling = "2.8" regex = { version = "1" } rusqlite = { version = "^0.28", default-features = false, optional = true } serde = { version = "1.0", features = ["rc", ] } @@ -46,6 +47,7 @@ serde_json = { version = "1.0", features = ["raw_value",] } serde_path_to_error = { version = "0.1" } smallvec = { version = "^1.5.0", features = ["serde", ] } smol = "1.0.0" +socket2 = { version = "0.4", features = [] } unicode-segmentation = { version = "1.2.1", default-features = false, optional = true } uuid = { version = "^1", features = ["serde", "v4", "v5"] } diff --git a/melib/src/imap/connection.rs b/melib/src/imap/connection.rs index 137ba49c..1f171469 100644 --- a/melib/src/imap/connection.rs +++ b/melib/src/imap/connection.rs @@ -25,7 +25,7 @@ use crate::{ email::parser::BytesExt, error::*, utils::{ - connections::{lookup_ip, Connection}, + connections::{std_net::connect as tcp_stream_connect, Connection}, futures::timeout, }, LogLevel, @@ -158,7 +158,6 @@ impl ImapStream { #[cfg(debug_assertions)] id: Cow<'static, str>, uid_store: &UIDStore, ) -> Result<(Capabilities, Self)> { - use std::net::TcpStream; let path = &server_conf.server_hostname; let cmd_id = 1; @@ -177,14 +176,10 @@ impl ImapStream { .build() .chain_err_kind(ErrorKind::Network(NetworkErrorKind::InvalidTLSConnection))?; - let addr = lookup_ip(path, server_conf.server_port)?; + let addr = (path.as_str(), server_conf.server_port); let mut socket = AsyncWrapper::new({ - let conn = Connection::new_tcp(if let Some(timeout) = server_conf.timeout { - TcpStream::connect_timeout(&addr, timeout)? - } else { - TcpStream::connect(addr)? - }); + let conn = Connection::new_tcp(tcp_stream_connect(addr, server_conf.timeout)?); #[cfg(feature = "imap-trace")] let conn = conn.trace(true).with_id("imap"); conn @@ -280,13 +275,9 @@ impl ImapStream { .chain_err_summary(|| format!("Could not initiate TLS negotiation to {}.", path))? } } else { - let addr = lookup_ip(path, server_conf.server_port)?; + let addr = (path.as_str(), server_conf.server_port); AsyncWrapper::new({ - let conn = Connection::new_tcp(if let Some(timeout) = server_conf.timeout { - TcpStream::connect_timeout(&addr, timeout)? - } else { - TcpStream::connect(addr)? - }); + let conn = Connection::new_tcp(tcp_stream_connect(addr, server_conf.timeout)?); #[cfg(feature = "imap-trace")] let conn = conn.trace(true).with_id("imap"); conn diff --git a/melib/src/nntp/connection.rs b/melib/src/nntp/connection.rs index 30a76456..a06cf56c 100644 --- a/melib/src/nntp/connection.rs +++ b/melib/src/nntp/connection.rs @@ -24,7 +24,7 @@ use crate::{ email::parser::BytesExt, error::*, log, - utils::connections::{lookup_ip, Connection}, + utils::connections::{std_net::connect as tcp_stream_connect, Connection}, }; extern crate native_tls; use std::{collections::HashSet, future::Future, pin::Pin, sync::Arc, time::Instant}; @@ -74,15 +74,14 @@ pub struct NntpConnection { impl NntpStream { pub async fn new_connection(server_conf: &NntpServerConf) -> Result<(Capabilities, Self)> { - use std::net::TcpStream; let path = &server_conf.server_hostname; let stream = { - let addr = lookup_ip(path, server_conf.server_port)?; + let addr = (path.as_str(), server_conf.server_port); AsyncWrapper::new({ - let conn = Connection::new_tcp(TcpStream::connect_timeout( - &addr, - std::time::Duration::new(16, 0), + let conn = Connection::new_tcp(tcp_stream_connect( + addr, + Some(std::time::Duration::new(16, 0)), )?); #[cfg(feature = "nntp-trace")] let conn = conn.trace(true).with_id("nntp"); diff --git a/melib/src/smtp.rs b/melib/src/smtp.rs index 2db5b6ba..155fddfb 100644 --- a/melib/src/smtp.rs +++ b/melib/src/smtp.rs @@ -72,7 +72,7 @@ //! Ok(()) //! ``` -use std::{borrow::Cow, convert::TryFrom, net::TcpStream, process::Command}; +use std::{borrow::Cow, convert::TryFrom, process::Command}; use futures::io::{AsyncReadExt, AsyncWriteExt}; use native_tls::TlsConnector; @@ -82,7 +82,7 @@ use smol::{unblock, Async as AsyncWrapper}; use crate::{ email::{parser::BytesExt, Address, Envelope}, error::{Error, Result, ResultIntoError}, - utils::connections::{lookup_ip, Connection}, + utils::connections::{std_net::connect as tcp_stream_connect, Connection}, }; /// Kind of server security (StartTLS/TLS/None) the client should attempt @@ -261,11 +261,11 @@ impl SmtpConnection { } let connector = connector.build()?; - let addr = lookup_ip(path, server_conf.port)?; + let addr = (path.as_str(), server_conf.port); let mut socket = { - let conn = Connection::new_tcp(TcpStream::connect_timeout( - &addr, - std::time::Duration::new(4, 0), + let conn = Connection::new_tcp(tcp_stream_connect( + addr, + Some(std::time::Duration::new(4, 0)), )?); #[cfg(feature = "smtp-trace")] let conn = conn.trace(true).with_id("smtp"); @@ -373,11 +373,11 @@ impl SmtpConnection { ret } SmtpSecurity::None => { - let addr = lookup_ip(path, server_conf.port)?; + let addr = (path.as_str(), server_conf.port); let mut ret = AsyncWrapper::new({ - let conn = Connection::new_tcp(TcpStream::connect_timeout( - &addr, - std::time::Duration::new(4, 0), + let conn = Connection::new_tcp(tcp_stream_connect( + addr, + Some(std::time::Duration::new(4, 0)), )?); #[cfg(feature = "smtp-trace")] let conn = conn.trace(true).with_id("smtp"); diff --git a/melib/src/utils/connections.rs b/melib/src/utils/connections.rs index f2695533..a5d5b60b 100644 --- a/melib/src/utils/connections.rs +++ b/melib/src/utils/connections.rs @@ -38,6 +38,11 @@ use libc::TCP_KEEPALIVE as KEEPALIVE_OPTION; use libc::TCP_KEEPIDLE as KEEPALIVE_OPTION; use libc::{self, c_int, c_void}; +// pub mod smol; +pub mod std_net; + +pub const CONNECTION_ATTEMPT_DELAY: std::time::Duration = std::time::Duration::from_millis(250); + pub enum Connection { Tcp { inner: std::net::TcpStream, @@ -494,6 +499,8 @@ impl std::os::unix::io::AsRawFd for Connection { } } +#[deprecated = "While it supports IPv6, it does not implement the happy eyeballs algorithm. Use \ + {std_net,smol}::tcp_stream_connect instead."] pub fn lookup_ip(host: &str, port: u16) -> crate::Result { use std::net::ToSocketAddrs; diff --git a/melib/src/utils/connections/smol.rs b/melib/src/utils/connections/smol.rs new file mode 100644 index 00000000..63643446 --- /dev/null +++ b/melib/src/utils/connections/smol.rs @@ -0,0 +1,304 @@ +/* + * Copyright (C) 2023 by Kim Minh Kaplan + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR + * IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +use std::{io::Result, sync::mpsc}; + +use smol::net::{AsyncToSocketAddrs, SocketAddr, TcpStream}; + +use super::CONNECTION_ATTEMPT_DELAY; + +pub async fn connect(addr: A) -> Result { + // modify the ordered list to interleave address families. + let (addrs_v4, addrs_v6): (Vec<_>, Vec<_>) = { + let fut = addr.to_socket_addrs(); + fut.await?.partition(|a| match a { + SocketAddr::V4(_) => true, + SocketAddr::V6(_) => false, + }) + }; + let mut addrs = Vec::with_capacity(addrs_v4.len() + addrs_v6.len()); + let (mut left, mut right) = (addrs_v6.into_iter(), addrs_v4.into_iter()); + while let Some(a) = left.next() { + addrs.push(a); + std::mem::swap(&mut left, &mut right); + } + addrs.extend(right); + + let (tx, rx) = mpsc::channel(); + let mut last_error = None; + let mut attempts = Vec::new(); + let mut attempts_count = 0; + for a in addrs { + attempts.push(smol::spawn({ + let tx = tx.clone(); + Box::pin(async move { + let res = TcpStream::connect(a).await; + tx.send(res).expect("channel is available"); + }) + })); + attempts_count += 1; + let recv = rx.recv_timeout(CONNECTION_ATTEMPT_DELAY); + match recv { + Ok(Ok(tcp)) => { + for t in attempts { + t.cancel().await; + } + return Ok(tcp); + } + Ok(Err(error)) => { + last_error = Some(error); + attempts_count -= 1; + } + Err(mpsc::RecvTimeoutError::Timeout) => (), + Err(error) => unreachable!("{}", error), + } + } + drop(tx); + + while attempts_count > 0 { + let res = rx.recv(); + match res { + Ok(Ok(tcp)) => { + for t in attempts { + t.cancel().await; + } + return Ok(tcp); + } + Ok(Err(error)) => { + last_error = Some(error); + attempts_count -= 1; + } + Err(error) => unreachable!("{}", error), + } + } + Err(last_error.unwrap_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "could not resolve to any address", + ) + })) +} + +/* +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{serve_4, serve_6, tar_pit}; + use rand::{thread_rng, Rng}; + use smol::io::ReadExt; + use std::net::{Ipv4Addr, Ipv6Addr}; + + #[async_std::test] + async fn test_no_ipv4() { + let port = thread_rng().gen_range(49152..=65535); + assert!(connect((Ipv4Addr::LOCALHOST, port)).await.is_err()); + } + + #[async_std::test] + async fn test_connect_ipv4() { + let (_serve, addr, port) = serve_4(); + let mut data = String::new(); + connect((addr, port)) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, format!("{addr}")); + assert!(connect((Ipv6Addr::LOCALHOST, port)).await.is_err()); + } + + #[async_std::test] + async fn test_no_ipv6() { + let port = thread_rng().gen_range(49152..=65535); + assert!(connect((Ipv6Addr::LOCALHOST, port)).await.is_err()); + } + + #[async_std::test] + async fn test_connect_ipv6() { + let (_serve, addr, port) = serve_6(); + let mut data = String::new(); + connect((addr, port)) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, format!("{addr}")); + assert!(connect((Ipv4Addr::LOCALHOST, port)).await.is_err()); + } + + #[async_std::test] + async fn test_connect_no_6_but_4() { + let (_serve, addr, port) = serve_4(); + let expect = format!("{addr}"); + let saddr6: SocketAddr = (Ipv6Addr::LOCALHOST, port).into(); + let saddr4: SocketAddr = (Ipv4Addr::LOCALHOST, port).into(); + let mut data = String::new(); + { + let saddrs = &[saddr4, saddr6][..]; + data.clear(); + connect(saddrs) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, expect); + } + { + let saddrs = &[saddr6, saddr4][..]; + data.clear(); + connect(saddrs) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, expect); + } + } + + #[async_std::test] + async fn test_connect_no_4_but_6() { + let (_serve, addr, port) = serve_6(); + let expect = format!("{addr}"); + let saddr6: SocketAddr = (Ipv6Addr::LOCALHOST, port).into(); + let saddr4: SocketAddr = (Ipv4Addr::LOCALHOST, port).into(); + let mut data = String::new(); + { + let saddrs = &[saddr4, saddr6][..]; + data.clear(); + connect(saddrs) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, expect); + } + { + let saddrs = &[saddr6, saddr4][..]; + data.clear(); + connect(saddrs) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, expect); + } + } + + #[async_std::test] + async fn test_connect() { + let (_serve4, addr4, port4) = serve_4(); + let (_serve6, addr6, port6) = serve_6(); + let saddr4: SocketAddr = (addr4, port4).into(); + let saddr6: SocketAddr = (addr6, port6).into(); + let mut data = String::new(); + + data.clear(); + connect(&[saddr6, saddr4][..]) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, format!("{addr6}")); + + data.clear(); + connect(&[saddr4, saddr6][..]) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + // IPv6 is preferred + assert_eq!(data, format!("{addr6}")); + } + + #[async_std::test] + async fn test_connect_tar_pit4() { + let (_serve4, addr4, port4) = tar_pit(Ipv4Addr::LOCALHOST); + let (_serve6, addr6, port6) = serve_6(); + let saddr4: SocketAddr = (addr4, port4).into(); + let saddr6: SocketAddr = (addr6, port6).into(); + let mut data = String::new(); + + data.clear(); + connect(&[saddr4, saddr6][..]) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, format!("{addr6}")); + + data.clear(); + connect(&[saddr6, saddr4][..]) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, format!("{addr6}")); + } + + #[async_std::test] + async fn test_connect_tar_pit6() { + let (_serve4, addr4, port4) = serve_4(); + let (_serve6, addr6, port6) = tar_pit(Ipv6Addr::LOCALHOST); + let saddr4: SocketAddr = (addr4, port4).into(); + let saddr6: SocketAddr = (addr6, port6).into(); + let mut data = String::new(); + + data.clear(); + let mut cnx = connect(&[saddr4, saddr6][..]).await.unwrap(); + cnx.read_to_string(&mut data).await.unwrap(); + assert_eq!(data, format!("{addr4}")); + + data.clear(); + connect(&[saddr6, saddr4][..]) + .await + .unwrap() + .read_to_string(&mut data) + .await + .unwrap(); + assert_eq!(data, format!("{addr4}")); + } + + #[async_std::test] + async fn test_connect_tar_pit_all() { + let (_serve4, addr4, port4) = tar_pit(Ipv4Addr::LOCALHOST); + let (_serve6, addr6, port6) = tar_pit(Ipv6Addr::LOCALHOST); + let saddr4: SocketAddr = (addr4, port4).into(); + let saddr6: SocketAddr = (addr6, port6).into(); + assert_eq!( + connect(&[saddr4, saddr6][..]).await.unwrap_err().kind(), + std::io::ErrorKind::TimedOut + ); + } + + #[async_std::test] + async fn test_connect_empty() { + let empty = &[][..]; + assert_eq!( + connect(empty).await.unwrap_err().kind(), + std::io::ErrorKind::InvalidInput + ); + } +} +*/ diff --git a/melib/src/utils/connections/std_net.rs b/melib/src/utils/connections/std_net.rs new file mode 100644 index 00000000..636a4755 --- /dev/null +++ b/melib/src/utils/connections/std_net.rs @@ -0,0 +1,393 @@ +/* + * Copyright (C) 2023 by Kim Minh Kaplan + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR + * IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +use std::{ + io::{Error, ErrorKind, Result}, + net::{SocketAddr, TcpStream, ToSocketAddrs}, + os::fd::AsRawFd, + time::{Duration, Instant}, +}; + +use polling::{Event, Poller}; +use socket2::{Domain, SockAddr, Socket}; + +/// Opens a TCP connection to a remote host. +/// +/// If `addr` yelds multiple addresses, `connect` uses the algorithm +/// described in [RFC 8305 Happy Eyeballs Version 2: Better +/// Connectivity Using +/// Concurrency](https://datatracker.ietf.org/doc/html/rfc8305) to +/// connect. +/// +/// Examples +/// ======== +/// +/// ```no_run +/// let socket_addr = ("www.example", 80); +/// let tcp_stream = connect(socket_addr, None)?; +/// # Ok::<(), std::io::Error>(()) +/// ``` +pub fn connect(addr: A, timeout: Option) -> Result { + let mut happy = HappyEyeballs::new()?; + let start = Instant::now(); + let timeout_left = || -> Result> { + let Some(v) = timeout else { + return Ok(None); + }; + Ok(Some(v.checked_sub(Instant::now() - start).ok_or_else( + || Error::new(ErrorKind::TimedOut, "Connection timed out."), + )?)) + }; + + for a in prepare_addresses(addr)? { + log::trace!("a = {:?}", a); + match happy.add(a.into(), Domain::for_address(a)) { + AddOutcome::Connected(tcp) => return Ok(tcp), + AddOutcome::Error => continue, + AddOutcome::InProgress => (), + } + if let Some(sock) = happy.poll_once(timeout_left()?)? { + return Ok(sock); + } + } + + while !happy.is_empty() { + if let Some(sock) = happy.poll_once(timeout_left()?)? { + return Ok(sock); + } + } + Err(happy + .error + .unwrap_or_else(|| Error::new(ErrorKind::InvalidInput, "could not resolve to any address"))) +} + +/// Resolve addresses and order them to alternate between IPv6 and IPv4. +fn prepare_addresses(addr: A) -> Result> +where + A: ToSocketAddrs, +{ + let (addrs_v4, addrs_v6): (Vec<_>, Vec<_>) = addr.to_socket_addrs()?.partition(|a| match a { + SocketAddr::V4(_) => true, + SocketAddr::V6(_) => false, + }); + log::trace!("prepare_addresses 4 = {:?} 6 = {:?}", addrs_v4, addrs_v6); + let mut addrs = Vec::with_capacity(addrs_v4.len() + addrs_v6.len()); + let (mut left, mut right) = (addrs_v6.into_iter(), addrs_v4.into_iter()); + while let Some(a) = left.next() { + addrs.push(a); + std::mem::swap(&mut left, &mut right); + } + addrs.extend(right); + Ok(addrs) +} + +struct HappyEyeballs { + poller: Poller, + error: Option, + attempts: Vec>, + attempts_in_progress: usize, +} + +impl HappyEyeballs { + fn new() -> Result { + Ok(Self { + poller: Poller::new()?, + error: None, + attempts: Vec::new(), + attempts_in_progress: 0, + }) + } + + fn is_empty(&self) -> bool { + self.attempts_in_progress == 0 + } + + fn set_error(&mut self, error: Error) { + self.error.get_or_insert(error); + } + + // Initiate a fresh non-blocking TCP connection to `saddr`. + // + // Returns `AddOutcome::InProgress`. + // + // If there is an error concerning this particular connection + // attempt, return `AddOutcome::Error` and the error code is + // remembered in `self` if it is the first to occur; it will be + // the one returned if no connection can be established at all. + // + // If the connection succeeds immediatly, returns + // `AddOutcome::Connected(stream)`. Because of non-blocking and + // the way TCP works, this should *not* happen. + fn add(&mut self, saddr: SockAddr, domain: Domain) -> AddOutcome { + let sock = match Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP)) + .and_then(|sock| { + sock.set_nonblocking(true)?; + Ok(sock) + }) { + Ok(sock) => sock, + Err(e) => { + self.set_error(e); + return AddOutcome::Error; + } + }; + match sock.connect(&saddr) { + Ok(()) => match sock.set_nonblocking(false) { + Ok(()) => return AddOutcome::Connected(sock.into()), + Err(e) => self.set_error(e), + }, + Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => { + let interest = Event::writable(self.attempts.len()); + match self.poller.add(&sock, interest) { + Ok(()) => { + self.attempts.push(Some((saddr, sock))); + self.attempts_in_progress += 1; + return AddOutcome::InProgress; + } + Err(e) => self.set_error(e), + } + } + Err(e) => self.set_error(e), + } + AddOutcome::Error + } + + fn poll_once(&mut self, timeout: Option) -> Result> { + let mut events = Vec::new(); + self.poller.wait(&mut events, timeout)?; + for evt in &events { + log::trace!("poll_once evt = {:?}", evt); + assert!(evt.writable); + let (sock_addr, sock) = self.attempts[evt.key].take().expect("attempt exists"); + self.attempts_in_progress -= 1; + self.poller.delete(&sock).expect("socket is in poll set"); + match debug!(nix::sys::socket::getsockopt( + sock.as_raw_fd(), + nix::sys::socket::sockopt::SocketError + )) { + Err(e) => self.set_error(e.into()), + Ok(0) => { + if let Some(tcp) = self.socket_into_blocking_tcp_stream(sock) { + return Ok(Some(tcp)); + } + return Ok(None); + } + Ok(_) => {} + } + match debug!(sock.connect(&sock_addr)) { + Err(e) => self.set_error(e), + Ok(()) => { + if let Some(tcp) = self.socket_into_blocking_tcp_stream(sock) { + return Ok(Some(tcp)); + } + } + } + } + Ok(None) + } + + fn socket_into_blocking_tcp_stream(&mut self, sock: Socket) -> Option { + match sock.set_nonblocking(false) { + Ok(()) => Some(sock.into()), + Err(e) => { + self.set_error(e); + None + } + } + } +} + +enum AddOutcome { + Connected(TcpStream), + InProgress, + Error, +} + +/* +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::{serve_4, serve_6, tar_pit}; + use rand::{thread_rng, Rng}; + use std::io::Read; + use std::net::{Ipv4Addr, Ipv6Addr}; + + #[test] + fn test_no_ipv4() { + let port = thread_rng().gen_range(49152..=65535); + assert!(connect((Ipv4Addr::LOCALHOST, port)).is_err()); + } + + #[test] + fn test_connect_ipv4() { + let (_serve, addr, port) = serve_4(); + let mut data = String::new(); + connect((addr, port)) + .unwrap() + .read_to_string(&mut data) + .unwrap(); + assert_eq!(data, format!("{addr}")); + assert!(connect((Ipv6Addr::LOCALHOST, port)).is_err()); + } + + #[test] + fn test_no_ipv6() { + let port = thread_rng().gen_range(49152..=65535); + assert!(connect((Ipv6Addr::LOCALHOST, port)).is_err()); + } + + #[test] + fn test_connect_ipv6() { + let (_serve, addr, port) = serve_6(); + let mut data = String::new(); + connect((addr, port)) + .unwrap() + .read_to_string(&mut data) + .unwrap(); + assert_eq!(data, format!("{addr}")); + assert!(connect((Ipv4Addr::LOCALHOST, port)).is_err()); + } + + #[test] + fn test_connect_no_6_but_4() { + let (_serve, addr, port) = serve_4(); + let expect = format!("{addr}"); + let saddr6: SocketAddr = (Ipv6Addr::LOCALHOST, port).into(); + let saddr4: SocketAddr = (Ipv4Addr::LOCALHOST, port).into(); + let mut data = String::new(); + { + let saddrs = &[saddr4, saddr6][..]; + data.clear(); + connect(saddrs).unwrap().read_to_string(&mut data).unwrap(); + assert_eq!(data, expect); + } + { + let saddrs = &[saddr6, saddr4][..]; + data.clear(); + connect(saddrs).unwrap().read_to_string(&mut data).unwrap(); + assert_eq!(data, expect); + } + } + + #[test] + fn test_connect_no_4_but_6() { + let (_serve, addr, port) = serve_6(); + let expect = format!("{addr}"); + let saddr6: SocketAddr = (Ipv6Addr::LOCALHOST, port).into(); + let saddr4: SocketAddr = (Ipv4Addr::LOCALHOST, port).into(); + let mut data = String::new(); + { + let saddrs = &[saddr4, saddr6][..]; + data.clear(); + connect(saddrs).unwrap().read_to_string(&mut data).unwrap(); + assert_eq!(data, expect); + } + { + let saddrs = &[saddr6, saddr4][..]; + data.clear(); + connect(saddrs).unwrap().read_to_string(&mut data).unwrap(); + assert_eq!(data, expect); + } + } + + #[test] + fn test_connect() { + let (_serve4, addr4, port4) = serve_4(); + let (_serve6, addr6, port6) = serve_6(); + let saddr4: SocketAddr = (addr4, port4).into(); + let saddr6: SocketAddr = (addr6, port6).into(); + let mut data = String::new(); + + data.clear(); + connect(&[saddr6, saddr4][..]) + .unwrap() + .read_to_string(&mut data) + .unwrap(); + assert_eq!(data, format!("{addr6}")); + + data.clear(); + connect(&[saddr4, saddr6][..]) + .unwrap() + .read_to_string(&mut data) + .unwrap(); + // IPv6 is preferred + assert_eq!(data, format!("{addr6}")); + } + + #[test] + fn test_connect_tar_pit4() { + let (_serve4, addr4, port4) = tar_pit(Ipv4Addr::LOCALHOST); + let (_serve6, addr6, port6) = serve_6(); + let saddr4: SocketAddr = (addr4, port4).into(); + let saddr6: SocketAddr = (addr6, port6).into(); + let mut data = String::new(); + + data.clear(); + connect(&[saddr4, saddr6][..]) + .unwrap() + .read_to_string(&mut data) + .unwrap(); + assert_eq!(data, format!("{addr6}")); + + data.clear(); + connect(&[saddr6, saddr4][..]) + .unwrap() + .read_to_string(&mut data) + .unwrap(); + assert_eq!(data, format!("{addr6}")); + } + + #[test] + fn test_connect_tar_pit6() { + let (_serve4, addr4, port4) = serve_4(); + let (_serve6, addr6, port6) = tar_pit(Ipv6Addr::LOCALHOST); + let saddr4: SocketAddr = (addr4, port4).into(); + let saddr6: SocketAddr = (addr6, port6).into(); + let mut data = String::new(); + + data.clear(); + let mut cnx = connect(&[saddr4, saddr6][..]).unwrap(); + cnx.read_to_string(&mut data).unwrap(); + assert_eq!(data, format!("{addr4}")); + + data.clear(); + connect(&[saddr6, saddr4][..]) + .unwrap() + .read_to_string(&mut data) + .unwrap(); + assert_eq!(data, format!("{addr4}")); + } + + #[test] + fn test_connect_tar_pit_all() { + let (_serve4, addr4, port4) = tar_pit(Ipv4Addr::LOCALHOST); + let (_serve6, addr6, port6) = tar_pit(Ipv6Addr::LOCALHOST); + let saddr4: SocketAddr = (addr4, port4).into(); + let saddr6: SocketAddr = (addr6, port6).into(); + assert_eq!( + connect(&[saddr4, saddr6][..]).unwrap_err().kind(), + std::io::ErrorKind::TimedOut + ); + } + + #[test] + fn test_connect_empty() { + let empty = &[][..]; + assert_eq!( + connect(empty).unwrap_err().kind(), + std::io::ErrorKind::InvalidInput + ); + } +} +*/