Skip to main content

hopr_protocol_hopr/codec/
decoder.rs

1use std::{ops::Mul, time::Duration};
2
3use bytes::{BufMut, Bytes, BytesMut};
4use hopr_api::{
5    chain::*,
6    types::{crypto::prelude::*, internal::prelude::*, primitive::prelude::*},
7};
8use hopr_crypto_packet::prelude::*;
9use hopr_utils::trace_timed;
10
11use crate::{
12    AuxiliaryPacketInfo, HoprCodecConfig, IncomingAcknowledgementPacket, IncomingFinalPacket, IncomingForwardedPacket,
13    IncomingPacket, IncomingPacketError, PacketDecoder, SurbStore, errors::HoprProtocolError, tbf::TagBloomFilter,
14};
15
16/// Default [decoder](PacketDecoder) implementation for HOPR packets.
17pub struct HoprDecoder<Chain, S, T> {
18    chain_api: Chain,
19    surb_store: std::sync::Arc<S>,
20    ticket_factory: T,
21    packet_key: OffchainKeypair,
22    chain_key: ChainKeypair,
23    channels_dst: Hash,
24    cfg: HoprCodecConfig,
25    tbf: parking_lot::Mutex<TagBloomFilter>,
26    peer_id_cache: moka::sync::Cache<PeerId, OffchainPublicKey>,
27}
28
29impl<Chain, S, T> HoprDecoder<Chain, S, T>
30where
31    Chain: ChainReadChannelOperations + ChainKeyOperations + ChainReadTicketOperations + ChainValues + Send + Sync,
32    S: SurbStore + Send + Sync,
33    T: hopr_api::tickets::TicketFactory + Send + Sync,
34{
35    /// Creates a new instance of the decoder.
36    pub fn new(
37        (packet_key, chain_key): (OffchainKeypair, ChainKeypair),
38        chain_api: Chain,
39        surb_store: S,
40        ticket_factory: T,
41        channels_dst: Hash,
42        cfg: HoprCodecConfig,
43    ) -> Self {
44        Self {
45            chain_api,
46            surb_store: std::sync::Arc::new(surb_store),
47            packet_key,
48            chain_key,
49            channels_dst,
50            cfg,
51            ticket_factory,
52            tbf: parking_lot::Mutex::new(Default::default()),
53            peer_id_cache: moka::sync::Cache::builder()
54                .time_to_idle(Duration::from_secs(600))
55                .max_capacity(100_000)
56                .build(),
57        }
58    }
59
60    #[tracing::instrument(skip(self, fwd), level = "debug", fields(path_pos = fwd.path_pos))]
61    fn validate_and_replace_ticket(
62        &self,
63        mut fwd: HoprForwardedPacket,
64    ) -> Result<(HoprForwardedPacket, UnacknowledgedTicket), HoprProtocolError> {
65        let previous_hop_addr = trace_timed!("previous_hop_addr lookup", {
66            self.chain_api
67                .packet_key_to_chain_key(&fwd.previous_hop)
68                .map_err(HoprProtocolError::resolver)?
69                .ok_or(HoprProtocolError::KeyNotFound)?
70        });
71
72        let next_hop_addr = trace_timed!("next_hop_addr lookup", {
73            self.chain_api
74                .packet_key_to_chain_key(&fwd.outgoing.next_hop)
75                .map_err(HoprProtocolError::resolver)?
76                .ok_or(HoprProtocolError::KeyNotFound)?
77        });
78
79        let incoming_channel = trace_timed!("incoming_channel lookup", {
80            self.chain_api
81                .channel_by_parties(&previous_hop_addr, self.chain_key.as_ref())
82                .map_err(HoprProtocolError::resolver)?
83                .ok_or_else(|| HoprProtocolError::ChannelNotFound(previous_hop_addr, *self.chain_key.as_ref()))?
84        });
85
86        // The ticket price from the oracle times my node's position on the
87        // path is the acceptable minimum
88        let (win_prob, minimum_ticket_price) = self
89            .chain_api
90            .incoming_ticket_values()
91            .map_err(HoprProtocolError::resolver)?;
92
93        let minimum_ticket_price = minimum_ticket_price
94            .mul(U256::from(fwd.path_pos))
95            .max(self.cfg.min_incoming_ticket_price.unwrap_or_default());
96
97        let remaining_balance = trace_timed!("unrealized_balance lookup", {
98            self.ticket_factory
99                .remaining_incoming_channel_stake(&incoming_channel)
100                .map_err(HoprProtocolError::ticket_factory)?
101        });
102
103        // Here also the signature on the ticket gets validated,
104        // so afterward we are sure the source of the `channel`
105        // (which is equal to `previous_hop_addr`) has issued this
106        // ticket.
107
108        let verified_incoming_ticket = trace_timed!("ticket_signature_verification", {
109            validate_unacknowledged_ticket(
110                fwd.outgoing.ticket,
111                &incoming_channel,
112                minimum_ticket_price,
113                win_prob,
114                remaining_balance,
115                &self.channels_dst,
116            )
117        })?;
118
119        // The ticket is now validated:
120        tracing::trace!(%verified_incoming_ticket, "successfully verified incoming ticket");
121
122        // NOTE: that the path position according to the ticket value
123        // may no longer match the path position from the packet header,
124        // because the ticket issuer may set the price of the ticket higher.
125
126        // Create the new ticket for the new packet
127        let ticket_builder = if fwd.path_pos > 1 {
128            // There must be a channel to the next node if it's not the final hop.
129            // If the channel does not exist, the ticket we extracted before cannot be saved,
130            // as there would be no way to acknowledge it without the channel.
131            let outgoing_channel = self
132                .chain_api
133                .channel_by_parties(self.chain_key.as_ref(), &next_hop_addr)
134                .map_err(HoprProtocolError::resolver)?
135                .ok_or_else(|| HoprProtocolError::ChannelNotFound(*self.chain_key.as_ref(), next_hop_addr))?;
136
137            let (outgoing_ticket_win_prob, outgoing_ticket_price) = self
138                .chain_api
139                .outgoing_ticket_values(self.cfg.outgoing_win_prob, self.cfg.outgoing_ticket_price)
140                .map_err(HoprProtocolError::resolver)?;
141
142            // We currently take the maximum of the win prob from the incoming ticket
143            // and the one configured on this node.
144            // Therefore, the winning probability can only increase along the path.
145            let outgoing_ticket_win_prob = outgoing_ticket_win_prob.max(&verified_incoming_ticket.win_prob());
146
147            // The following operation fails if there's not enough balance on the channel to the next hop.
148            // Again, in this case, we cannot save the ticket we previously extracted because there is no way it gets
149            // acknowledged without enough balance.
150            self.ticket_factory
151                .new_multihop_ticket(
152                    &outgoing_channel,
153                    fwd.path_pos.try_into().expect("path position is always > 1"),
154                    outgoing_ticket_win_prob,
155                    outgoing_ticket_price,
156                )
157                .map_err(HoprProtocolError::ticket_factory)?
158        } else {
159            TicketBuilder::zero_hop().counterparty(next_hop_addr)
160        };
161
162        // Finally, replace the ticket in the outgoing packet with a new one
163        let ticket_builder = ticket_builder.eth_challenge(fwd.next_challenge);
164        fwd.outgoing.ticket = trace_timed!("ticket_signing", {
165            ticket_builder.build_signed(&self.chain_key, &self.channels_dst)?.leak()
166        });
167
168        let unack_ticket = verified_incoming_ticket.into_unacknowledged(fwd.own_key);
169        Ok((fwd, unack_ticket))
170    }
171}
172
173impl<Chain, S, T> PacketDecoder for HoprDecoder<Chain, S, T>
174where
175    Chain: ChainReadChannelOperations + ChainKeyOperations + ChainReadTicketOperations + ChainValues + Send + Sync,
176    S: SurbStore + Send + Sync + 'static,
177    T: hopr_api::tickets::TicketFactory + Send + Sync,
178{
179    type Error = HoprProtocolError;
180
181    #[tracing::instrument(skip(self, sender, data), level = "trace", fields(%sender))]
182    fn decode(&self, sender: PeerId, data: Bytes) -> Result<IncomingPacket, IncomingPacketError<Self::Error>> {
183        #[cfg(feature = "trace-timing")]
184        let decode_start = std::time::Instant::now();
185        tracing::trace!(data_len = data.len(), "decoding packet");
186
187        // Phase 1: Peer ID conversion
188        // Try to retrieve the peer's public key from the cache or compute it if it does not exist yet.
189        // The async block ensures the Rayon task is only submitted on cache miss.
190        let previous_hop = trace_timed!("peer_id_conversion complete", {
191            match self
192                .peer_id_cache
193                .try_get_with_by_ref(&sender, || OffchainPublicKey::from_peerid(&sender))
194            {
195                Ok(peer) => Ok(peer),
196                Err(error) => {
197                    tracing::error!(%sender, %error, "dropping packet - cannot convert peer id");
198                    Err(IncomingPacketError::Undecodable(HoprProtocolError::InvalidSender))
199                }
200            }
201        })?;
202
203        // Phase 2: Sphinx packet decoding
204
205        // If the following operation fails, it means that the packet is not a valid Hopr packet,
206        // and as such should not be acknowledged later.
207        let packet = trace_timed!("sphinx_decode complete", {
208            HoprPacket::from_incoming(
209                &data,
210                &self.packet_key,
211                previous_hop,
212                self.chain_api.key_id_mapper_ref(),
213                |p| self.surb_store.find_reply_opener(p),
214            )
215        })
216        .map_err(IncomingPacketError::undecodable)?;
217
218        // This is checked on both Final and Forwarded packets,
219        // Outgoing packets are not allowed to pass and are later reported as invalid state.
220        if let Some(tag) = packet.packet_tag() {
221            // This operation has run-time of ~10 nanoseconds,
222            // and therefore does not need to be invoked via spawn_blocking
223            if self.tbf.lock().check_and_set(tag) {
224                return Err(IncomingPacketError::ProcessingError(
225                    previous_hop.into(),
226                    HoprProtocolError::Replay,
227                ));
228            }
229        }
230
231        match packet {
232            HoprPacket::Final(incoming) => {
233                // Extract additional information from the packet that will be passed upwards
234                let info = AuxiliaryPacketInfo {
235                    packet_signals: incoming.signals,
236                    num_surbs: incoming.surbs.len(),
237                };
238
239                // Store all incoming SURBs if any
240                if !incoming.surbs.is_empty() {
241                    self.surb_store.insert_surbs(incoming.sender, incoming.surbs);
242                    tracing::trace!(pseudonym = %incoming.sender, num_surbs = info.num_surbs, packet_type = "final", "stored incoming surbs for pseudonym");
243                }
244
245                let result = match incoming.ack_key {
246                    None => {
247                        if incoming.plain_text.len() < size_of::<u16>() {
248                            return Err(IncomingPacketError::Undecodable(
249                                GeneralError::ParseError("invalid acknowledgement packet size".into()).into(),
250                            ));
251                        }
252
253                        let num_acks =
254                            u16::from_be_bytes(incoming.plain_text[..size_of::<u16>()].try_into().map_err(|_| {
255                                IncomingPacketError::Undecodable(
256                                    GeneralError::ParseError("invalid num acks".into()).into(),
257                                )
258                            })?);
259
260                        if incoming.plain_text.len() < size_of::<u16>() + (num_acks as usize) * Acknowledgement::SIZE {
261                            return Err(IncomingPacketError::Undecodable(
262                                GeneralError::ParseError("invalid number of acknowledgements in packet".into()).into(),
263                            ));
264                        }
265                        tracing::trace!(num_acks, packet_type = "final", "received acknowledgement packet");
266
267                        // The contained payload represents an Acknowledgement
268                        IncomingPacket::Acknowledgement(
269                            IncomingAcknowledgementPacket {
270                                packet_tag: incoming.packet_tag,
271                                previous_hop: incoming.previous_hop,
272                                received_acks: incoming.plain_text
273                                    [size_of::<u16>()..size_of::<u16>() + num_acks as usize * Acknowledgement::SIZE]
274                                    .chunks_exact(Acknowledgement::SIZE)
275                                    .map(Acknowledgement::try_from)
276                                    .collect::<Result<Vec<_>, _>>()
277                                    .map_err(|e: GeneralError| IncomingPacketError::Undecodable(e.into()))?,
278                            }
279                            .into(),
280                        )
281                    }
282                    Some(ack_key) => IncomingPacket::Final(
283                        IncomingFinalPacket {
284                            packet_tag: incoming.packet_tag,
285                            previous_hop: incoming.previous_hop,
286                            sender: incoming.sender,
287                            plain_text: incoming.plain_text,
288                            ack_key,
289                            info,
290                        }
291                        .into(),
292                    ),
293                };
294                #[cfg(feature = "trace-timing")]
295                tracing::trace!(
296                    total_ms = decode_start.elapsed().as_millis() as u64,
297                    packet_type = "final",
298                    "decode complete"
299                );
300                Ok(result)
301            }
302            HoprPacket::Forwarded(fwd) => {
303                // Phase 3: Ticket validation and replacement for forwarded packets
304                // Transform the ticket so it can be sent to the next hop
305                let (fwd, verified_unack_ticket) = trace_timed!("ticket_validation complete", {
306                    self.validate_and_replace_ticket(*fwd).map_err(|error| match error {
307                        // Distinguish ticket validation errors so that they can get extra treatment later
308                        HoprProtocolError::TicketValidationError(e) => {
309                            IncomingPacketError::InvalidTicket(previous_hop.into(), e)
310                        }
311                        e => IncomingPacketError::ProcessingError(previous_hop.into(), e),
312                    })?
313                });
314
315                let mut payload = BytesMut::with_capacity(HoprPacket::SIZE);
316                payload.put_slice(fwd.outgoing.packet.as_ref());
317                payload.put_slice(&fwd.outgoing.ticket.into_encoded());
318
319                #[cfg(feature = "trace-timing")]
320                tracing::trace!(
321                    total_ms = decode_start.elapsed().as_millis() as u64,
322                    packet_type = "forwarded",
323                    "decode complete"
324                );
325                Ok(IncomingPacket::Forwarded(
326                    IncomingForwardedPacket {
327                        packet_tag: fwd.packet_tag,
328                        previous_hop: fwd.previous_hop,
329                        next_hop: fwd.outgoing.next_hop,
330                        data: payload.freeze(),
331                        ack_challenge: fwd.outgoing.ack_challenge,
332                        received_ticket: verified_unack_ticket,
333                        ack_key_prev_hop: fwd.ack_key,
334                    }
335                    .into(),
336                ))
337            }
338            HoprPacket::Outgoing(_) => {
339                #[cfg(feature = "trace-timing")]
340                tracing::trace!(
341                    total_ms = decode_start.elapsed().as_millis() as u64,
342                    packet_type = "outgoing",
343                    "decode complete"
344                );
345                Err(IncomingPacketError::ProcessingError(
346                    previous_hop.into(),
347                    HoprProtocolError::InvalidState("cannot be outgoing packet"),
348                ))
349            }
350        }
351    }
352}