use futures::stream::{Stream, StreamExt};
use futures::{
channel::{
mpsc::{channel, Receiver, Sender},
oneshot,
},
future::{poll_fn, Either},
pin_mut,
};
use libp2p::request_response::{OutboundRequestId, ResponseChannel};
use rust_stream_ext_concurrent::then_concurrent::StreamThenConcurrentExt;
use std::{pin::Pin, task::Poll};
use tracing::{error, warn};
use hopr_async_runtime::prelude::{sleep, spawn};
use hopr_crypto_types::prelude::*;
use hopr_db_api::{
errors::DbError,
tickets::{AggregationPrerequisites, HoprDbTicketOperations},
};
use hopr_internal_types::prelude::*;
use hopr_transport_identity::PeerId;
use crate::errors::{
ProtocolError::{Retry, TransportError},
Result,
};
#[cfg(all(feature = "prometheus", not(test)))]
use hopr_metrics::metrics::SimpleCounter;
#[cfg(all(feature = "prometheus", not(test)))]
lazy_static::lazy_static! {
static ref METRIC_AGGREGATED_TICKETS: SimpleCounter = SimpleCounter::new(
"hopr_aggregated_tickets_count",
"Number of aggregated tickets"
)
.unwrap();
static ref METRIC_AGGREGATION_COUNT: SimpleCounter = SimpleCounter::new(
"hopr_aggregations_count",
"Number of performed ticket aggregations"
)
.unwrap();
}
pub const TICKET_AGGREGATION_TX_QUEUE_SIZE: usize = 2048;
pub const TICKET_AGGREGATION_RX_QUEUE_SIZE: usize = 2048;
#[allow(clippy::type_complexity)] #[allow(clippy::large_enum_variant)] #[derive(Debug)]
pub enum TicketAggregationToProcess<T, U> {
ToReceive(PeerId, std::result::Result<Ticket, String>, U),
ToProcess(PeerId, Vec<TransferableWinningTicket>, T),
ToSend(Hash, AggregationPrerequisites, TicketAggregationFinalizer),
}
#[allow(clippy::large_enum_variant)] #[derive(Debug)]
pub enum TicketAggregationProcessed<T, U> {
Receive(PeerId, AcknowledgedTicket, U),
Reply(PeerId, std::result::Result<Ticket, String>, T),
Send(
PeerId,
Vec<hopr_internal_types::legacy::AcknowledgedTicket>,
TicketAggregationFinalizer,
),
}
#[async_trait::async_trait]
pub trait TicketAggregatorTrait {
async fn aggregate_tickets(&self, channel: &Hash, prerequisites: AggregationPrerequisites) -> Result<()>;
}
#[derive(Debug)]
pub struct AwaitingAggregator<T, U, Db>
where
Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug,
T: Send,
U: Send,
{
db: Db,
writer: TicketAggregationActions<T, U>,
agg_timeout: std::time::Duration,
}
impl<T, U, Db> Clone for AwaitingAggregator<T, U, Db>
where
Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug,
T: Send,
U: Send,
{
fn clone(&self) -> Self {
Self {
db: self.db.clone(),
writer: self.writer.clone(),
agg_timeout: self.agg_timeout,
}
}
}
impl<T, U, Db> AwaitingAggregator<T, U, Db>
where
Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug,
T: Send,
U: Send,
{
pub fn new(db: Db, writer: TicketAggregationActions<T, U>, agg_timeout: std::time::Duration) -> Self {
Self {
db,
writer,
agg_timeout,
}
}
}
#[async_trait::async_trait]
impl<T, U, Db> TicketAggregatorTrait for AwaitingAggregator<T, U, Db>
where
Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug,
T: Send,
U: Send,
{
#[tracing::instrument(level = "debug", skip(self))]
async fn aggregate_tickets(&self, channel: &Hash, prerequisites: AggregationPrerequisites) -> Result<()> {
let awaiter = self.writer.clone().aggregate_tickets(channel, prerequisites)?;
if let Err(e) = awaiter.consume_and_wait(self.agg_timeout).await {
warn!(%channel, error = %e, "Error during ticket aggregation, performing a rollback");
self.db.rollback_aggregation_in_channel(*channel).await?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct TicketAggregationAwaiter {
rx: oneshot::Receiver<()>,
}
impl From<oneshot::Receiver<()>> for TicketAggregationAwaiter {
fn from(value: oneshot::Receiver<()>) -> Self {
Self { rx: value }
}
}
impl TicketAggregationAwaiter {
pub async fn consume_and_wait(self, until_timeout: std::time::Duration) -> Result<()> {
let timeout = sleep(until_timeout);
let resolve = self.rx;
pin_mut!(resolve, timeout);
match futures::future::select(resolve, timeout).await {
Either::Left((result, _)) => result.map_err(|_| TransportError("Canceled".to_owned())),
Either::Right(_) => Err(TransportError("Timed out on sending a packet".to_owned())),
}
}
}
#[derive(Debug)]
pub struct TicketAggregationFinalizer {
tx: Option<oneshot::Sender<()>>,
}
impl TicketAggregationFinalizer {
pub fn new(tx: oneshot::Sender<()>) -> Self {
Self { tx: Some(tx) }
}
pub fn finalize(mut self) {
if let Some(sender) = self.tx.take() {
if sender.send(()).is_err() {
error!("Failed to notify the awaiter about the successful ticket aggregation")
}
} else {
error!("Sender for packet send signalization is already spent")
}
}
}
#[derive(Debug)]
pub struct TicketAggregationActions<T, U> {
pub queue: Sender<TicketAggregationToProcess<T, U>>,
}
pub type BasicTicketAggregationActions<T> = TicketAggregationActions<ResponseChannel<T>, OutboundRequestId>;
impl<T, U> Clone for TicketAggregationActions<T, U> {
fn clone(&self) -> Self {
Self {
queue: self.queue.clone(),
}
}
}
impl<T, U> TicketAggregationActions<T, U> {
pub fn receive_ticket(
&mut self,
source: PeerId,
ticket: std::result::Result<Ticket, String>,
request: U,
) -> Result<()> {
self.process(TicketAggregationToProcess::ToReceive(source, ticket, request))
}
pub fn receive_aggregation_request(
&mut self,
source: PeerId,
tickets: Vec<TransferableWinningTicket>,
request: T,
) -> Result<()> {
self.process(TicketAggregationToProcess::ToProcess(source, tickets, request))
}
pub fn aggregate_tickets(
&mut self,
channel: &Hash,
prerequisites: AggregationPrerequisites,
) -> Result<TicketAggregationAwaiter> {
let (tx, rx) = oneshot::channel::<()>();
self.process(TicketAggregationToProcess::ToSend(
*channel,
prerequisites,
TicketAggregationFinalizer::new(tx),
))?;
Ok(rx.into())
}
fn process(&mut self, event: TicketAggregationToProcess<T, U>) -> Result<()> {
self.queue.try_send(event).map_err(|e| {
if e.is_full() {
Retry
} else if e.is_disconnected() {
TransportError("queue is closed".to_string())
} else {
TransportError(format!("Unknown error: {}", e))
}
})
}
}
type AckEventQueue<T, U> = (
Sender<TicketAggregationToProcess<T, U>>,
Receiver<TicketAggregationProcessed<T, U>>,
);
pub struct TicketAggregationInteraction<T, U>
where
T: Send,
U: Send,
{
ack_event_queue: AckEventQueue<T, U>,
}
impl<T: 'static, U: 'static> TicketAggregationInteraction<T, U>
where
T: Send,
U: Send,
{
pub fn new<Db>(db: Db, chain_key: &ChainKeypair) -> Self
where
Db: HoprDbTicketOperations + Send + Sync + Clone + std::fmt::Debug + 'static,
{
let (processing_in_tx, processing_in_rx) = channel::<TicketAggregationToProcess<T, U>>(
TICKET_AGGREGATION_RX_QUEUE_SIZE + TICKET_AGGREGATION_TX_QUEUE_SIZE,
);
let (processing_out_tx, processing_out_rx) = channel::<TicketAggregationProcessed<T, U>>(
TICKET_AGGREGATION_RX_QUEUE_SIZE + TICKET_AGGREGATION_TX_QUEUE_SIZE,
);
let chain_key = chain_key.clone();
let mut processing_stream = processing_in_rx.then_concurrent(move |event| {
let chain_key = chain_key.clone();
let db = db.clone();
let mut processed_tx = processing_out_tx.clone();
async move {
let processed = match event {
TicketAggregationToProcess::ToProcess(destination, acked_tickets, response) => {
let opk: std::result::Result<OffchainPublicKey, hopr_primitive_types::errors::GeneralError> =
destination.try_into();
match opk {
Ok(opk) => match db.aggregate_tickets(opk, acked_tickets, &chain_key).await {
Ok(ticket) => Some(TicketAggregationProcessed::Reply(
destination,
Ok(ticket.leak()),
response,
)),
Err(DbError::TicketAggregationError(e)) => {
Some(TicketAggregationProcessed::Reply(destination, Err(e), response))
}
Err(e) => {
error!(error = %e, "Dropping tickets aggregation request due to an error");
None
}
},
Err(e) => {
error!(
?destination, error = %e,
"Failed to deserialize the destination to an offchain public key"
);
None
}
}
}
TicketAggregationToProcess::ToReceive(destination, aggregated_ticket, request) => {
match aggregated_ticket {
Ok(ticket) => match db.process_received_aggregated_ticket(ticket.clone(), &chain_key).await
{
Ok(acked_ticket) => {
Some(TicketAggregationProcessed::Receive(destination, acked_ticket, request))
}
Err(e) => {
error!(error = %e, "Error while handling aggregated ticket");
None
}
},
Err(e) => {
warn!(error = %e, "Counterparty refused to aggregate tickets");
None
}
}
}
TicketAggregationToProcess::ToSend(channel, prerequsites, finalizer) => {
match db.prepare_aggregation_in_channel(&channel, prerequsites).await {
Ok(Some((source, tickets, dst))) if !tickets.is_empty() => {
#[cfg(all(feature = "prometheus", not(test)))]
{
METRIC_AGGREGATED_TICKETS.increment_by(tickets.len() as u64);
METRIC_AGGREGATION_COUNT.increment();
}
let addr = chain_key.public().to_address();
let tickets = tickets
.into_iter()
.map(|t| hopr_internal_types::legacy::AcknowledgedTicket::new(t, &addr, &dst))
.collect::<Vec<_>>();
Some(TicketAggregationProcessed::Send(source.into(), tickets, finalizer))
}
Err(e) => {
error!(error = %e, "An error occured when preparing the channel aggregation");
None
}
_ => {
finalizer.finalize();
None
}
}
}
};
if let Some(event) = processed {
match poll_fn(|cx| Pin::new(&mut processed_tx).poll_ready(cx)).await {
Ok(_) => match processed_tx.start_send(event) {
Ok(_) => {}
Err(e) => error!(error = %e, "Failed to pass a processed ack message"),
},
Err(e) => {
warn!(error = %e, "The receiver for processed ack no longer exists");
}
};
}
}
});
spawn(async move {
while processing_stream.next().await.is_some() {}
});
Self {
ack_event_queue: (processing_in_tx, processing_out_rx),
}
}
pub fn writer(&self) -> TicketAggregationActions<T, U> {
TicketAggregationActions {
queue: self.ack_event_queue.0.clone(),
}
}
}
impl<T, U> Stream for TicketAggregationInteraction<T, U>
where
T: Send,
U: Send,
{
type Item = TicketAggregationProcessed<T, U>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(self).ack_event_queue.1.poll_next_unpin(cx)
}
}
#[cfg(test)]
mod tests {
use super::TicketAggregationProcessed;
use async_std::prelude::FutureExt;
use futures::pin_mut;
use futures::stream::StreamExt;
use hex_literal::hex;
use hopr_crypto_types::{
keypairs::{ChainKeypair, Keypair, OffchainKeypair},
types::{Hash, Response},
};
use hopr_db_sql::accounts::HoprDbAccountOperations;
use hopr_db_sql::api::tickets::HoprDbTicketOperations;
use hopr_db_sql::channels::HoprDbChannelOperations;
use hopr_db_sql::info::HoprDbInfoOperations;
use hopr_db_sql::HoprDbGeneralModelOperations;
use hopr_db_sql::{api::info::DomainSeparator, db::HoprDb};
use hopr_internal_types::prelude::*;
use hopr_primitive_types::prelude::*;
use lazy_static::lazy_static;
use std::ops::{Add, Mul};
use std::time::Duration;
lazy_static! {
static ref PEERS: Vec<OffchainKeypair> = [
hex!("b91a28ff9840e9c93e5fafd581131f0b9f33f3e61b02bf5dd83458aa0221f572"),
hex!("82283757872f99541ce33a47b90c2ce9f64875abf08b5119a8a434b2fa83ea98")
]
.iter()
.map(|private| OffchainKeypair::from_secret(private).expect("lazy static keypair should be valid"))
.collect();
static ref PEERS_CHAIN: Vec<ChainKeypair> = [
hex!("51d3003d908045a4d76d0bfc0d84f6ff946b5934b7ea6a2958faf02fead4567a"),
hex!("e1f89073a01831d0eed9fe2c67e7d65c144b9d9945320f6d325b1cccc2d124e9")
]
.iter()
.map(|private| ChainKeypair::from_secret(private).expect("lazy static keypair should be valid"))
.collect();
}
fn mock_acknowledged_ticket(
signer: &ChainKeypair,
destination: &ChainKeypair,
index: u64,
) -> anyhow::Result<AcknowledgedTicket> {
let price_per_packet: U256 = 10000000000000000u128.into();
let ticket_win_prob = 1.0f64;
let channel_id = generate_channel_id(&signer.into(), &destination.into());
let channel_epoch = 1u64;
let domain_separator = Hash::default();
let response = Response::try_from(
Hash::create(&[channel_id.as_ref(), &channel_epoch.to_be_bytes(), &index.to_be_bytes()]).as_ref(),
)?;
Ok(TicketBuilder::default()
.addresses(signer, destination)
.amount(price_per_packet.div_f64(ticket_win_prob)?)
.index(index)
.index_offset(1)
.win_prob(ticket_win_prob)
.channel_epoch(1)
.challenge(response.to_challenge().into())
.build_signed(signer, &domain_separator)?
.into_acknowledged(response))
}
async fn init_db(db: HoprDb) -> anyhow::Result<()> {
let db_clone = db.clone();
let peers = PEERS.clone();
let peers_chain = PEERS_CHAIN.clone();
db.begin_transaction()
.await?
.perform(move |tx| {
Box::pin(async move {
db_clone
.set_domain_separator(Some(tx), DomainSeparator::Channel, Hash::default())
.await?;
for (offchain, chain) in peers.iter().zip(peers_chain.iter()) {
db_clone
.insert_account(
Some(tx),
AccountEntry::new(
*offchain.public(),
chain.public().to_address(),
AccountType::NotAnnounced,
),
)
.await?
}
Ok::<(), hopr_db_sql::errors::DbSqlError>(())
})
})
.await?;
Ok(())
}
#[async_std::test]
async fn test_ticket_aggregation() -> anyhow::Result<()> {
let db_alice = HoprDb::new_in_memory(PEERS_CHAIN[0].clone()).await?;
let db_bob = HoprDb::new_in_memory(PEERS_CHAIN[1].clone()).await?;
init_db(db_alice.clone()).await?;
init_db(db_bob.clone()).await?;
const NUM_TICKETS: u64 = 30;
let mut tickets = vec![];
let mut agg_balance = Balance::zero(BalanceType::HOPR);
for i in 1..=NUM_TICKETS {
let mut ack_ticket = mock_acknowledged_ticket(&PEERS_CHAIN[0], &PEERS_CHAIN[1], i)?;
if i == 1 {
ack_ticket.status = AcknowledgedTicketStatus::BeingRedeemed;
} else {
agg_balance = agg_balance.add(&ack_ticket.verified_ticket().amount);
}
tickets.push(ack_ticket)
}
let alice_addr: Address = (&PEERS_CHAIN[0]).into();
let bob_addr: Address = (&PEERS_CHAIN[1]).into();
let alice_packet_key = PEERS[0].public().into();
let bob_packet_key = PEERS[1].public().into();
let channel_alice_bob = ChannelEntry::new(
alice_addr,
bob_addr,
agg_balance.mul(10),
1_u32.into(),
ChannelStatus::Open,
1u32.into(),
);
db_alice.upsert_channel(None, channel_alice_bob).await?;
db_bob.upsert_channel(None, channel_alice_bob).await?;
for ticket in tickets.into_iter() {
db_bob.upsert_ticket(None, ticket).await?;
}
let (bob_notify_tx, bob_notify_rx) = futures::channel::mpsc::unbounded();
db_bob.start_ticket_processing(bob_notify_tx.into())?;
let mut alice = super::TicketAggregationInteraction::<(), ()>::new(db_alice.clone(), &PEERS_CHAIN[0]);
let mut bob = super::TicketAggregationInteraction::<(), ()>::new(db_bob.clone(), &PEERS_CHAIN[1]);
let awaiter = bob
.writer()
.aggregate_tickets(&channel_alice_bob.get_id(), Default::default())?;
let mut finalizer = None;
match bob.next().timeout(Duration::from_secs(5)).await {
Ok(Some(TicketAggregationProcessed::Send(_, acked_tickets, request_finalizer))) => {
let _ = finalizer.insert(request_finalizer);
assert_eq!(
NUM_TICKETS - 1,
acked_tickets.len() as u64,
"invalid number of tickets to aggregate"
);
alice.writer().receive_aggregation_request(
bob_packet_key,
acked_tickets.into_iter().map(TransferableWinningTicket::from).collect(),
(),
)?;
}
_ => panic!("unexpected action happened while sending agg request by Bob"),
};
match alice.next().timeout(Duration::from_secs(5)).await {
Ok(Some(TicketAggregationProcessed::Reply(_, aggregated_ticket, ()))) => {
bob.writer().receive_ticket(alice_packet_key, aggregated_ticket, ())?
}
_ => panic!("unexpected action happened while awaiting agg request at Alice"),
};
match bob.next().timeout(Duration::from_secs(5)).await {
Ok(Some(TicketAggregationProcessed::Receive(_destination, _acked_tkt, ()))) => {
finalizer.take().expect("finalizer should be present").finalize()
}
_ => panic!("unexpected action happened while awaiting agg response at Bob"),
}
pin_mut!(bob_notify_rx);
bob_notify_rx
.next()
.await
.expect("bob should have received the ticket notification");
let stored_acked_tickets = db_bob.get_tickets((&channel_alice_bob).into()).await?;
assert_eq!(
stored_acked_tickets.len(),
2,
"there should be 1 aggregated ticket and 1 ticket being redeemed"
);
assert_eq!(
AcknowledgedTicketStatus::BeingRedeemed,
stored_acked_tickets[0].status,
"first ticket must be being redeemed"
);
assert!(
stored_acked_tickets[1].verified_ticket().is_aggregated(),
"last ticket must be the aggregated one"
);
assert_eq!(
AcknowledgedTicketStatus::Untouched,
stored_acked_tickets[1].status,
"second ticket must be untouched"
);
assert_eq!(
agg_balance,
stored_acked_tickets[1].verified_ticket().amount,
"aggregated balance invalid"
);
Ok(awaiter.consume_and_wait(Duration::from_millis(2000)).await?)
}
#[async_std::test]
async fn test_ticket_aggregation_skip_lower_indices() -> anyhow::Result<()> {
let db_alice = HoprDb::new_in_memory(PEERS_CHAIN[0].clone()).await?;
let db_bob = HoprDb::new_in_memory(PEERS_CHAIN[1].clone()).await?;
init_db(db_alice.clone()).await?;
init_db(db_bob.clone()).await?;
let (bob_notify_tx, bob_notify_rx) = futures::channel::mpsc::unbounded();
db_bob.start_ticket_processing(bob_notify_tx.into())?;
const NUM_TICKETS: u64 = 30;
const CHANNEL_TICKET_IDX: u64 = 20;
let mut tickets = vec![];
let mut agg_balance = Balance::zero(BalanceType::HOPR);
for i in 1..=NUM_TICKETS {
let ack_ticket = mock_acknowledged_ticket(&PEERS_CHAIN[0], &PEERS_CHAIN[1], i)?;
if i >= CHANNEL_TICKET_IDX {
agg_balance = agg_balance.add(&ack_ticket.verified_ticket().amount);
}
tickets.push(ack_ticket)
}
let alice_addr: Address = (&PEERS_CHAIN[0]).into();
let bob_addr: Address = (&PEERS_CHAIN[1]).into();
let alice_packet_key = PEERS[0].public().into();
let bob_packet_key = PEERS[1].public().into();
let channel_alice_bob = ChannelEntry::new(
alice_addr,
bob_addr,
agg_balance.mul(10),
CHANNEL_TICKET_IDX.into(),
ChannelStatus::Open,
1u32.into(),
);
db_alice.upsert_channel(None, channel_alice_bob).await?;
db_bob.upsert_channel(None, channel_alice_bob).await?;
for ticket in tickets.into_iter() {
db_bob.upsert_ticket(None, ticket).await?;
}
let mut alice = super::TicketAggregationInteraction::<(), ()>::new(db_alice.clone(), &PEERS_CHAIN[0]);
let mut bob = super::TicketAggregationInteraction::<(), ()>::new(db_bob.clone(), &PEERS_CHAIN[1]);
let awaiter = bob
.writer()
.aggregate_tickets(&channel_alice_bob.get_id(), Default::default())?;
let mut finalizer = None;
match bob.next().timeout(Duration::from_secs(5)).await {
Ok(Some(TicketAggregationProcessed::Send(_, acked_tickets, request_finalizer))) => {
let _ = finalizer.insert(request_finalizer);
assert_eq!(
NUM_TICKETS - CHANNEL_TICKET_IDX + 1,
acked_tickets.len() as u64,
"invalid number of tickets to aggregate"
);
alice.writer().receive_aggregation_request(
bob_packet_key,
acked_tickets.into_iter().map(TransferableWinningTicket::from).collect(),
(),
)?;
}
_ => panic!("unexpected action happened while sending agg request by Bob"),
};
match alice.next().timeout(Duration::from_secs(5)).await {
Ok(Some(TicketAggregationProcessed::Reply(_, aggregated_ticket, ()))) => {
bob.writer().receive_ticket(alice_packet_key, aggregated_ticket, ())?
}
_ => panic!("unexpected action happened while awaiting agg request at Alice"),
};
match bob.next().timeout(Duration::from_secs(5)).await {
Ok(Some(TicketAggregationProcessed::Receive(_destination, _acked_tkt, ()))) => {
finalizer.take().expect("finalizer should be present").finalize()
}
_ => panic!("unexpected action happened while awaiting agg response at Bob"),
}
pin_mut!(bob_notify_rx);
bob_notify_rx
.next()
.await
.expect("bob should have received the ticket notification");
let stored_acked_tickets = db_bob.get_tickets((&channel_alice_bob).into()).await?;
assert_eq!(
stored_acked_tickets.len(),
20,
"there should be 1 aggregated ticket and 19 old tickets"
);
assert!(
stored_acked_tickets[19].verified_ticket().is_aggregated(),
"last ticket must be the aggregated one"
);
for i in 0..19 {
assert_eq!(
AcknowledgedTicketStatus::Untouched,
stored_acked_tickets[i].status,
"ticket #{i} must be untouched"
);
}
assert_eq!(
agg_balance,
stored_acked_tickets[19].verified_ticket().amount,
"aggregated balance invalid"
);
Ok(awaiter.consume_and_wait(Duration::from_millis(2000)).await?)
}
}