hopr_transport_p2p/behavior/
discovery.rs1use std::collections::{HashMap, HashSet, VecDeque};
6
7use futures::stream::{BoxStream, Stream, StreamExt};
8use hopr_transport_protocol::PeerDiscovery;
9use libp2p::{
10 Multiaddr, PeerId,
11 core::Endpoint,
12 swarm::{
13 CloseConnection, ConnectionDenied, ConnectionId, DialFailure, NetworkBehaviour, ToSwarm,
14 dummy::ConnectionHandler,
15 },
16};
17
18#[derive(Debug)]
19pub enum DiscoveryInput {
20 Indexer(PeerDiscovery),
21}
22
23#[derive(Debug)]
24pub enum Event {
25 IncomingConnection(PeerId, Multiaddr),
26 FailedDial(PeerId),
27}
28
29pub struct Behaviour {
30 me: PeerId,
31 events: BoxStream<'static, DiscoveryInput>,
32 pending_events: VecDeque<
33 libp2p::swarm::ToSwarm<
34 <Self as NetworkBehaviour>::ToSwarm,
35 <<Self as NetworkBehaviour>::ConnectionHandler as libp2p::swarm::ConnectionHandler>::FromBehaviour,
36 >,
37 >,
38 bootstrap_peers: HashMap<PeerId, Vec<Multiaddr>>,
39 allowed_peers: HashSet<PeerId>,
40 connected_peers: HashMap<PeerId, usize>,
41}
42
43impl Behaviour {
44 pub fn new<T>(me: PeerId, onchain_events: T) -> Self
45 where
46 T: Stream<Item = PeerDiscovery> + Send + 'static,
47 {
48 Self {
49 me,
50 events: Box::pin(onchain_events.map(DiscoveryInput::Indexer)),
51 bootstrap_peers: HashMap::new(),
52 pending_events: VecDeque::new(),
53 allowed_peers: HashSet::new(),
54 connected_peers: HashMap::new(),
55 }
56 }
57
58 fn is_peer_connected(&self, peer: &PeerId) -> bool {
59 self.connected_peers.get(peer).map(|v| *v > 0).unwrap_or(false)
60 }
61}
62
63impl NetworkBehaviour for Behaviour {
64 type ConnectionHandler = ConnectionHandler;
65 type ToSwarm = Event;
66
67 #[tracing::instrument(
68 level = "debug",
69 name = "Discovery::handle_established_inbound_connection",
70 skip(self),
71 fields(transport = "p2p discovery"),
72 err(Display)
73 )]
74 fn handle_established_inbound_connection(
75 &mut self,
76 connection_id: libp2p::swarm::ConnectionId,
77 peer: libp2p::PeerId,
78 local_addr: &libp2p::Multiaddr,
79 remote_addr: &libp2p::Multiaddr,
80 ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
81 let is_allowed = self.allowed_peers.contains(&peer);
82 tracing::trace!(%is_allowed, direction = "outbound", "Handling peer connection");
83
84 if is_allowed {
85 self.pending_events
86 .push_back(ToSwarm::GenerateEvent(Event::IncomingConnection(
87 peer,
88 remote_addr.clone(),
89 )));
90 }
91
92 is_allowed.then_some(Self::ConnectionHandler {}).ok_or_else(|| {
93 libp2p::swarm::ConnectionDenied::new(crate::errors::P2PError::Logic(format!(
94 "Connection from '{peer}' is not allowed"
95 )))
96 })
97 }
98
99 #[tracing::instrument(
100 level = "debug",
101 name = "Discovery::handle_pending_outbound_connection"
102 skip(self),
103 fields(transport = "p2p discovery"),
104 ret(Debug),
105 err(Display)
106 )]
107 fn handle_pending_outbound_connection(
108 &mut self,
109 connection_id: ConnectionId,
110 maybe_peer: Option<PeerId>,
111 addresses: &[Multiaddr],
112 effective_role: Endpoint,
113 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
114 if let Some(peer) = maybe_peer {
115 let is_allowed = self.allowed_peers.contains(&peer);
116 tracing::trace!(%is_allowed, direction = "inbound", "Handling peer connection");
117
118 if self.allowed_peers.contains(&peer) {
119 return Ok(self
121 .bootstrap_peers
122 .get(&peer)
123 .map_or(vec![], |addresses| addresses.clone()));
124 } else {
125 return Err(libp2p::swarm::ConnectionDenied::new(crate::errors::P2PError::Logic(
126 format!("Connection to '{peer}' is not allowed"),
127 )));
128 }
129 }
130
131 Ok(vec![])
132 }
133
134 #[tracing::instrument(
135 level = "trace",
136 name = "Discovery::handle_established_outbound_connection",
137 skip(self),
138 fields(transport = "p2p discovery"),
139 err(Display)
140 )]
141 fn handle_established_outbound_connection(
142 &mut self,
143 connection_id: libp2p::swarm::ConnectionId,
144 peer: libp2p::PeerId,
145 addr: &libp2p::Multiaddr,
146 role_override: libp2p::core::Endpoint,
147 port_use: libp2p::core::transport::PortUse,
148 ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
149 Ok(Self::ConnectionHandler {})
151 }
152
153 #[tracing::instrument(
154 level = "debug",
155 name = "Discovery::on_swarm_event"
156 skip(self),
157 fields(transport = "p2p discovery"),
158 )]
159 fn on_swarm_event(&mut self, event: libp2p::swarm::FromSwarm) {
160 match event {
161 libp2p::swarm::FromSwarm::ConnectionEstablished(data) => {
162 *self.connected_peers.entry(data.peer_id).or_insert(0) += 1
163 }
164 libp2p::swarm::FromSwarm::ConnectionClosed(data) => {
165 let v = self.connected_peers.entry(data.peer_id).or_insert(0);
166 if *v > 0 {
167 *v -= 1;
168 };
169 }
170 libp2p::swarm::FromSwarm::DialFailure(DialFailure { peer_id, error, .. }) => {
171 tracing::debug!(?peer_id, %error, "Failed to dial peer");
172
173 if let Some(peer) = peer_id {
174 self.pending_events
175 .push_back(ToSwarm::GenerateEvent(Event::FailedDial(peer)));
176 }
177 }
178 _ => {}
179 }
180 }
181
182 fn on_connection_handler_event(
183 &mut self,
184 _peer_id: libp2p::PeerId,
185 _connection_id: libp2p::swarm::ConnectionId,
186 _event: libp2p::swarm::THandlerOutEvent<Self>,
187 ) {
188 }
190
191 #[tracing::instrument(
192 level = "debug",
193 name = "Discovery::poll"
194 skip(self, cx),
195 fields(transport = "p2p discovery")
196 )]
197 fn poll(
198 &mut self,
199 cx: &mut std::task::Context<'_>,
200 ) -> std::task::Poll<libp2p::swarm::ToSwarm<Self::ToSwarm, libp2p::swarm::THandlerInEvent<Self>>> {
201 if let Some(value) = self.pending_events.pop_front() {
202 return std::task::Poll::Ready(value);
203 };
204
205 let poll_result = self.events.poll_next_unpin(cx).map(|e| match e {
206 Some(DiscoveryInput::Indexer(event)) => match event {
207 PeerDiscovery::Allow(peer) => {
208 let inserted_into_allow_list = self.allowed_peers.insert(peer);
209
210 let multiaddresses = self.bootstrap_peers.get(&peer);
211 if let Some(multiaddresses) = multiaddresses {
212 for address in multiaddresses {
213 self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
214 peer_id: peer,
215 address: address.clone(),
216 });
217 }
218 }
219
220 tracing::debug!(%peer, state = "allow", inserted_into_allow_list, emitted_libp2p_address_announce = multiaddresses.is_some_and(|v| !v.is_empty()), "Network registry");
221 }
222 PeerDiscovery::Ban(peer) => {
223 let was_allowed = self.allowed_peers.remove(&peer);
224 let is_connected = self.is_peer_connected(&peer);
225
226 if is_connected {
227 self.pending_events.push_back(ToSwarm::CloseConnection {
228 peer_id: peer,
229 connection: CloseConnection::default(),
230 });
231 }
232
233 tracing::debug!(%peer, state = "ban", was_allowed, will_close_active_connection = is_connected, "Network registry");
234 }
235 PeerDiscovery::Announce(peer, multiaddresses) => {
236 if peer != self.me {
237 tracing::debug!(%peer, addresses = ?&multiaddresses, "Announcement");
238
239 for multiaddress in &multiaddresses {
240 self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
241 peer_id: peer,
242 address: multiaddress.clone(),
243 });
244 }
245
246 self.bootstrap_peers.insert(peer, multiaddresses.clone());
247 }
248 }
249 },
250 None => {}
251 });
252
253 if matches!(poll_result, std::task::Poll::Pending) {
254 std::task::Poll::Pending
255 } else if let Some(value) = self.pending_events.pop_front() {
256 std::task::Poll::Ready(value)
257 } else {
258 std::task::Poll::Pending
259 }
260 }
261}