hopr_db_sql/
channels.rs

1use async_trait::async_trait;
2use futures::stream::BoxStream;
3use futures::{StreamExt, TryStreamExt};
4use sea_orm::ActiveValue::Set;
5use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
6
7use hopr_crypto_types::prelude::*;
8use hopr_db_entity::channel;
9use hopr_db_entity::conversions::channels::ChannelStatusUpdate;
10use hopr_db_entity::prelude::Channel;
11use hopr_internal_types::prelude::*;
12use hopr_primitive_types::prelude::*;
13
14use crate::cache::ChannelParties;
15use crate::db::HoprDb;
16use crate::errors::{DbSqlError, Result};
17use crate::{HoprDbGeneralModelOperations, OptTx};
18
19/// API for editing [ChannelEntry] in the DB.
20pub struct ChannelEditor {
21    orig: ChannelEntry,
22    model: channel::ActiveModel,
23    delete: bool,
24}
25
26impl ChannelEditor {
27    /// Original channel entry **before** the edits.
28    pub fn entry(&self) -> &ChannelEntry {
29        &self.orig
30    }
31
32    /// Change the HOPR balance of the channel.
33    pub fn change_balance(mut self, balance: Balance) -> Self {
34        assert_eq!(BalanceType::HOPR, balance.balance_type());
35        self.model.balance = Set(balance.amount().to_be_bytes().to_vec());
36        self
37    }
38
39    /// Change the channel status.
40    pub fn change_status(mut self, status: ChannelStatus) -> Self {
41        self.model.set_status(status);
42        self
43    }
44
45    /// Change the ticket index.
46    pub fn change_ticket_index(mut self, index: impl Into<U256>) -> Self {
47        self.model.ticket_index = Set(index.into().to_be_bytes().to_vec());
48        self
49    }
50
51    /// Change the channel epoch.
52    pub fn change_epoch(mut self, epoch: impl Into<U256>) -> Self {
53        self.model.epoch = Set(epoch.into().to_be_bytes().to_vec());
54        self
55    }
56
57    /// If set, the channel will be deleted, no other edits will be done.
58    pub fn delete(mut self) -> Self {
59        self.delete = true;
60        self
61    }
62}
63
64/// Defines DB API for accessing information about HOPR payment channels.
65#[async_trait]
66pub trait HoprDbChannelOperations {
67    /// Retrieves channel by its channel ID hash.
68    ///
69    /// See [generate_channel_id] on how to generate a channel ID hash from source and destination [Addresses](Address).
70    async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>>;
71
72    /// Start changes to channel entry.
73    /// If the channel with the given ID exists, the [ChannelEditor] is returned.
74    /// Use [`HoprDbChannelOperations::finish_channel_update`] to commit edits to the DB when done.
75    async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>>;
76
77    /// Commits changes of the channel to the database.
78    /// Returns the updated channel, or on deletion, the deleted channel entry.
79    async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<ChannelEntry>;
80
81    /// Retrieves the channel by source and destination.
82    /// This operation should be able to use cache since it can be also called from
83    /// performance-sensitive locations.
84    async fn get_channel_by_parties<'a>(
85        &'a self,
86        tx: OptTx<'a>,
87        src: &Address,
88        dst: &Address,
89        use_cache: bool,
90    ) -> Result<Option<ChannelEntry>>;
91
92    /// Fetches all channels that are `Incoming` to the given `target`, or `Outgoing` from the given `target`
93    async fn get_channels_via<'a>(
94        &'a self,
95        tx: OptTx<'a>,
96        direction: ChannelDirection,
97        target: &Address,
98    ) -> Result<Vec<ChannelEntry>>;
99
100    /// Fetches all channels that are `Incoming` to this node.
101    /// Shorthand for `get_channels_via(tx, ChannelDirection::Incoming, my_node)`
102    async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
103
104    /// Fetches all channels that are `Outgoing` from this node.
105    /// Shorthand for `get_channels_via(tx, ChannelDirection::Outgoing, my_node)`
106    async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
107
108    /// Retrieves all channel information from the DB.
109    async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
110
111    /// Returns a stream of all channels that are `Open` or `PendingToClose` with an active grace period.s
112    async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>>;
113
114    /// Inserts or updates the given channel entry.
115    async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()>;
116}
117
118#[async_trait]
119impl HoprDbChannelOperations for HoprDb {
120    async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>> {
121        let id_hex = id.to_hex();
122        self.nest_transaction(tx)
123            .await?
124            .perform(|tx| {
125                Box::pin(async move {
126                    Ok::<_, DbSqlError>(
127                        if let Some(model) = Channel::find()
128                            .filter(channel::Column::ChannelId.eq(id_hex))
129                            .one(tx.as_ref())
130                            .await?
131                        {
132                            Some(model.try_into()?)
133                        } else {
134                            None
135                        },
136                    )
137                })
138            })
139            .await
140    }
141
142    async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>> {
143        let id_hex = id.to_hex();
144        self.nest_transaction(tx)
145            .await?
146            .perform(|tx| {
147                Box::pin(async move {
148                    Ok::<_, DbSqlError>(
149                        if let Some(model) = Channel::find()
150                            .filter(channel::Column::ChannelId.eq(id_hex))
151                            .one(tx.as_ref())
152                            .await?
153                        {
154                            Some(ChannelEditor {
155                                orig: model.clone().try_into()?,
156                                model: model.into_active_model(),
157                                delete: false,
158                            })
159                        } else {
160                            None
161                        },
162                    )
163                })
164            })
165            .await
166    }
167
168    async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<ChannelEntry> {
169        let epoch = editor.model.epoch.clone();
170        let parties = ChannelParties(editor.orig.source, editor.orig.destination);
171        let ret = self
172            .nest_transaction(tx)
173            .await?
174            .perform(|tx| {
175                Box::pin(async move {
176                    if !editor.delete {
177                        let model = editor.model.update(tx.as_ref()).await?;
178                        Ok::<_, DbSqlError>(model.try_into()?)
179                    } else {
180                        editor.model.delete(tx.as_ref()).await?;
181                        Ok::<_, DbSqlError>(editor.orig)
182                    }
183                })
184            })
185            .await?;
186        self.caches.src_dst_to_channel.invalidate(&parties).await;
187
188        // Finally invalidate any unrealized values from the cache.
189        // This might be a no-op if the channel was not in the cache
190        // like for channels that are not ours.
191        let channel_id = editor.orig.get_id();
192        if let Some(channel_epoch) = epoch.try_as_ref() {
193            self.caches
194                .unrealized_value
195                .invalidate(&(channel_id, channel_epoch.as_slice().into()))
196                .await;
197        }
198
199        Ok(ret)
200    }
201
202    async fn get_channel_by_parties<'a>(
203        &'a self,
204        tx: OptTx<'a>,
205        src: &Address,
206        dst: &Address,
207        use_cache: bool,
208    ) -> Result<Option<ChannelEntry>> {
209        let fetch_channel = async move {
210            let src_hex = src.to_hex();
211            let dst_hex = dst.to_hex();
212            self.nest_transaction(tx)
213                .await?
214                .perform(|tx| {
215                    Box::pin(async move {
216                        Ok::<_, DbSqlError>(
217                            if let Some(model) = Channel::find()
218                                .filter(channel::Column::Source.eq(src_hex))
219                                .filter(channel::Column::Destination.eq(dst_hex))
220                                .one(tx.as_ref())
221                                .await?
222                            {
223                                Some(model.try_into()?)
224                            } else {
225                                None
226                            },
227                        )
228                    })
229                })
230                .await
231        };
232
233        if use_cache {
234            Ok(self
235                .caches
236                .src_dst_to_channel
237                .try_get_with(ChannelParties(*src, *dst), fetch_channel)
238                .await?)
239        } else {
240            fetch_channel.await
241        }
242    }
243
244    async fn get_channels_via<'a>(
245        &'a self,
246        tx: OptTx<'a>,
247        direction: ChannelDirection,
248        target: &Address,
249    ) -> Result<Vec<ChannelEntry>> {
250        let target_hex = target.to_hex();
251        self.nest_transaction(tx)
252            .await?
253            .perform(|tx| {
254                Box::pin(async move {
255                    Channel::find()
256                        .filter(match direction {
257                            ChannelDirection::Incoming => channel::Column::Destination.eq(target_hex),
258                            ChannelDirection::Outgoing => channel::Column::Source.eq(target_hex),
259                        })
260                        .all(tx.as_ref())
261                        .await?
262                        .into_iter()
263                        .map(|x| ChannelEntry::try_from(x).map_err(DbSqlError::from))
264                        .collect::<Result<Vec<_>>>()
265                })
266            })
267            .await
268    }
269
270    async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
271        self.get_channels_via(tx, ChannelDirection::Incoming, &self.me_onchain)
272            .await
273    }
274
275    async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
276        self.get_channels_via(tx, ChannelDirection::Outgoing, &self.me_onchain)
277            .await
278    }
279
280    async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
281        self.nest_transaction(tx)
282            .await?
283            .perform(|tx| {
284                Box::pin(async move {
285                    Channel::find()
286                        .stream(tx.as_ref())
287                        .await?
288                        .map_err(DbSqlError::from)
289                        .try_filter_map(|m| async move { Ok(Some(ChannelEntry::try_from(m)?)) })
290                        .try_collect()
291                        .await
292                })
293            })
294            .await
295    }
296
297    async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>> {
298        Ok(Channel::find()
299            .filter(
300                channel::Column::Status
301                    .eq(i8::from(ChannelStatus::Open))
302                    .or(channel::Column::Status
303                        .eq(i8::from(ChannelStatus::PendingToClose(
304                            hopr_platform::time::native::current_time(), // irrelevant
305                        )))
306                        .and(channel::Column::ClosureTime.gt(Utc::now()))),
307            )
308            .stream(&self.index_db)
309            .await?
310            .map_err(DbSqlError::from)
311            .and_then(|m| async move { Ok(ChannelEntry::try_from(m)?) })
312            .boxed())
313    }
314
315    async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()> {
316        let parties = ChannelParties(channel_entry.source, channel_entry.destination);
317        self.nest_transaction(tx)
318            .await?
319            .perform(|tx| {
320                Box::pin(async move {
321                    let mut model = channel::ActiveModel::from(channel_entry);
322                    if let Some(channel) = channel::Entity::find()
323                        .filter(channel::Column::ChannelId.eq(channel_entry.get_id().to_hex()))
324                        .one(tx.as_ref())
325                        .await?
326                    {
327                        model.id = Set(channel.id);
328                    }
329
330                    Ok::<_, DbSqlError>(model.save(tx.as_ref()).await?)
331                })
332            })
333            .await?;
334
335        self.caches.src_dst_to_channel.invalidate(&parties).await;
336
337        // Finally invalidate any unrealized values from the cache.
338        // This might be a no-op if the channel was not in the cache
339        // like for channels that are not ours.
340        let channel_id = channel_entry.get_id();
341        let channel_epoch = channel_entry.channel_epoch;
342        self.caches
343            .unrealized_value
344            .invalidate(&(channel_id, channel_epoch))
345            .await;
346
347        Ok(())
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use crate::channels::HoprDbChannelOperations;
354    use crate::db::HoprDb;
355    use crate::HoprDbGeneralModelOperations;
356    use anyhow::Context;
357    use hopr_crypto_random::random_bytes;
358    use hopr_crypto_types::keypairs::ChainKeypair;
359    use hopr_crypto_types::prelude::Keypair;
360    use hopr_internal_types::channels::ChannelStatus;
361    use hopr_internal_types::prelude::{ChannelDirection, ChannelEntry};
362    use hopr_primitive_types::prelude::{Address, BalanceType};
363
364    #[async_std::test]
365    async fn test_insert_get_by_id() -> anyhow::Result<()> {
366        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
367
368        let ce = ChannelEntry::new(
369            Address::default(),
370            Address::default(),
371            BalanceType::HOPR.zero(),
372            0_u32.into(),
373            ChannelStatus::Open,
374            0_u32.into(),
375        );
376
377        db.upsert_channel(None, ce).await?;
378        let from_db = db
379            .get_channel_by_id(None, &ce.get_id())
380            .await?
381            .expect("channel must be present");
382
383        assert_eq!(ce, from_db, "channels must be equal");
384
385        Ok(())
386    }
387
388    #[async_std::test]
389    async fn test_insert_get_by_parties() -> anyhow::Result<()> {
390        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
391
392        let a = Address::from(random_bytes());
393        let b = Address::from(random_bytes());
394
395        let ce = ChannelEntry::new(
396            a,
397            b,
398            BalanceType::HOPR.zero(),
399            0_u32.into(),
400            ChannelStatus::Open,
401            0_u32.into(),
402        );
403
404        db.upsert_channel(None, ce).await?;
405        let from_db = db
406            .get_channel_by_parties(None, &a, &b, false)
407            .await?
408            .context("channel must be present")?;
409
410        assert_eq!(ce, from_db, "channels must be equal");
411
412        Ok(())
413    }
414
415    #[async_std::test]
416    async fn test_channel_get_for_destination_that_does_not_exist_returns_none() -> anyhow::Result<()> {
417        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
418
419        let from_db = db
420            .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
421            .await?
422            .first()
423            .cloned();
424
425        assert_eq!(None, from_db, "should return None");
426
427        Ok(())
428    }
429
430    #[async_std::test]
431    async fn test_channel_get_for_destination_that_exists_should_be_returned() -> anyhow::Result<()> {
432        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
433
434        let expected_destination = Address::default();
435
436        let ce = ChannelEntry::new(
437            Address::default(),
438            expected_destination,
439            BalanceType::HOPR.zero(),
440            0_u32.into(),
441            ChannelStatus::Open,
442            0_u32.into(),
443        );
444
445        db.upsert_channel(None, ce).await?;
446        let from_db = db
447            .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
448            .await?
449            .first()
450            .cloned();
451
452        assert_eq!(Some(ce), from_db, "should return a valid channel");
453
454        Ok(())
455    }
456
457    #[async_std::test]
458    async fn test_incoming_outgoing_channels() -> anyhow::Result<()> {
459        let ckp = ChainKeypair::random();
460        let addr_1 = ckp.public().to_address();
461        let addr_2 = ChainKeypair::random().public().to_address();
462
463        let db = HoprDb::new_in_memory(ckp).await?;
464
465        let ce_1 = ChannelEntry::new(
466            addr_1,
467            addr_2,
468            BalanceType::HOPR.zero(),
469            1_u32.into(),
470            ChannelStatus::Open,
471            0_u32.into(),
472        );
473
474        let ce_2 = ChannelEntry::new(
475            addr_2,
476            addr_1,
477            BalanceType::HOPR.zero(),
478            2_u32.into(),
479            ChannelStatus::Open,
480            0_u32.into(),
481        );
482
483        let db_clone = db.clone();
484        db.begin_transaction()
485            .await?
486            .perform(|tx| {
487                Box::pin(async move {
488                    db_clone.upsert_channel(Some(tx), ce_1).await?;
489                    db_clone.upsert_channel(Some(tx), ce_2).await
490                })
491            })
492            .await?;
493
494        assert_eq!(vec![ce_2], db.get_incoming_channels(None).await?);
495        assert_eq!(vec![ce_1], db.get_outgoing_channels(None).await?);
496
497        Ok(())
498    }
499}