use std::fmt::Formatter;
use std::future::Future;
use std::str::FromStr;
use axum::extract::Path;
use axum::Error;
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Json, State,
},
http::status::StatusCode,
response::IntoResponse,
};
use axum_extra::extract::Query;
use base64::Engine;
use futures::{AsyncReadExt, AsyncWriteExt, SinkExt, StreamExt, TryStreamExt};
use futures_concurrency::stream::Merge;
use libp2p_identity::PeerId;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use std::net::IpAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tracing::{debug, error, info, trace};
use hopr_lib::errors::HoprLibError;
use hopr_lib::transfer_session;
use hopr_lib::{HoprSession, ServiceId, SessionClientConfig, SessionTarget};
use hopr_network_types::prelude::{ConnectedUdpStream, IpOrHost, SealedHost, UdpStreamParallelism};
use hopr_network_types::udp::ForeignDataMode;
use hopr_network_types::utils::AsyncReadStreamer;
use crate::types::PeerOrAddress;
use crate::{ApiError, ApiErrorStatus, InternalState, ListenerId, BASE_PATH};
pub const HOPR_TCP_BUFFER_SIZE: usize = 4096;
pub const HOPR_UDP_BUFFER_SIZE: usize = 16384;
pub const HOPR_UDP_QUEUE_SIZE: usize = 8192;
#[cfg(all(feature = "prometheus", not(test)))]
lazy_static::lazy_static! {
static ref METRIC_ACTIVE_CLIENTS: hopr_metrics::MultiGauge = hopr_metrics::MultiGauge::new(
"hopr_session_hoprd_clients",
"Number of clients connected at this Entry node",
&["type"]
).unwrap();
}
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub enum SessionTargetSpec {
Plain(String),
Sealed(#[serde_as(as = "serde_with::base64::Base64")] Vec<u8>),
Service(ServiceId),
}
impl std::fmt::Display for SessionTargetSpec {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
SessionTargetSpec::Plain(t) => write!(f, "{t}"),
SessionTargetSpec::Sealed(t) => write!(f, "$${}", base64::prelude::BASE64_URL_SAFE.encode(t)),
SessionTargetSpec::Service(t) => write!(f, "#{t}"),
}
}
}
impl std::str::FromStr for SessionTargetSpec {
type Err = HoprLibError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if let Some(stripped) = s.strip_prefix("$$") {
Self::Sealed(
base64::prelude::BASE64_URL_SAFE
.decode(stripped)
.map_err(|e| HoprLibError::GeneralError(e.to_string()))?,
)
} else if let Some(stripped) = s.strip_prefix("#") {
Self::Service(
stripped
.parse()
.map_err(|_| HoprLibError::GeneralError("cannot parse service id".into()))?,
)
} else {
Self::Plain(s.to_owned())
})
}
}
impl SessionTargetSpec {
pub fn into_target(self, protocol: IpProtocol) -> Result<SessionTarget, HoprLibError> {
Ok(match (protocol, self) {
(IpProtocol::TCP, SessionTargetSpec::Plain(plain)) => SessionTarget::TcpStream(
IpOrHost::from_str(&plain)
.map(SealedHost::from)
.map_err(|e| HoprLibError::GeneralError(e.to_string()))?,
),
(IpProtocol::UDP, SessionTargetSpec::Plain(plain)) => SessionTarget::UdpStream(
IpOrHost::from_str(&plain)
.map(SealedHost::from)
.map_err(|e| HoprLibError::GeneralError(e.to_string()))?,
),
(IpProtocol::TCP, SessionTargetSpec::Sealed(enc)) => {
SessionTarget::TcpStream(SealedHost::Sealed(enc.into_boxed_slice()))
}
(IpProtocol::UDP, SessionTargetSpec::Sealed(enc)) => {
SessionTarget::UdpStream(SealedHost::Sealed(enc.into_boxed_slice()))
}
(_, SessionTargetSpec::Service(id)) => SessionTarget::ExitNode(id),
})
}
}
#[derive(Debug)]
pub struct StoredSessionEntry {
pub target: SessionTargetSpec,
pub path: RoutingOptions,
pub jh: hopr_async_runtime::prelude::JoinHandle<()>,
}
#[repr(u8)]
#[derive(
Debug, Clone, strum::EnumIter, strum::Display, strum::EnumString, Serialize, Deserialize, utoipa::ToSchema,
)]
pub enum SessionCapability {
Segmentation,
Retransmission,
RetransmissionAckOnly,
NoDelay,
}
impl From<SessionCapability> for hopr_lib::SessionCapability {
fn from(cap: SessionCapability) -> hopr_lib::SessionCapability {
match cap {
SessionCapability::Segmentation => hopr_lib::SessionCapability::Segmentation,
SessionCapability::Retransmission => hopr_lib::SessionCapability::Retransmission,
SessionCapability::RetransmissionAckOnly => hopr_lib::SessionCapability::RetransmissionAckOnly,
SessionCapability::NoDelay => hopr_lib::SessionCapability::NoDelay,
}
}
}
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::IntoParams, utoipa::ToSchema)]
#[into_params(parameter_in = Query)]
#[serde(rename_all = "camelCase")]
pub(crate) struct SessionWebsocketClientQueryRequest {
#[serde_as(as = "DisplayFromStr")]
#[schema(required = true, value_type = String)]
pub destination: String, #[schema(required = true)]
pub hops: u8,
#[cfg(feature = "explicit-path")]
#[schema(required = false)]
pub path: Option<String>,
#[schema(required = true)]
#[serde_as(as = "Vec<DisplayFromStr>")]
pub capabilities: Vec<SessionCapability>,
#[schema(required = true)]
#[serde_as(as = "DisplayFromStr")]
pub target: SessionTargetSpec,
#[schema(required = false)]
#[serde(default = "default_protocol")]
pub protocol: IpProtocol,
}
#[inline]
fn default_protocol() -> IpProtocol {
IpProtocol::TCP
}
impl SessionWebsocketClientQueryRequest {
pub(crate) fn into_protocol_session_config(self) -> Result<SessionClientConfig, HoprLibError> {
#[cfg(not(feature = "explicit-path"))]
let path_options = hopr_lib::RoutingOptions::Hops((self.hops as u32).try_into()?);
#[cfg(feature = "explicit-path")]
let path_options = if let Some(path) = self.path {
hopr_lib::RoutingOptions::IntermediatePath(
path.split(',')
.map(PeerId::from_str)
.collect::<Result<Vec<PeerId>, _>>()
.map_err(|e| HoprLibError::GeneralError(format!("invalid peer id on path: {e}")))?
.try_into()?,
)
} else {
hopr_lib::RoutingOptions::Hops((self.hops as u32).try_into()?)
};
Ok(SessionClientConfig {
peer: PeerId::from_str(self.destination.as_str())
.map_err(|_e| HoprLibError::GeneralError(format!("invalid destination: {}", self.destination)))?,
path_options,
target: self.target.into_target(self.protocol)?,
capabilities: self.capabilities.into_iter().map(SessionCapability::into).collect(),
})
}
}
#[derive(Debug, Default, Clone, Deserialize, utoipa::ToSchema)]
#[schema(value_type = String, format = Binary)]
#[allow(dead_code)] struct WssData(Vec<u8>);
#[allow(dead_code)] #[utoipa::path(
get,
path = const_format::formatcp!("{BASE_PATH}/session/websocket"),
params(SessionWebsocketClientQueryRequest),
responses(
(status = 200, description = "Successfully created a new client websocket session."),
(status = 401, description = "Invalid authorization token.", body = ApiError),
(status = 422, description = "Unknown failure", body = ApiError),
(status = 429, description = "Too many open websocket connections.", body = ApiError),
),
security(
("api_token" = []),
("bearer_token" = [])
),
tag = "Session",
)]
pub(crate) async fn websocket(
ws: WebSocketUpgrade,
Query(query): Query<SessionWebsocketClientQueryRequest>,
State(state): State<Arc<InternalState>>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let data = query.into_protocol_session_config().map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
ApiErrorStatus::UnknownFailure(e.to_string()),
)
})?;
let hopr = state.hopr.clone();
let session: HoprSession = hopr.connect_to(data).await.map_err(|e| {
error!(error = %e, "Failed to establish session");
(
StatusCode::UNPROCESSABLE_ENTITY,
ApiErrorStatus::UnknownFailure(e.to_string()),
)
})?;
Ok::<_, (StatusCode, ApiErrorStatus)>(ws.on_upgrade(move |socket| websocket_connection(socket, session)))
}
enum WebSocketInput {
Network(Result<Box<[u8]>, std::io::Error>),
WsInput(Result<Message, Error>),
}
const WS_MAX_SESSION_READ_SIZE: usize = 4096;
#[tracing::instrument(level = "debug", skip(socket, session))]
async fn websocket_connection(socket: WebSocket, session: HoprSession) {
let session_id = *session.id();
let (rx, mut tx) = session.split();
let (mut sender, receiver) = socket.split();
let mut queue = (
receiver.map(WebSocketInput::WsInput),
AsyncReadStreamer::<WS_MAX_SESSION_READ_SIZE, _>(rx).map(WebSocketInput::Network),
)
.merge();
let (mut bytes_to_session, mut bytes_from_session) = (0, 0);
while let Some(v) = queue.next().await {
match v {
WebSocketInput::Network(bytes) => match bytes {
Ok(bytes) => {
let len = bytes.len();
if let Err(e) = sender.send(Message::Binary(bytes.into())).await {
error!(
error = %e,
"Failed to emit read data onto the websocket, closing connection"
);
break;
};
bytes_from_session += len;
}
Err(e) => {
error!(
error = %e,
"Failed to push data from network to socket, closing connection"
);
break;
}
},
WebSocketInput::WsInput(ws_in) => match ws_in {
Ok(Message::Binary(data)) => {
let len = data.len();
if let Err(e) = tx.write(data.as_ref()).await {
error!(error = %e, "Failed to write data to the session, closing connection");
break;
}
bytes_to_session += len;
}
Ok(Message::Text(_)) => {
error!("Received string instead of binary data, closing connection");
break;
}
Ok(Message::Close(_)) => {
debug!("Received close frame, closing connection");
break;
}
Ok(m) => trace!(message = ?m, "skipping an unsupported websocket message"),
Err(e) => {
error!(error = %e, "Failed to get a valid websocket message, closing connection");
break;
}
},
}
}
info!(%session_id, bytes_from_session, bytes_to_session, "WS session connection ended");
}
#[serde_as]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, utoipa::ToSchema)]
pub enum RoutingOptions {
#[cfg(feature = "explicit-path")]
#[schema(value_type = Vec<String>)]
IntermediatePath(#[serde_as(as = "Vec<DisplayFromStr>")] Vec<PeerId>),
Hops(usize),
}
impl TryFrom<RoutingOptions> for hopr_lib::RoutingOptions {
type Error = HoprLibError;
fn try_from(value: RoutingOptions) -> Result<Self, Self::Error> {
match value {
#[cfg(feature = "explicit-path")]
RoutingOptions::IntermediatePath(path) => {
Ok(hopr_lib::RoutingOptions::IntermediatePath(path.into_iter().collect()))
}
RoutingOptions::Hops(hops) => Ok(hopr_lib::RoutingOptions::Hops(hops.try_into()?)),
}
}
}
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
#[schema(example = json!({
"destination": "12D3KooWR4uwjKCDCAY1xsEFB4esuWLF9Q5ijYvCjz5PNkTbnu33",
"path": {
"Hops": 1
},
"target": {"Plain": "localhost:8080"},
"listenHost": "127.0.0.1:10000",
"capabilities": ["Retransmission", "Segmentation"]
}))]
#[serde(rename_all = "camelCase")]
pub(crate) struct SessionClientRequest {
#[serde_as(as = "DisplayFromStr")]
#[schema(value_type = String)]
pub destination: PeerOrAddress,
pub path: RoutingOptions,
pub target: SessionTargetSpec,
pub listen_host: Option<String>,
#[serde_as(as = "Option<Vec<DisplayFromStr>>")]
pub capabilities: Option<Vec<SessionCapability>>,
}
impl SessionClientRequest {
pub(crate) fn into_protocol_session_config(
self,
target_protocol: IpProtocol,
) -> Result<SessionClientConfig, HoprLibError> {
let peer = match self.destination {
PeerOrAddress::PeerId(peer_id) => peer_id,
PeerOrAddress::Address(address) => {
return Err(HoprLibError::GeneralError(format!("invalid destination: {address}")))
}
};
Ok(SessionClientConfig {
peer,
path_options: self.path.try_into()?,
target: self.target.into_target(target_protocol)?,
capabilities: self
.capabilities
.map(|vs| {
vs.into_iter()
.map(|v| {
let cap: hopr_lib::SessionCapability = v.into();
cap
})
.collect::<Vec<_>>()
})
.unwrap_or_else(|| match target_protocol {
IpProtocol::TCP => {
vec![
hopr_lib::SessionCapability::Retransmission,
hopr_lib::SessionCapability::Segmentation,
]
}
_ => vec![], }),
})
}
}
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
#[schema(example = json!({
"target": "example.com:80",
"protocol": "tcp",
"ip": "127.0.0.1",
"port": 5542,
"path": { "Hops": 1 }
}))]
#[serde(rename_all = "camelCase")]
pub(crate) struct SessionClientResponse {
pub target: String,
#[serde_as(as = "DisplayFromStr")]
#[schema(value_type = String)]
pub protocol: IpProtocol,
pub ip: String,
pub path: RoutingOptions,
pub port: u16,
}
fn build_binding_host(requested: Option<&str>, default: std::net::SocketAddr) -> std::net::SocketAddr {
match requested.map(|r| std::net::SocketAddr::from_str(r).map_err(|_| r)) {
Some(Err(requested)) => {
debug!(requested, %default, "using partially default listen host");
std::net::SocketAddr::new(
requested.parse().unwrap_or(default.ip()),
requested
.strip_prefix(":")
.and_then(|p| u16::from_str(p).ok())
.unwrap_or(default.port()),
)
}
Some(Ok(requested)) => {
debug!(%requested, "using requested listen host");
requested
}
None => {
debug!(%default, "using default listen host");
default
}
}
}
#[utoipa::path(
post,
path = const_format::formatcp!("{BASE_PATH}/session/{{protocol}}"),
params(
("protocol" = String, Path, description = "IP transport protocol")
),
request_body(
content = SessionClientRequest,
description = "Creates a new client HOPR session that will start listening on a dedicated port. Once the port is bound, it is possible to use the socket for bidirectional read and write communication.",
content_type = "application/json"),
responses(
(status = 200, description = "Successfully created a new client session.", body = SessionClientResponse),
(status = 400, description = "Invalid IP protocol.", body = ApiError),
(status = 401, description = "Invalid authorization token.", body = ApiError),
(status = 409, description = "Listening address and port already in use.", body = ApiError),
(status = 422, description = "Unknown failure", body = ApiError),
),
security(
("api_token" = []),
("bearer_token" = [])
),
tag = "Session"
)]
pub(crate) async fn create_client(
State(state): State<Arc<InternalState>>,
Path(protocol): Path<IpProtocol>,
Json(args): Json<SessionClientRequest>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let bind_host: std::net::SocketAddr = build_binding_host(args.listen_host.as_deref(), state.default_listen_host);
if bind_host.port() > 0
&& state
.open_listeners
.read()
.await
.contains_key(&ListenerId(protocol.into(), bind_host))
{
return Err((StatusCode::CONFLICT, ApiErrorStatus::ListenHostAlreadyUsed));
}
let target = args.target.clone();
let path = args.path.clone();
let data = args.into_protocol_session_config(protocol).map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
ApiErrorStatus::UnknownFailure(e.to_string()),
)
})?;
debug!("binding {protocol} session listening socket to {bind_host}");
let bound_host = match protocol {
IpProtocol::TCP => {
let (bound_host, tcp_listener) = tcp_listen_on(bind_host).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::AddrInUse {
(StatusCode::CONFLICT, ApiErrorStatus::ListenHostAlreadyUsed)
} else {
(
StatusCode::UNPROCESSABLE_ENTITY,
ApiErrorStatus::UnknownFailure(format!("failed to start TCP listener on {bind_host}: {e}")),
)
}
})?;
info!(%bound_host, "TCP session listener bound");
let hopr = state.hopr.clone();
let jh = hopr_async_runtime::prelude::spawn(
tokio_stream::wrappers::TcpListenerStream::new(tcp_listener)
.and_then(|sock| async { Ok((sock.peer_addr()?, sock)) })
.for_each_concurrent(None, move |accepted_client| {
let data = data.clone();
let hopr = hopr.clone();
async move {
match accepted_client {
Ok((sock_addr, stream)) => {
debug!(socket = ?sock_addr, "incoming TCP connection");
let session = match hopr.connect_to(data).await {
Ok(s) => s,
Err(e) => {
error!(error = %e, "failed to establish session");
return;
}
};
debug!(
socket = ?sock_addr,
session_id = tracing::field::debug(*session.id()),
"new session for incoming TCP connection",
);
#[cfg(all(feature = "prometheus", not(test)))]
METRIC_ACTIVE_CLIENTS.increment(&["tcp"], 1.0);
bind_session_to_stream(session, stream, HOPR_TCP_BUFFER_SIZE).await;
#[cfg(all(feature = "prometheus", not(test)))]
METRIC_ACTIVE_CLIENTS.decrement(&["tcp"], 1.0);
}
Err(e) => error!(error = %e, "failed to accept connection"),
}
}
}),
);
state.open_listeners.write().await.insert(
ListenerId(protocol.into(), bound_host),
StoredSessionEntry {
target: target.clone(),
path: path.clone(),
jh,
},
);
bound_host
}
IpProtocol::UDP => {
let (bound_host, udp_socket) = udp_bind_to(bind_host).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::AddrInUse {
(StatusCode::CONFLICT, ApiErrorStatus::ListenHostAlreadyUsed)
} else {
(
StatusCode::UNPROCESSABLE_ENTITY,
ApiErrorStatus::UnknownFailure(format!("failed to start UDP listener on {bind_host}: {e}")),
)
}
})?;
info!(%bound_host, "UDP session listener bound");
let hopr = state.hopr.clone();
let session = hopr.connect_to(data).await.map_err(|e| {
(
StatusCode::UNPROCESSABLE_ENTITY,
ApiErrorStatus::UnknownFailure(e.to_string()),
)
})?;
let open_listeners_clone = state.open_listeners.clone();
let listener_id = ListenerId(protocol.into(), bound_host);
state.open_listeners.write().await.insert(
listener_id,
StoredSessionEntry {
target: target.clone(),
path: path.clone(),
jh: hopr_async_runtime::prelude::spawn(async move {
#[cfg(all(feature = "prometheus", not(test)))]
METRIC_ACTIVE_CLIENTS.increment(&["udp"], 1.0);
bind_session_to_stream(session, udp_socket, HOPR_UDP_BUFFER_SIZE).await;
#[cfg(all(feature = "prometheus", not(test)))]
METRIC_ACTIVE_CLIENTS.decrement(&["udp"], 1.0);
open_listeners_clone.write().await.remove(&listener_id);
}),
},
);
bound_host
}
};
Ok::<_, (StatusCode, ApiErrorStatus)>(
(
StatusCode::OK,
Json(SessionClientResponse {
protocol,
path,
target: target.to_string(),
ip: bound_host.ip().to_string(),
port: bound_host.port(),
}),
)
.into_response(),
)
}
#[utoipa::path(
get,
path = const_format::formatcp!("{BASE_PATH}/session/{{protocol}}"),
params(
("protocol" = String, Path, description = "IP transport protocol")
),
responses(
(status = 200, description = "Opened session listeners for the given IP protocol.", body = Vec<SessionClientResponse>),
(status = 400, description = "Invalid IP protocol.", body = ApiError),
(status = 401, description = "Invalid authorization token.", body = ApiError),
(status = 422, description = "Unknown failure", body = ApiError)
),
security(
("api_token" = []),
("bearer_token" = [])
),
tag = "Session",
)]
pub(crate) async fn list_clients(
State(state): State<Arc<InternalState>>,
Path(protocol): Path<IpProtocol>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let response = state
.open_listeners
.read()
.await
.iter()
.filter(|(id, _)| id.0 == protocol.into())
.map(|(id, entry)| SessionClientResponse {
protocol,
target: entry.target.to_string(),
ip: id.1.ip().to_string(),
port: id.1.port(),
path: entry.path.clone(),
})
.collect::<Vec<_>>();
Ok::<_, (StatusCode, ApiErrorStatus)>((StatusCode::OK, Json(response)).into_response())
}
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, strum::Display, strum::EnumString, utoipa::ToSchema,
)]
#[strum(serialize_all = "lowercase", ascii_case_insensitive)]
#[serde(rename_all = "lowercase")]
pub enum IpProtocol {
#[allow(clippy::upper_case_acronyms)]
TCP,
#[allow(clippy::upper_case_acronyms)]
UDP,
}
impl From<IpProtocol> for hopr_lib::IpProtocol {
fn from(protocol: IpProtocol) -> hopr_lib::IpProtocol {
match protocol {
IpProtocol::TCP => hopr_lib::IpProtocol::TCP,
IpProtocol::UDP => hopr_lib::IpProtocol::UDP,
}
}
}
#[serde_as]
#[derive(Debug, Serialize, Deserialize, utoipa::IntoParams, utoipa::ToSchema)]
pub struct SessionCloseClientQuery {
#[serde_as(as = "DisplayFromStr")]
#[schema(value_type = String)]
pub protocol: IpProtocol,
pub ip: String,
pub port: u16,
}
#[utoipa::path(
delete,
path = const_format::formatcp!("{BASE_PATH}/session/{{protocol}}/{{ip}}/{{port}}"),
params(SessionCloseClientQuery),
responses(
(status = 204, description = "Listener closed successfully"),
(status = 400, description = "Invalid IP protocol or port.", body = ApiError),
(status = 401, description = "Invalid authorization token.", body = ApiError),
(status = 404, description = "Listener not found.", body = ApiError),
(status = 422, description = "Unknown failure", body = ApiError)
),
security(
("api_token" = []),
("bearer_token" = [])
),
tag = "Session",
)]
pub(crate) async fn close_client(
State(state): State<Arc<InternalState>>,
Path(SessionCloseClientQuery { protocol, ip, port }): Path<SessionCloseClientQuery>,
) -> Result<impl IntoResponse, impl IntoResponse> {
let listening_ip: IpAddr = ip
.parse()
.map_err(|_| (StatusCode::BAD_REQUEST, ApiErrorStatus::InvalidInput))?;
{
let mut open_listeners = state.open_listeners.write().await;
let mut to_remove = Vec::new();
open_listeners
.iter()
.filter(|(ListenerId(proto, addr), _)| {
let protocol: hopr_lib::IpProtocol = protocol.into();
protocol == *proto && addr.ip() == listening_ip && (addr.port() == port || port == 0)
})
.for_each(|(id, _)| to_remove.push(*id));
if to_remove.is_empty() {
return Err((StatusCode::NOT_FOUND, ApiErrorStatus::InvalidInput));
}
for bound_addr in to_remove {
let entry = open_listeners
.remove(&bound_addr)
.ok_or((StatusCode::NOT_FOUND, ApiErrorStatus::InvalidInput))?;
hopr_async_runtime::prelude::cancel_join_handle(entry.jh).await;
}
}
Ok::<_, (StatusCode, ApiErrorStatus)>((StatusCode::NO_CONTENT, "").into_response())
}
async fn try_restricted_bind<F, S, Fut>(
addrs: Vec<std::net::SocketAddr>,
range_str: &str,
binder: F,
) -> std::io::Result<S>
where
F: Fn(Vec<std::net::SocketAddr>) -> Fut,
Fut: Future<Output = std::io::Result<S>>,
{
if addrs.is_empty() {
return Err(std::io::Error::other("no valid socket addresses found"));
}
let range = range_str
.split_once(":")
.and_then(
|(a, b)| match u16::from_str(a).and_then(|a| Ok((a, u16::from_str(b)?))) {
Ok((a, b)) if a <= b => Some(a..=b),
_ => None,
},
)
.ok_or(std::io::Error::other(format!("invalid port range {range_str}")))?;
for port in range {
let addrs = addrs
.iter()
.map(|addr| std::net::SocketAddr::new(addr.ip(), port))
.collect::<Vec<_>>();
match binder(addrs).await {
Ok(listener) => return Ok(listener),
Err(error) => debug!(%error, "listen address not usable"),
}
}
Err(std::io::Error::new(
std::io::ErrorKind::AddrNotAvailable,
format!("no valid socket addresses found within range: {range_str}"),
))
}
async fn tcp_listen_on<A: std::net::ToSocketAddrs>(address: A) -> std::io::Result<(std::net::SocketAddr, TcpListener)> {
let addrs = address.to_socket_addrs()?.collect::<Vec<_>>();
if addrs.iter().all(|a| a.port() == 0) {
if let Ok(range_str) = std::env::var(crate::env::HOPRD_SESSION_PORT_RANGE) {
let tcp_listener =
try_restricted_bind(
addrs,
&range_str,
|a| async move { TcpListener::bind(a.as_slice()).await },
)
.await?;
return Ok((tcp_listener.local_addr()?, tcp_listener));
}
}
let tcp_listener = TcpListener::bind(addrs.as_slice()).await?;
Ok((tcp_listener.local_addr()?, tcp_listener))
}
async fn udp_bind_to<A: std::net::ToSocketAddrs>(
address: A,
) -> std::io::Result<(std::net::SocketAddr, ConnectedUdpStream)> {
let addrs = address.to_socket_addrs()?.collect::<Vec<_>>();
let builder = ConnectedUdpStream::builder()
.with_buffer_size(HOPR_UDP_BUFFER_SIZE)
.with_foreign_data_mode(ForeignDataMode::Discard) .with_queue_size(HOPR_UDP_QUEUE_SIZE)
.with_receiver_parallelism(UdpStreamParallelism::Auto);
if addrs.iter().all(|a| a.port() == 0) {
if let Ok(range_str) = std::env::var(crate::env::HOPRD_SESSION_PORT_RANGE) {
let udp_listener = try_restricted_bind(addrs, &range_str, |addrs| {
futures::future::ready(builder.clone().build(addrs.as_slice()))
})
.await?;
return Ok((*udp_listener.bound_address(), udp_listener));
}
}
let udp_socket = builder.build(address)?;
Ok((*udp_socket.bound_address(), udp_socket))
}
async fn bind_session_to_stream<T>(mut session: HoprSession, mut stream: T, max_buf: usize)
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let session_id = *session.id();
match transfer_session(&mut session, &mut stream, max_buf).await {
Ok((session_to_stream_bytes, stream_to_session_bytes)) => info!(
session_id = ?session_id,
session_to_stream_bytes, stream_to_session_bytes, "client session ended",
),
Err(error) => error!(
session_id = ?session_id,
%error,
"error during data transfer"
),
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Context;
use futures::channel::mpsc::UnboundedSender;
use hopr_lib::{ApplicationData, Keypair, PeerId, SendMsg};
use hopr_transport_session::errors::TransportSessionError;
use std::collections::HashSet;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
pub struct SendMsgResender {
tx: UnboundedSender<Box<[u8]>>,
}
impl SendMsgResender {
pub fn new(tx: UnboundedSender<Box<[u8]>>) -> Self {
Self { tx }
}
}
#[hopr_lib::async_trait]
impl SendMsg for SendMsgResender {
async fn send_message(
&self,
data: ApplicationData,
_destination: PeerId,
_options: hopr_lib::RoutingOptions,
) -> std::result::Result<(), TransportSessionError> {
let (_peer, data) = hopr_transport_session::types::unwrap_offchain_key(data.plain_text)?;
self.tx
.clone()
.unbounded_send(data)
.map_err(|_| TransportSessionError::Closed)?;
Ok(())
}
}
#[tokio::test]
async fn hoprd_session_connection_should_create_a_working_tcp_socket_through_which_data_can_be_sent_and_received(
) -> anyhow::Result<()> {
let (tx, rx) = futures::channel::mpsc::unbounded::<Box<[u8]>>();
let peer: hopr_lib::PeerId = hopr_lib::HoprOffchainKeypair::random().public().into();
let session = hopr_lib::HoprSession::new(
hopr_lib::HoprSessionId::new(4567, peer),
peer,
hopr_lib::RoutingOptions::IntermediatePath(Default::default()),
HashSet::default(),
Arc::new(SendMsgResender::new(tx)),
rx,
None,
);
let (bound_addr, tcp_listener) = tcp_listen_on(("127.0.0.1", 0)).await.context("listen_on failed")?;
tokio::task::spawn(async move {
match tcp_listener.accept().await {
Ok((stream, _)) => bind_session_to_stream(session, stream, HOPR_TCP_BUFFER_SIZE).await,
Err(e) => error!("failed to accept connection: {e}"),
}
});
let mut tcp_stream = tokio::net::TcpStream::connect(bound_addr)
.await
.context("connect failed")?;
let data = vec![b"hello", b"world", b"this ", b"is ", b" a", b" test"];
for d in data.clone().into_iter() {
tcp_stream.write_all(d).await.context("write failed")?;
}
for d in data.iter() {
let mut buf = vec![0; d.len()];
tcp_stream.read_exact(&mut buf).await.context("read failed")?;
}
Ok(())
}
#[tokio::test]
async fn hoprd_session_connection_should_create_a_working_udp_socket_through_which_data_can_be_sent_and_received(
) -> anyhow::Result<()> {
let (tx, rx) = futures::channel::mpsc::unbounded::<Box<[u8]>>();
let peer: hopr_lib::PeerId = hopr_lib::HoprOffchainKeypair::random().public().into();
let session = hopr_lib::HoprSession::new(
hopr_lib::HoprSessionId::new(4567, peer),
peer,
hopr_lib::RoutingOptions::IntermediatePath(Default::default()),
HashSet::default(),
Arc::new(SendMsgResender::new(tx)),
rx,
None,
);
let (listen_addr, udp_listener) = udp_bind_to(("127.0.0.1", 0)).await.context("udp_bind_to failed")?;
tokio::task::spawn(bind_session_to_stream(
session,
udp_listener,
hopr_lib::SESSION_USABLE_MTU_SIZE,
));
let mut udp_stream = ConnectedUdpStream::builder()
.with_buffer_size(hopr_lib::SESSION_USABLE_MTU_SIZE)
.with_queue_size(HOPR_UDP_QUEUE_SIZE)
.with_counterparty(listen_addr)
.build(("127.0.0.1", 0))
.context("bind failed")?;
let data = vec![b"hello", b"world", b"this ", b"is ", b" a", b" test"];
for d in data.clone().into_iter() {
udp_stream.write_all(d).await.context("write failed")?;
}
for d in data.iter() {
let mut buf = vec![0; d.len()];
udp_stream.read_exact(&mut buf).await.context("read failed")?;
}
Ok(())
}
#[test]
fn test_build_binding_address() {
let default = "10.0.0.1:10000".parse().unwrap();
let result = build_binding_host(Some("127.0.0.1:10000"), default);
assert_eq!(result, "127.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
let result = build_binding_host(None, default);
assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
let result = build_binding_host(Some("127.0.0.1"), default);
assert_eq!(result, "127.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
let result = build_binding_host(Some(":1234"), default);
assert_eq!(result, "10.0.0.1:1234".parse::<std::net::SocketAddr>().unwrap());
let result = build_binding_host(Some(":"), default);
assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
let result = build_binding_host(Some(""), default);
assert_eq!(result, "10.0.0.1:10000".parse::<std::net::SocketAddr>().unwrap());
}
}