hopr_transport_p2p/behavior/
discovery.rs

1/// TODO: Add discovery documentation here
2use std::collections::{HashMap, HashSet, VecDeque};
3
4use futures::stream::{BoxStream, Stream, StreamExt};
5use futures_concurrency::stream::Merge;
6use libp2p::{
7    swarm::{dial_opts::DialOpts, dummy::ConnectionHandler, CloseConnection, NetworkBehaviour, ToSwarm},
8    Multiaddr, PeerId,
9};
10use tracing::debug;
11
12use hopr_transport_network::network::NetworkTriggeredEvent;
13use hopr_transport_protocol::PeerDiscovery;
14
15#[derive(Debug)]
16pub enum DiscoveryInput {
17    NetworkUpdate(NetworkTriggeredEvent),
18    Indexer(PeerDiscovery),
19}
20
21#[derive(Debug)]
22pub enum Event {}
23
24pub struct Behaviour {
25    me: PeerId,
26    events: BoxStream<'static, DiscoveryInput>,
27    pending_events: VecDeque<
28        libp2p::swarm::ToSwarm<
29            <Self as NetworkBehaviour>::ToSwarm,
30            <<Self as NetworkBehaviour>::ConnectionHandler as libp2p::swarm::ConnectionHandler>::FromBehaviour,
31        >,
32    >,
33    all_peers: HashMap<PeerId, Multiaddr>,
34    allowed_peers: HashSet<PeerId>,
35    connected_peers: HashMap<PeerId, usize>,
36}
37
38impl Behaviour {
39    pub fn new<T, U>(me: PeerId, network_events: T, onchain_events: U) -> Self
40    where
41        T: Stream<Item = NetworkTriggeredEvent> + Send + 'static,
42        U: Stream<Item = PeerDiscovery> + Send + 'static,
43    {
44        Self {
45            me,
46            events: Box::pin(
47                (
48                    network_events.map(DiscoveryInput::NetworkUpdate),
49                    onchain_events.map(DiscoveryInput::Indexer),
50                )
51                    .merge()
52                    .fuse(),
53            ),
54            all_peers: HashMap::new(),
55            pending_events: VecDeque::new(),
56            allowed_peers: HashSet::new(),
57            connected_peers: HashMap::new(),
58        }
59    }
60
61    fn is_peer_connected(&self, peer: &PeerId) -> bool {
62        self.connected_peers.get(peer).map(|v| *v > 0).unwrap_or(false)
63    }
64}
65
66impl NetworkBehaviour for Behaviour {
67    type ConnectionHandler = ConnectionHandler;
68
69    type ToSwarm = Event;
70
71    fn handle_established_inbound_connection(
72        &mut self,
73        _connection_id: libp2p::swarm::ConnectionId,
74        peer: libp2p::PeerId,
75        _local_addr: &libp2p::Multiaddr,
76        _remote_addr: &libp2p::Multiaddr,
77    ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
78        if self.allowed_peers.contains(&peer) {
79            Ok(Self::ConnectionHandler {})
80        } else {
81            Err(libp2p::swarm::ConnectionDenied::new(crate::errors::P2PError::Logic(
82                format!("Connection from '{peer}' is not allowed"),
83            )))
84        }
85    }
86
87    fn handle_established_outbound_connection(
88        &mut self,
89        _connection_id: libp2p::swarm::ConnectionId,
90        peer: libp2p::PeerId,
91        _addr: &libp2p::Multiaddr,
92        _role_override: libp2p::core::Endpoint,
93        _port_use: libp2p::core::transport::PortUse,
94    ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
95        if self.allowed_peers.contains(&peer) {
96            Ok(Self::ConnectionHandler {})
97        } else {
98            Err(libp2p::swarm::ConnectionDenied::new(crate::errors::P2PError::Logic(
99                format!("Connection to '{peer}' is not allowed"),
100            )))
101        }
102    }
103
104    fn on_swarm_event(&mut self, event: libp2p::swarm::FromSwarm) {
105        match event {
106            libp2p::swarm::FromSwarm::ConnectionEstablished(data) => {
107                *self.connected_peers.entry(data.peer_id).or_insert(0) += 1
108            }
109            libp2p::swarm::FromSwarm::ConnectionClosed(data) => {
110                let v = self.connected_peers.entry(data.peer_id).or_insert(0);
111                if *v > 0 {
112                    *v -= 1;
113                };
114            }
115            libp2p::swarm::FromSwarm::DialFailure(failure) => {
116                // NOTE: libp2p swarm in the current version removes the (PeerId, Multiaddr) from the cache on a dial failure,
117                // therefore it needs to be readded back to the swarm on every dial failure, for now we want to mirror the entire
118                // announcement back to the swarm
119                if let Some(peer_id) = failure.peer_id {
120                    if let Some(multiaddress) = self.all_peers.get(&peer_id) {
121                        self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
122                            peer_id,
123                            address: multiaddress.clone(),
124                        });
125                    }
126                }
127            }
128            _ => {}
129        }
130    }
131
132    fn on_connection_handler_event(
133        &mut self,
134        _peer_id: libp2p::PeerId,
135        _connection_id: libp2p::swarm::ConnectionId,
136        _event: libp2p::swarm::THandlerOutEvent<Self>,
137    ) {
138        // Nothing is necessary here, because no ConnectionHandler events should be generated
139    }
140
141    fn poll(
142        &mut self,
143        cx: &mut std::task::Context<'_>,
144    ) -> std::task::Poll<libp2p::swarm::ToSwarm<Self::ToSwarm, libp2p::swarm::THandlerInEvent<Self>>> {
145        if let Some(value) = self.pending_events.pop_front() {
146            return std::task::Poll::Ready(value);
147        };
148
149        let poll_result = self.events.poll_next_unpin(cx).map(|e| match e {
150            Some(DiscoveryInput::NetworkUpdate(event)) => match event {
151                NetworkTriggeredEvent::CloseConnection(peer) => {
152                    debug!(peer = %peer, "p2p - discovery - Closing connection (reason: low ping connection quality");
153                    if self.is_peer_connected(&peer) {
154                        self.pending_events.push_back(ToSwarm::CloseConnection {
155                            peer_id: peer,
156                            connection: CloseConnection::default(),
157                        });
158                    }
159                }
160                NetworkTriggeredEvent::UpdateQuality(_, _) => {}
161            },
162            Some(DiscoveryInput::Indexer(event)) => match event {
163                PeerDiscovery::Allow(peer) => {
164                    debug!(peer = %peer, "p2p - discovery - Network registry allow");
165                    let _ = self.allowed_peers.insert(peer);
166
167                    if let Some(multiaddress) = self.all_peers.get(&peer) {
168                        self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
169                            peer_id: peer,
170                            address: multiaddress.clone(),
171                        });
172                    }
173                }
174                PeerDiscovery::Ban(peer) => {
175                    debug!(peer = %peer, "p2p - discovery - Network registry ban");
176                    self.allowed_peers.remove(&peer);
177
178                    if self.is_peer_connected(&peer) {
179                        debug!(peer = %peer, "p2p - discovery - Requesting disconnect due to ban");
180                        self.pending_events.push_back(ToSwarm::CloseConnection {
181                            peer_id: peer,
182                            connection: CloseConnection::default(),
183                        });
184                    }
185                }
186                PeerDiscovery::Announce(peer, multiaddresses) => {
187                    if peer != self.me {
188                        debug!(peer = %peer, addresses = tracing::field::debug(&multiaddresses), "p2p - discovery - Announcement");
189                        if let Some(multiaddress) = multiaddresses.last() {
190                            self.all_peers.insert(peer, multiaddress.clone());
191
192                            self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
193                                peer_id: peer,
194                                address: multiaddress.clone(),
195                            });
196
197                            // the dial is important to create a first connection some time before the heartbeat mechanism
198                            // kicks in, otherwise the heartbeat is likely to fail on the first try due to dial and protocol
199                            // negotiation taking longer than the request response timeout
200                            self.pending_events.push_back(ToSwarm::Dial { opts: DialOpts::peer_id(peer).addresses(multiaddresses).build()});
201                        }
202                    }
203                }
204            },
205            None => {}
206        });
207
208        if matches!(poll_result, std::task::Poll::Pending) {
209            std::task::Poll::Pending
210        } else if let Some(value) = self.pending_events.pop_front() {
211            std::task::Poll::Ready(value)
212        } else {
213            std::task::Poll::Pending
214        }
215    }
216}