hopr_transport_p2p/behavior/
discovery.rs1use std::collections::{HashMap, HashSet, VecDeque};
3
4use futures::stream::{BoxStream, Stream, StreamExt};
5use futures_concurrency::stream::Merge;
6use libp2p::{
7 swarm::{dial_opts::DialOpts, dummy::ConnectionHandler, CloseConnection, NetworkBehaviour, ToSwarm},
8 Multiaddr, PeerId,
9};
10use tracing::debug;
11
12use hopr_transport_network::network::NetworkTriggeredEvent;
13use hopr_transport_protocol::PeerDiscovery;
14
15#[derive(Debug)]
16pub enum DiscoveryInput {
17 NetworkUpdate(NetworkTriggeredEvent),
18 Indexer(PeerDiscovery),
19}
20
21#[derive(Debug)]
22pub enum Event {}
23
24pub struct Behaviour {
25 me: PeerId,
26 events: BoxStream<'static, DiscoveryInput>,
27 pending_events: VecDeque<
28 libp2p::swarm::ToSwarm<
29 <Self as NetworkBehaviour>::ToSwarm,
30 <<Self as NetworkBehaviour>::ConnectionHandler as libp2p::swarm::ConnectionHandler>::FromBehaviour,
31 >,
32 >,
33 all_peers: HashMap<PeerId, Multiaddr>,
34 allowed_peers: HashSet<PeerId>,
35 connected_peers: HashMap<PeerId, usize>,
36}
37
38impl Behaviour {
39 pub fn new<T, U>(me: PeerId, network_events: T, onchain_events: U) -> Self
40 where
41 T: Stream<Item = NetworkTriggeredEvent> + Send + 'static,
42 U: Stream<Item = PeerDiscovery> + Send + 'static,
43 {
44 Self {
45 me,
46 events: Box::pin(
47 (
48 network_events.map(DiscoveryInput::NetworkUpdate),
49 onchain_events.map(DiscoveryInput::Indexer),
50 )
51 .merge()
52 .fuse(),
53 ),
54 all_peers: HashMap::new(),
55 pending_events: VecDeque::new(),
56 allowed_peers: HashSet::new(),
57 connected_peers: HashMap::new(),
58 }
59 }
60
61 fn is_peer_connected(&self, peer: &PeerId) -> bool {
62 self.connected_peers.get(peer).map(|v| *v > 0).unwrap_or(false)
63 }
64}
65
66impl NetworkBehaviour for Behaviour {
67 type ConnectionHandler = ConnectionHandler;
68
69 type ToSwarm = Event;
70
71 fn handle_established_inbound_connection(
72 &mut self,
73 _connection_id: libp2p::swarm::ConnectionId,
74 peer: libp2p::PeerId,
75 _local_addr: &libp2p::Multiaddr,
76 _remote_addr: &libp2p::Multiaddr,
77 ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
78 if self.allowed_peers.contains(&peer) {
79 Ok(Self::ConnectionHandler {})
80 } else {
81 Err(libp2p::swarm::ConnectionDenied::new(crate::errors::P2PError::Logic(
82 format!("Connection from '{peer}' is not allowed"),
83 )))
84 }
85 }
86
87 fn handle_established_outbound_connection(
88 &mut self,
89 _connection_id: libp2p::swarm::ConnectionId,
90 peer: libp2p::PeerId,
91 _addr: &libp2p::Multiaddr,
92 _role_override: libp2p::core::Endpoint,
93 _port_use: libp2p::core::transport::PortUse,
94 ) -> Result<libp2p::swarm::THandler<Self>, libp2p::swarm::ConnectionDenied> {
95 if self.allowed_peers.contains(&peer) {
96 Ok(Self::ConnectionHandler {})
97 } else {
98 Err(libp2p::swarm::ConnectionDenied::new(crate::errors::P2PError::Logic(
99 format!("Connection to '{peer}' is not allowed"),
100 )))
101 }
102 }
103
104 fn on_swarm_event(&mut self, event: libp2p::swarm::FromSwarm) {
105 match event {
106 libp2p::swarm::FromSwarm::ConnectionEstablished(data) => {
107 *self.connected_peers.entry(data.peer_id).or_insert(0) += 1
108 }
109 libp2p::swarm::FromSwarm::ConnectionClosed(data) => {
110 let v = self.connected_peers.entry(data.peer_id).or_insert(0);
111 if *v > 0 {
112 *v -= 1;
113 };
114 }
115 libp2p::swarm::FromSwarm::DialFailure(failure) => {
116 if let Some(peer_id) = failure.peer_id {
120 if let Some(multiaddress) = self.all_peers.get(&peer_id) {
121 self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
122 peer_id,
123 address: multiaddress.clone(),
124 });
125 }
126 }
127 }
128 _ => {}
129 }
130 }
131
132 fn on_connection_handler_event(
133 &mut self,
134 _peer_id: libp2p::PeerId,
135 _connection_id: libp2p::swarm::ConnectionId,
136 _event: libp2p::swarm::THandlerOutEvent<Self>,
137 ) {
138 }
140
141 fn poll(
142 &mut self,
143 cx: &mut std::task::Context<'_>,
144 ) -> std::task::Poll<libp2p::swarm::ToSwarm<Self::ToSwarm, libp2p::swarm::THandlerInEvent<Self>>> {
145 if let Some(value) = self.pending_events.pop_front() {
146 return std::task::Poll::Ready(value);
147 };
148
149 let poll_result = self.events.poll_next_unpin(cx).map(|e| match e {
150 Some(DiscoveryInput::NetworkUpdate(event)) => match event {
151 NetworkTriggeredEvent::CloseConnection(peer) => {
152 debug!(peer = %peer, "p2p - discovery - Closing connection (reason: low ping connection quality");
153 if self.is_peer_connected(&peer) {
154 self.pending_events.push_back(ToSwarm::CloseConnection {
155 peer_id: peer,
156 connection: CloseConnection::default(),
157 });
158 }
159 }
160 NetworkTriggeredEvent::UpdateQuality(_, _) => {}
161 },
162 Some(DiscoveryInput::Indexer(event)) => match event {
163 PeerDiscovery::Allow(peer) => {
164 debug!(peer = %peer, "p2p - discovery - Network registry allow");
165 let _ = self.allowed_peers.insert(peer);
166
167 if let Some(multiaddress) = self.all_peers.get(&peer) {
168 self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
169 peer_id: peer,
170 address: multiaddress.clone(),
171 });
172 }
173 }
174 PeerDiscovery::Ban(peer) => {
175 debug!(peer = %peer, "p2p - discovery - Network registry ban");
176 self.allowed_peers.remove(&peer);
177
178 if self.is_peer_connected(&peer) {
179 debug!(peer = %peer, "p2p - discovery - Requesting disconnect due to ban");
180 self.pending_events.push_back(ToSwarm::CloseConnection {
181 peer_id: peer,
182 connection: CloseConnection::default(),
183 });
184 }
185 }
186 PeerDiscovery::Announce(peer, multiaddresses) => {
187 if peer != self.me {
188 debug!(peer = %peer, addresses = tracing::field::debug(&multiaddresses), "p2p - discovery - Announcement");
189 if let Some(multiaddress) = multiaddresses.last() {
190 self.all_peers.insert(peer, multiaddress.clone());
191
192 self.pending_events.push_back(ToSwarm::NewExternalAddrOfPeer {
193 peer_id: peer,
194 address: multiaddress.clone(),
195 });
196
197 self.pending_events.push_back(ToSwarm::Dial { opts: DialOpts::peer_id(peer).addresses(multiaddresses).build()});
201 }
202 }
203 }
204 },
205 None => {}
206 });
207
208 if matches!(poll_result, std::task::Poll::Pending) {
209 std::task::Poll::Pending
210 } else if let Some(value) = self.pending_events.pop_front() {
211 std::task::Poll::Ready(value)
212 } else {
213 std::task::Poll::Pending
214 }
215 }
216}