/* * Copyright (C) 2023 by Kim Minh Kaplan * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR * IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ use std::{ io::{Error, ErrorKind, Result}, net::{SocketAddr, TcpStream, ToSocketAddrs}, time::{Duration, Instant}, }; use polling::{Event, Poller}; use socket2::{Domain, SockAddr, Socket}; /// Opens a TCP connection to a remote host. /// /// If `addr` yields multiple addresses, `connect` uses the algorithm /// described in [RFC 8305 Happy Eyeballs Version 2: Better /// Connectivity Using /// Concurrency](https://datatracker.ietf.org/doc/html/rfc8305) to /// connect. pub fn connect(addr: A, timeout: Option) -> Result { let mut happy = HappyEyeballs::new()?; let start = Instant::now(); let timeout_left = || -> Result> { let Some(v) = timeout else { return Ok(None); }; Ok(Some(v.checked_sub(Instant::now() - start).ok_or_else( || Error::new(ErrorKind::TimedOut, "Connection timed out."), )?)) }; for a in prepare_addresses(addr)? { match happy.add(a.into(), Domain::for_address(a)) { AddOutcome::Connected(tcp) => return Ok(tcp), AddOutcome::Error => continue, AddOutcome::InProgress => (), } if let Some(sock) = happy.poll_once(timeout_left()?)? { return Ok(sock); } } while !happy.is_empty() { if let Some(sock) = happy.poll_once(timeout_left()?)? { return Ok(sock); } } Err(happy .error .unwrap_or_else(|| Error::new(ErrorKind::InvalidInput, "could not resolve to any address"))) } /// Resolve addresses and order them to alternate between IPv6 and IPv4. fn prepare_addresses(addr: A) -> Result> where A: ToSocketAddrs, { let (addrs_v4, addrs_v6): (Vec<_>, Vec<_>) = addr.to_socket_addrs()?.partition(|a| match a { SocketAddr::V4(_) => true, SocketAddr::V6(_) => false, }); let mut addrs = Vec::with_capacity(addrs_v4.len() + addrs_v6.len()); let (mut left, mut right) = (addrs_v6.into_iter(), addrs_v4.into_iter()); while let Some(a) = left.next() { addrs.push(a); std::mem::swap(&mut left, &mut right); } addrs.extend(right); Ok(addrs) } struct HappyEyeballs { poller: Poller, error: Option, attempts: Vec>, attempts_in_progress: usize, } impl HappyEyeballs { fn new() -> Result { Ok(Self { poller: Poller::new()?, error: None, attempts: Vec::new(), attempts_in_progress: 0, }) } fn is_empty(&self) -> bool { self.attempts_in_progress == 0 } fn set_error(&mut self, error: Error) { self.error.get_or_insert(error); } // Initiate a fresh non-blocking TCP connection to `saddr`. // // Returns `AddOutcome::InProgress`. // // If there is an error concerning this particular connection // attempt, return `AddOutcome::Error` and the error code is // remembered in `self` if it is the first to occur; it will be // the one returned if no connection can be established at all. // // If the connection succeeds immediately, returns // `AddOutcome::Connected(stream)`. Because of non-blocking and // the way TCP works, this should *not* happen. fn add(&mut self, saddr: SockAddr, domain: Domain) -> AddOutcome { let sock = match Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP)) .and_then(|sock| { sock.set_nonblocking(true)?; Ok(sock) }) { Ok(sock) => sock, Err(e) => { self.set_error(e); return AddOutcome::Error; } }; match sock.connect(&saddr) { Ok(()) => match sock.set_nonblocking(false) { Ok(()) => return AddOutcome::Connected(sock.into()), Err(e) => self.set_error(e), }, Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => { let interest = Event::writable(self.attempts.len()); match self.poller.add(&sock, interest) { Ok(()) => { self.attempts.push(Some((saddr, sock))); self.attempts_in_progress += 1; return AddOutcome::InProgress; } Err(e) => self.set_error(e), } } Err(e) => self.set_error(e), } AddOutcome::Error } fn poll_once(&mut self, timeout: Option) -> Result> { let mut events = Vec::new(); self.poller.wait(&mut events, timeout)?; for evt in &events { assert!(evt.writable); let (sock_addr, sock) = self.attempts[evt.key].take().expect("attempt exists"); self.attempts_in_progress -= 1; self.poller.delete(&sock).expect("socket is in poll set"); match nix::sys::socket::getsockopt(&sock, nix::sys::socket::sockopt::SocketError) { Err(err) => self.set_error(err.into()), Ok(0) => { if let Some(tcp) = self.socket_into_blocking_tcp_stream(sock) { return Ok(Some(tcp)); } return Ok(None); } Ok(_) => {} } match sock.connect(&sock_addr) { Err(err) => self.set_error(err), Ok(()) => { if let Some(tcp) = self.socket_into_blocking_tcp_stream(sock) { return Ok(Some(tcp)); } } } } Ok(None) } fn socket_into_blocking_tcp_stream(&mut self, sock: Socket) -> Option { match sock.set_nonblocking(false) { Ok(()) => Some(sock.into()), Err(err) => { self.set_error(err); None } } } } enum AddOutcome { Connected(TcpStream), InProgress, Error, } /* #[cfg(test)] mod tests { use super::*; use crate::test_utils::{serve_4, serve_6, tar_pit}; use rand::{thread_rng, Rng}; use std::io::Read; use std::net::{Ipv4Addr, Ipv6Addr}; #[test] fn test_no_ipv4() { let port = thread_rng().gen_range(49152..=65535); assert!(connect((Ipv4Addr::LOCALHOST, port)).is_err()); } #[test] fn test_connect_ipv4() { let (_serve, addr, port) = serve_4(); let mut data = String::new(); connect((addr, port)) .unwrap() .read_to_string(&mut data) .unwrap(); assert_eq!(data, format!("{addr}")); assert!(connect((Ipv6Addr::LOCALHOST, port)).is_err()); } #[test] fn test_no_ipv6() { let port = thread_rng().gen_range(49152..=65535); assert!(connect((Ipv6Addr::LOCALHOST, port)).is_err()); } #[test] fn test_connect_ipv6() { let (_serve, addr, port) = serve_6(); let mut data = String::new(); connect((addr, port)) .unwrap() .read_to_string(&mut data) .unwrap(); assert_eq!(data, format!("{addr}")); assert!(connect((Ipv4Addr::LOCALHOST, port)).is_err()); } #[test] fn test_connect_no_6_but_4() { let (_serve, addr, port) = serve_4(); let expect = format!("{addr}"); let saddr6: SocketAddr = (Ipv6Addr::LOCALHOST, port).into(); let saddr4: SocketAddr = (Ipv4Addr::LOCALHOST, port).into(); let mut data = String::new(); { let saddrs = &[saddr4, saddr6][..]; data.clear(); connect(saddrs).unwrap().read_to_string(&mut data).unwrap(); assert_eq!(data, expect); } { let saddrs = &[saddr6, saddr4][..]; data.clear(); connect(saddrs).unwrap().read_to_string(&mut data).unwrap(); assert_eq!(data, expect); } } #[test] fn test_connect_no_4_but_6() { let (_serve, addr, port) = serve_6(); let expect = format!("{addr}"); let saddr6: SocketAddr = (Ipv6Addr::LOCALHOST, port).into(); let saddr4: SocketAddr = (Ipv4Addr::LOCALHOST, port).into(); let mut data = String::new(); { let saddrs = &[saddr4, saddr6][..]; data.clear(); connect(saddrs).unwrap().read_to_string(&mut data).unwrap(); assert_eq!(data, expect); } { let saddrs = &[saddr6, saddr4][..]; data.clear(); connect(saddrs).unwrap().read_to_string(&mut data).unwrap(); assert_eq!(data, expect); } } #[test] fn test_connect() { let (_serve4, addr4, port4) = serve_4(); let (_serve6, addr6, port6) = serve_6(); let saddr4: SocketAddr = (addr4, port4).into(); let saddr6: SocketAddr = (addr6, port6).into(); let mut data = String::new(); data.clear(); connect(&[saddr6, saddr4][..]) .unwrap() .read_to_string(&mut data) .unwrap(); assert_eq!(data, format!("{addr6}")); data.clear(); connect(&[saddr4, saddr6][..]) .unwrap() .read_to_string(&mut data) .unwrap(); // IPv6 is preferred assert_eq!(data, format!("{addr6}")); } #[test] fn test_connect_tar_pit4() { let (_serve4, addr4, port4) = tar_pit(Ipv4Addr::LOCALHOST); let (_serve6, addr6, port6) = serve_6(); let saddr4: SocketAddr = (addr4, port4).into(); let saddr6: SocketAddr = (addr6, port6).into(); let mut data = String::new(); data.clear(); connect(&[saddr4, saddr6][..]) .unwrap() .read_to_string(&mut data) .unwrap(); assert_eq!(data, format!("{addr6}")); data.clear(); connect(&[saddr6, saddr4][..]) .unwrap() .read_to_string(&mut data) .unwrap(); assert_eq!(data, format!("{addr6}")); } #[test] fn test_connect_tar_pit6() { let (_serve4, addr4, port4) = serve_4(); let (_serve6, addr6, port6) = tar_pit(Ipv6Addr::LOCALHOST); let saddr4: SocketAddr = (addr4, port4).into(); let saddr6: SocketAddr = (addr6, port6).into(); let mut data = String::new(); data.clear(); let mut cnx = connect(&[saddr4, saddr6][..]).unwrap(); cnx.read_to_string(&mut data).unwrap(); assert_eq!(data, format!("{addr4}")); data.clear(); connect(&[saddr6, saddr4][..]) .unwrap() .read_to_string(&mut data) .unwrap(); assert_eq!(data, format!("{addr4}")); } #[test] fn test_connect_tar_pit_all() { let (_serve4, addr4, port4) = tar_pit(Ipv4Addr::LOCALHOST); let (_serve6, addr6, port6) = tar_pit(Ipv6Addr::LOCALHOST); let saddr4: SocketAddr = (addr4, port4).into(); let saddr6: SocketAddr = (addr6, port6).into(); assert_eq!( connect(&[saddr4, saddr6][..]).unwrap_err().kind(), std::io::ErrorKind::TimedOut ); } #[test] fn test_connect_empty() { let empty = &[][..]; assert_eq!( connect(empty).unwrap_err().kind(), std::io::ErrorKind::InvalidInput ); } } */