/*
* meli - melib library
*
* Copyright 2020 Manos Pitsidianakis
*
* This file is part of meli.
*
* meli is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* meli is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with meli. If not, see .
*/
//! Connections layers (TCP/fd/TLS/Deflate) to use with remote backends.
use std::{os::unix::io::AsRawFd, time::Duration};
#[cfg(feature = "deflate_compression")]
use flate2::{read::DeflateDecoder, write::DeflateEncoder, Compression};
#[cfg(any(target_os = "openbsd", target_os = "netbsd", target_os = "haiku"))]
use libc::SO_KEEPALIVE as KEEPALIVE_OPTION;
#[cfg(any(target_os = "macos", target_os = "ios"))]
use libc::TCP_KEEPALIVE as KEEPALIVE_OPTION;
#[cfg(not(any(
target_os = "openbsd",
target_os = "netbsd",
target_os = "haiku",
target_os = "macos",
target_os = "ios"
)))]
use libc::TCP_KEEPIDLE as KEEPALIVE_OPTION;
use libc::{self, c_int, c_void};
// pub mod smol;
pub mod std_net;
pub const CONNECTION_ATTEMPT_DELAY: std::time::Duration = std::time::Duration::from_millis(250);
pub enum Connection {
Tcp {
inner: std::net::TcpStream,
id: Option<&'static str>,
trace: bool,
},
Fd {
inner: std::os::unix::io::RawFd,
id: Option<&'static str>,
trace: bool,
},
#[cfg(feature = "tls")]
Tls {
inner: native_tls::TlsStream,
id: Option<&'static str>,
trace: bool,
},
#[cfg(feature = "deflate_compression")]
Deflate {
inner: DeflateEncoder>>,
id: Option<&'static str>,
trace: bool,
},
}
impl std::fmt::Debug for Connection {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Tcp {
ref trace,
ref inner,
ref id,
} => fmt
.debug_struct(stringify!(Connection))
.field("variant", &stringify!(Tcp))
.field(stringify!(trace), trace)
.field(stringify!(id), id)
.field(stringify!(inner), inner)
.finish(),
#[cfg(feature = "tls")]
Tls {
ref trace,
ref inner,
ref id,
} => fmt
.debug_struct(stringify!(Connection))
.field("variant", &stringify!(Tls))
.field(stringify!(trace), trace)
.field(stringify!(id), id)
.field(stringify!(inner), inner.get_ref())
.finish(),
Fd {
ref trace,
ref inner,
ref id,
} => fmt
.debug_struct(stringify!(Connection))
.field("variant", &stringify!(Fd))
.field(stringify!(trace), trace)
.field(stringify!(id), id)
.field(stringify!(inner), inner)
.finish(),
#[cfg(feature = "deflate_compression")]
Deflate {
ref trace,
ref inner,
ref id,
} => fmt
.debug_struct(stringify!(Connection))
.field("variant", &stringify!(Deflate))
.field(stringify!(trace), trace)
.field(stringify!(id), id)
.field(stringify!(inner), inner)
.finish(),
}
}
}
use Connection::*;
macro_rules! syscall {
($fn: ident ( $($arg: expr),* $(,)* ) ) => {{
#[allow(unused_unsafe)]
let res = unsafe { libc::$fn($($arg, )*) };
if res == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(res)
}
}};
}
impl Connection {
#[cfg(feature = "deflate_compression")]
pub const IO_BUF_SIZE: usize = 64 * 1024;
#[cfg(feature = "deflate_compression")]
pub fn deflate(mut self) -> Self {
let trace = self.is_trace_enabled();
let id = self.id();
self.set_trace(false);
Self::Deflate {
inner: DeflateEncoder::new(
DeflateDecoder::new_with_buf(Box::new(self), vec![0; Self::IO_BUF_SIZE]),
Compression::default(),
),
id,
trace,
}
}
#[cfg(feature = "tls")]
pub fn new_tls(mut inner: native_tls::TlsStream) -> Self {
let trace = inner.get_ref().is_trace_enabled();
let id = inner.get_ref().id();
if trace {
inner.get_mut().set_trace(false);
}
Self::Tls { inner, id, trace }
}
pub fn new_tcp(inner: std::net::TcpStream) -> Self {
Self::Tcp {
inner,
id: None,
trace: false,
}
}
pub fn trace(mut self, val: bool) -> Self {
match self {
Tcp { ref mut trace, .. } => *trace = val,
#[cfg(feature = "tls")]
Tls { ref mut trace, .. } => *trace = val,
Fd { ref mut trace, .. } => {
*trace = val;
}
#[cfg(feature = "deflate_compression")]
Deflate { ref mut trace, .. } => *trace = val,
}
self
}
pub fn with_id(mut self, val: &'static str) -> Self {
match self {
Tcp { ref mut id, .. } => *id = Some(val),
#[cfg(feature = "tls")]
Tls { ref mut id, .. } => *id = Some(val),
Fd { ref mut id, .. } => {
*id = Some(val);
}
#[cfg(feature = "deflate_compression")]
Deflate { ref mut id, .. } => *id = Some(val),
}
self
}
pub fn set_trace(&mut self, val: bool) {
match self {
Tcp { ref mut trace, .. } => *trace = val,
#[cfg(feature = "tls")]
Tls { ref mut trace, .. } => *trace = val,
Fd { ref mut trace, .. } => {
*trace = val;
}
#[cfg(feature = "deflate_compression")]
Deflate { ref mut trace, .. } => *trace = val,
}
}
pub fn set_nonblocking(&self, nonblocking: bool) -> std::io::Result<()> {
if self.is_trace_enabled() {
let id = self.id();
log::trace!(
"{}{}{}{:?} set_nonblocking({:?})",
if id.is_some() { "[" } else { "" },
if let Some(id) = id.as_ref() { id } else { "" },
if id.is_some() { "]: " } else { "" },
self,
nonblocking
);
}
match self {
Tcp { ref inner, .. } => inner.set_nonblocking(nonblocking),
#[cfg(feature = "tls")]
Tls { ref inner, .. } => inner.get_ref().set_nonblocking(nonblocking),
Fd { inner, .. } => {
// [ref:VERIFY]
nix::fcntl::fcntl(
*inner,
nix::fcntl::FcntlArg::F_SETFL(if nonblocking {
nix::fcntl::OFlag::O_NONBLOCK
} else {
!nix::fcntl::OFlag::O_NONBLOCK
}),
)
.map_err(|err| std::io::Error::from_raw_os_error(err as i32))?;
Ok(())
}
#[cfg(feature = "deflate_compression")]
Deflate { ref inner, .. } => inner.get_ref().get_ref().set_nonblocking(nonblocking),
}
}
pub fn set_read_timeout(&self, dur: Option) -> std::io::Result<()> {
if self.is_trace_enabled() {
let id = self.id();
log::trace!(
"{}{}{}{:?} set_read_timeout({:?})",
if id.is_some() { "[" } else { "" },
if let Some(id) = id.as_ref() { id } else { "" },
if id.is_some() { "]: " } else { "" },
self,
dur
);
}
match self {
Tcp { ref inner, .. } => inner.set_read_timeout(dur),
#[cfg(feature = "tls")]
Tls { ref inner, .. } => inner.get_ref().set_read_timeout(dur),
Fd { .. } => Ok(()),
#[cfg(feature = "deflate_compression")]
Deflate { ref inner, .. } => inner.get_ref().get_ref().set_read_timeout(dur),
}
}
pub fn set_write_timeout(&self, dur: Option) -> std::io::Result<()> {
if self.is_trace_enabled() {
let id = self.id();
log::trace!(
"{}{}{}{:?} set_write_timeout({:?})",
if id.is_some() { "[" } else { "" },
if let Some(id) = id.as_ref() { id } else { "" },
if id.is_some() { "]: " } else { "" },
self,
dur
);
}
match self {
Tcp { ref inner, .. } => inner.set_write_timeout(dur),
#[cfg(feature = "tls")]
Tls { ref inner, .. } => inner.get_ref().set_write_timeout(dur),
Fd { .. } => Ok(()),
#[cfg(feature = "deflate_compression")]
Deflate { ref inner, .. } => inner.get_ref().get_ref().set_write_timeout(dur),
}
}
pub fn keepalive(&self) -> std::io::Result