hopr_transport_p2p/behavior/
discovery.rs

1//! The discovery mechanism uses an external stimulus to trigger the discovery
2//! process on the libp2p side. It is responsible for processing the events
3//! generated by other components and passing them to the libp2p swarm in
4//! an appropriate format.
5use std::collections::{HashMap, HashSet, VecDeque};
6
7use futures::stream::{BoxStream, Stream, StreamExt};
8use hopr_transport_protocol::PeerDiscovery;
9use libp2p::{
10    Multiaddr, PeerId,
11    core::Endpoint,
12    swarm::{
13        CloseConnection, ConnectionDenied, ConnectionId, DialFailure, NetworkBehaviour, ToSwarm,
14        dummy::ConnectionHandler,
15    },
16};
17
18#[derive(Debug)]
19pub enum DiscoveryInput {
20    Indexer(PeerDiscovery),
21}
22
23#[derive(Debug)]
24pub enum Event {
25    IncomingConnection(PeerId, Multiaddr),
26    FailedDial(PeerId),
27}
28
29pub struct Behaviour {
30    me: PeerId,
31    events: BoxStream<'static, DiscoveryInput>,
32    pending_events: VecDeque<
33        libp2p::swarm::ToSwarm<
34            <Self as NetworkBehaviour>::ToSwarm,
35            <<Self as NetworkBehaviour>::ConnectionHandler as libp2p::swarm::ConnectionHandler>::FromBehaviour,
36        >,
37    >,
38    bootstrap_peers: HashMap<PeerId, Vec<Multiaddr>>,
39    allowed_peers: HashSet<PeerId>,
40    connected_peers: HashMap<PeerId, usize>,
41}
42
43impl Behaviour {
44    pub fn new<T>(me: PeerId, onchain_events: T) -> Self
45    where
46        T: Stream<Item = PeerDiscovery> + Send + 'static,
47    {
48        Self {
49            me,
50            events: Box::pin(onchain_events.map(DiscoveryInput::Indexer)),
51            bootstrap_peers: HashMap::new(),
52            pending_events: VecDeque::new(),
53            allowed_peers: HashSet::new(),
54            connected_peers: HashMap::new(),
55        }
56    }
57
58    fn is_peer_connected(&self, peer: &PeerId) -> bool {
59        self.connected_peers.get(peer).map(|v| *v > 0).unwrap_or(false)
60    }
61}
62
63impl NetworkBehaviour for Behaviour {
64    type ConnectionHandler = ConnectionHandler;
65    type ToSwarm = Event;
66
67    #[tracing::instrument(
68        level = "debug",
69        name = "Discovery::handle_established_inbound_connection",
70        skip(self),
71        fields(transport = "p2p discovery"),
72        err(Display)
73    )]
74    fn handle_established_inbound_connection(
75        &mut self,
76        connection_id: libp2p::swarm::ConnectionId,
77        peer: libp2p::PeerId,
78        local_addr: &libp2p::Multiaddr,
79        remote_addr: &libp2p::Multiaddr,
80    ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
81        let is_allowed = self.allowed_peers.contains(&peer);
82        tracing::trace!(%is_allowed, direction = "outbound", "Handling peer connection");
83
84        if is_allowed {
85            self.pending_events
86                .push_back(ToSwarm::GenerateEvent(Event::IncomingConnection(
87                    peer,
88                    remote_addr.clone(),
89                )));
90        }
91
92        is_allowed.then_some(Self::ConnectionHandler {}).ok_or_else(|| {
93            libp2p::swarm::ConnectionDenied::new(crate::errors::P2PError::Logic(format!(
94                "Connection from '{peer}' is not allowed"
95            )))
96        })
97    }
98
99    #[tracing::instrument(
100        level = "debug",
101        name = "Discovery::handle_pending_outbound_connection"
102        skip(self),
103        fields(transport = "p2p discovery"),
104        ret(Debug),
105        err(Display)
106    )]
107    fn handle_pending_outbound_connection(
108        &mut self,
109        connection_id: ConnectionId,
110        maybe_peer: Option<PeerId>,
111        addresses: &[Multiaddr],
112        effective_role: Endpoint,
113    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
114        if let Some(peer) = maybe_peer {
115            let is_allowed = self.allowed_peers.contains(&peer);
116            tracing::trace!(%is_allowed, direction = "inbound", "Handling peer connection");
117
118            if self.allowed_peers.contains(&peer) {
119                // inject the multiaddress of the peer for possible dial usage by stream protocols
120                return Ok(self
121                    .bootstrap_peers
122                    .get(&peer)
123                    .map_or(vec![], |addresses| addresses.clone()));
124            } else {
125                return Err(libp2p::swarm::ConnectionDenied::new(crate::errors::P2PError::Logic(
126                    format!("Connection to '{peer}' is not allowed"),
127                )));
128            }
129        }
130
131        Ok(vec![])
132    }
133
134    #[tracing::instrument(
135        level = "trace",
136        name = "Discovery::handle_established_outbound_connection",
137        skip(self),
138        fields(transport = "p2p discovery"),
139        err(Display)
140    )]
141    fn handle_established_outbound_connection(
142        &mut self,
143        connection_id: libp2p::swarm::ConnectionId,
144        peer: libp2p::PeerId,
145        addr: &libp2p::Multiaddr,
146        role_override: libp2p::core::Endpoint,
147        port_use: libp2p::core::transport::PortUse,
148    ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
149        // cannot connect without the handle_ending_outbound_connection being called first
150        Ok(Self::ConnectionHandler {})
151    }
152
153    #[tracing::instrument(
154        level = "debug",
155        name = "Discovery::on_swarm_event"
156        skip(self),
157        fields(transport = "p2p discovery"),
158    )]
159    fn on_swarm_event(&mut self, event: libp2p::swarm::FromSwarm) {
160        match event {
161            libp2p::swarm::FromSwarm::ConnectionEstablished(data) => {
162                *self.connected_peers.entry(data.peer_id).or_insert(0) += 1
163            }
164            libp2p::swarm::FromSwarm::ConnectionClosed(data) => {
165                let v = self.connected_peers.entry(data.peer_id).or_insert(0);
166                if *v > 0 {
167                    *v -= 1;
168                };
169            }
170            libp2p::swarm::FromSwarm::DialFailure(DialFailure { peer_id, error, .. }) => {
171                tracing::debug!(?peer_id, %error, "Failed to dial peer");
172
173                if let Some(peer) = peer_id {
174                    self.pending_events
175                        .push_back(ToSwarm::GenerateEvent(Event::FailedDial(peer)));
176                }
177            }
178            _ => {}
179        }
180    }
181
182    fn on_connection_handler_event(
183        &mut self,
184        _peer_id: libp2p::PeerId,
185        _connection_id: libp2p::swarm::ConnectionId,
186        _event: libp2p::swarm::THandlerOutEvent<Self>,
187    ) {
188        // Nothing is necessary here, because no ConnectionHandler events should be generated
189    }
190
191    #[tracing::instrument(
192        level = "debug",
193        name = "Discovery::poll"
194        skip(self, cx),
195        fields(transport = "p2p discovery")
196    )]
197    fn poll(
198        &mut self,
199        cx: &mut std::task::Context<'_>,
200    ) -> std::task::Poll<libp2p::swarm::ToSwarm<Self::ToSwarm, libp2p::swarm::THandlerInEvent<Self>>> {
201        if let Some(value) = self.pending_events.pop_front() {
202            return std::task::Poll::Ready(value);
203        };
204
205        let poll_result = self.events.poll_next_unpin(cx).map(|e| match e {
206            Some(DiscoveryInput::Indexer(event)) => match event {
207                PeerDiscovery::Allow(peer) => {
208                    let inserted_into_allow_list = self.allowed_peers.insert(peer);
209
210                    let multiaddresses = self.bootstrap_peers.get(&peer);
211                    if let Some(multiaddresses) = multiaddresses {
212                        for address in multiaddresses {
213                            self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
214                                peer_id: peer,
215                                address: address.clone(),
216                            });
217                        }
218                    }
219
220                    tracing::debug!(%peer, state = "allow", inserted_into_allow_list, emitted_libp2p_address_announce = multiaddresses.is_some_and(|v| !v.is_empty()), "Network registry");
221                }
222                PeerDiscovery::Ban(peer) => {
223                    let was_allowed = self.allowed_peers.remove(&peer);
224                    let is_connected = self.is_peer_connected(&peer);
225
226                    if is_connected {
227                        self.pending_events.push_back(ToSwarm::CloseConnection {
228                            peer_id: peer,
229                            connection: CloseConnection::default(),
230                        });
231                    }
232
233                    tracing::debug!(%peer, state = "ban", was_allowed, will_close_active_connection = is_connected, "Network registry");
234                }
235                PeerDiscovery::Announce(peer, multiaddresses) => {
236                    if peer != self.me {
237                        tracing::debug!(%peer, addresses = ?&multiaddresses, "Announcement");
238
239                        for multiaddress in &multiaddresses {
240                            self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
241                                peer_id: peer,
242                                address: multiaddress.clone(),
243                            });
244                        }
245
246                        self.bootstrap_peers.insert(peer, multiaddresses.clone());
247                    }
248                }
249            },
250            None => {}
251        });
252
253        if matches!(poll_result, std::task::Poll::Pending) {
254            std::task::Poll::Pending
255        } else if let Some(value) = self.pending_events.pop_front() {
256            std::task::Poll::Ready(value)
257        } else {
258            std::task::Poll::Pending
259        }
260    }
261}