hopr_db_sql/
channels.rs

1use async_trait::async_trait;
2use futures::{StreamExt, TryStreamExt, stream::BoxStream};
3use hopr_crypto_types::prelude::*;
4use hopr_db_entity::{channel, conversions::channels::ChannelStatusUpdate, prelude::Channel};
5use hopr_internal_types::prelude::*;
6use hopr_primitive_types::prelude::*;
7use sea_orm::{ActiveModelTrait, ActiveValue::Set, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
8
9use crate::{
10    HoprDbGeneralModelOperations, OptTx,
11    cache::ChannelParties,
12    db::HoprDb,
13    errors::{DbSqlError, Result},
14};
15
16/// API for editing [ChannelEntry] in the DB.
17pub struct ChannelEditor {
18    orig: ChannelEntry,
19    model: channel::ActiveModel,
20    delete: bool,
21}
22
23impl ChannelEditor {
24    /// Original channel entry **before** the edits.
25    pub fn entry(&self) -> &ChannelEntry {
26        &self.orig
27    }
28
29    /// Change the HOPR balance of the channel.
30    pub fn change_balance(mut self, balance: HoprBalance) -> Self {
31        self.model.balance = Set(balance.amount().to_be_bytes().to_vec());
32        self
33    }
34
35    /// Change the channel status.
36    pub fn change_status(mut self, status: ChannelStatus) -> Self {
37        self.model.set_status(status);
38        self
39    }
40
41    /// Change the ticket index.
42    pub fn change_ticket_index(mut self, index: impl Into<U256>) -> Self {
43        self.model.ticket_index = Set(index.into().to_be_bytes().to_vec());
44        self
45    }
46
47    /// Change the channel epoch.
48    pub fn change_epoch(mut self, epoch: impl Into<U256>) -> Self {
49        self.model.epoch = Set(epoch.into().to_be_bytes().to_vec());
50        self
51    }
52
53    /// If set, the channel will be deleted, no other edits will be done.
54    pub fn delete(mut self) -> Self {
55        self.delete = true;
56        self
57    }
58}
59
60/// Defines DB API for accessing information about HOPR payment channels.
61#[async_trait]
62pub trait HoprDbChannelOperations {
63    /// Retrieves channel by its channel ID hash.
64    ///
65    /// See [generate_channel_id] on how to generate a channel ID hash from source and destination [Addresses](Address).
66    async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>>;
67
68    /// Start changes to channel entry.
69    /// If the channel with the given ID exists, the [ChannelEditor] is returned.
70    /// Use [`HoprDbChannelOperations::finish_channel_update`] to commit edits to the DB when done.
71    async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>>;
72
73    /// Commits changes of the channel to the database.
74    /// Returns the updated channel, or on deletion, the deleted channel entry.
75    async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<ChannelEntry>;
76
77    /// Retrieves the channel by source and destination.
78    /// This operation should be able to use cache since it can be also called from
79    /// performance-sensitive locations.
80    async fn get_channel_by_parties<'a>(
81        &'a self,
82        tx: OptTx<'a>,
83        src: &Address,
84        dst: &Address,
85        use_cache: bool,
86    ) -> Result<Option<ChannelEntry>>;
87
88    /// Fetches all channels that are `Incoming` to the given `target`, or `Outgoing` from the given `target`
89    async fn get_channels_via<'a>(
90        &'a self,
91        tx: OptTx<'a>,
92        direction: ChannelDirection,
93        target: &Address,
94    ) -> Result<Vec<ChannelEntry>>;
95
96    /// Fetches all channels that are `Incoming` to this node.
97    /// Shorthand for `get_channels_via(tx, ChannelDirection::Incoming, my_node)`
98    async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
99
100    /// Fetches all channels that are `Outgoing` from this node.
101    /// Shorthand for `get_channels_via(tx, ChannelDirection::Outgoing, my_node)`
102    async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
103
104    /// Retrieves all channel information from the DB.
105    async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
106
107    /// Returns a stream of all channels that are `Open` or `PendingToClose` with an active grace period.s
108    async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>>;
109
110    /// Inserts or updates the given channel entry.
111    async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()>;
112}
113
114#[async_trait]
115impl HoprDbChannelOperations for HoprDb {
116    async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>> {
117        let id_hex = id.to_hex();
118        self.nest_transaction(tx)
119            .await?
120            .perform(|tx| {
121                Box::pin(async move {
122                    Ok::<_, DbSqlError>(
123                        if let Some(model) = Channel::find()
124                            .filter(channel::Column::ChannelId.eq(id_hex))
125                            .one(tx.as_ref())
126                            .await?
127                        {
128                            Some(model.try_into()?)
129                        } else {
130                            None
131                        },
132                    )
133                })
134            })
135            .await
136    }
137
138    async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>> {
139        let id_hex = id.to_hex();
140        self.nest_transaction(tx)
141            .await?
142            .perform(|tx| {
143                Box::pin(async move {
144                    Ok::<_, DbSqlError>(
145                        if let Some(model) = Channel::find()
146                            .filter(channel::Column::ChannelId.eq(id_hex))
147                            .one(tx.as_ref())
148                            .await?
149                        {
150                            Some(ChannelEditor {
151                                orig: model.clone().try_into()?,
152                                model: model.into_active_model(),
153                                delete: false,
154                            })
155                        } else {
156                            None
157                        },
158                    )
159                })
160            })
161            .await
162    }
163
164    async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<ChannelEntry> {
165        let epoch = editor.model.epoch.clone();
166        let parties = ChannelParties(editor.orig.source, editor.orig.destination);
167        let ret = self
168            .nest_transaction(tx)
169            .await?
170            .perform(|tx| {
171                Box::pin(async move {
172                    if !editor.delete {
173                        let model = editor.model.update(tx.as_ref()).await?;
174                        Ok::<_, DbSqlError>(model.try_into()?)
175                    } else {
176                        editor.model.delete(tx.as_ref()).await?;
177                        Ok::<_, DbSqlError>(editor.orig)
178                    }
179                })
180            })
181            .await?;
182        self.caches.src_dst_to_channel.invalidate(&parties).await;
183
184        // Finally invalidate any unrealized values from the cache.
185        // This might be a no-op if the channel was not in the cache
186        // like for channels that are not ours.
187        let channel_id = editor.orig.get_id();
188        if let Some(channel_epoch) = epoch.try_as_ref() {
189            self.caches
190                .unrealized_value
191                .invalidate(&(channel_id, U256::from_big_endian(channel_epoch.as_slice())))
192                .await;
193        }
194
195        Ok(ret)
196    }
197
198    async fn get_channel_by_parties<'a>(
199        &'a self,
200        tx: OptTx<'a>,
201        src: &Address,
202        dst: &Address,
203        use_cache: bool,
204    ) -> Result<Option<ChannelEntry>> {
205        let fetch_channel = async move {
206            let src_hex = src.to_hex();
207            let dst_hex = dst.to_hex();
208            self.nest_transaction(tx)
209                .await?
210                .perform(|tx| {
211                    Box::pin(async move {
212                        Ok::<_, DbSqlError>(
213                            if let Some(model) = Channel::find()
214                                .filter(channel::Column::Source.eq(src_hex))
215                                .filter(channel::Column::Destination.eq(dst_hex))
216                                .one(tx.as_ref())
217                                .await?
218                            {
219                                Some(model.try_into()?)
220                            } else {
221                                None
222                            },
223                        )
224                    })
225                })
226                .await
227        };
228
229        if use_cache {
230            Ok(self
231                .caches
232                .src_dst_to_channel
233                .try_get_with(ChannelParties(*src, *dst), fetch_channel)
234                .await?)
235        } else {
236            fetch_channel.await
237        }
238    }
239
240    async fn get_channels_via<'a>(
241        &'a self,
242        tx: OptTx<'a>,
243        direction: ChannelDirection,
244        target: &Address,
245    ) -> Result<Vec<ChannelEntry>> {
246        let target_hex = target.to_hex();
247        self.nest_transaction(tx)
248            .await?
249            .perform(|tx| {
250                Box::pin(async move {
251                    Channel::find()
252                        .filter(match direction {
253                            ChannelDirection::Incoming => channel::Column::Destination.eq(target_hex),
254                            ChannelDirection::Outgoing => channel::Column::Source.eq(target_hex),
255                        })
256                        .all(tx.as_ref())
257                        .await?
258                        .into_iter()
259                        .map(|x| ChannelEntry::try_from(x).map_err(DbSqlError::from))
260                        .collect::<Result<Vec<_>>>()
261                })
262            })
263            .await
264    }
265
266    async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
267        self.get_channels_via(tx, ChannelDirection::Incoming, &self.me_onchain)
268            .await
269    }
270
271    async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
272        self.get_channels_via(tx, ChannelDirection::Outgoing, &self.me_onchain)
273            .await
274    }
275
276    async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
277        self.nest_transaction(tx)
278            .await?
279            .perform(|tx| {
280                Box::pin(async move {
281                    Channel::find()
282                        .stream(tx.as_ref())
283                        .await?
284                        .map_err(DbSqlError::from)
285                        .try_filter_map(|m| async move { Ok(Some(ChannelEntry::try_from(m)?)) })
286                        .try_collect()
287                        .await
288                })
289            })
290            .await
291    }
292
293    async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>> {
294        Ok(Channel::find()
295            .filter(
296                channel::Column::Status
297                    .eq(i8::from(ChannelStatus::Open))
298                    .or(channel::Column::Status
299                        .eq(i8::from(ChannelStatus::PendingToClose(
300                            hopr_platform::time::native::current_time(), // irrelevant
301                        )))
302                        .and(channel::Column::ClosureTime.gt(Utc::now()))),
303            )
304            .stream(&self.index_db)
305            .await?
306            .map_err(DbSqlError::from)
307            .and_then(|m| async move { Ok(ChannelEntry::try_from(m)?) })
308            .boxed())
309    }
310
311    async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()> {
312        let parties = ChannelParties(channel_entry.source, channel_entry.destination);
313        self.nest_transaction(tx)
314            .await?
315            .perform(|tx| {
316                Box::pin(async move {
317                    let mut model = channel::ActiveModel::from(channel_entry);
318                    if let Some(channel) = channel::Entity::find()
319                        .filter(channel::Column::ChannelId.eq(channel_entry.get_id().to_hex()))
320                        .one(tx.as_ref())
321                        .await?
322                    {
323                        model.id = Set(channel.id);
324                    }
325
326                    Ok::<_, DbSqlError>(model.save(tx.as_ref()).await?)
327                })
328            })
329            .await?;
330
331        self.caches.src_dst_to_channel.invalidate(&parties).await;
332
333        // Finally, invalidate any unrealized values from the cache.
334        // This might be a no-op if the channel was not in the cache
335        // like for channels that are not ours.
336        let channel_id = channel_entry.get_id();
337        let channel_epoch = channel_entry.channel_epoch;
338        self.caches
339            .unrealized_value
340            .invalidate(&(channel_id, channel_epoch))
341            .await;
342
343        Ok(())
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use anyhow::Context;
350    use hopr_crypto_random::random_bytes;
351    use hopr_crypto_types::{keypairs::ChainKeypair, prelude::Keypair};
352    use hopr_internal_types::{
353        channels::ChannelStatus,
354        prelude::{ChannelDirection, ChannelEntry},
355    };
356    use hopr_primitive_types::prelude::Address;
357
358    use crate::{HoprDbGeneralModelOperations, channels::HoprDbChannelOperations, db::HoprDb};
359
360    #[tokio::test]
361    async fn test_insert_get_by_id() -> anyhow::Result<()> {
362        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
363
364        let ce = ChannelEntry::new(
365            Address::default(),
366            Address::default(),
367            0.into(),
368            0_u32.into(),
369            ChannelStatus::Open,
370            0_u32.into(),
371        );
372
373        db.upsert_channel(None, ce).await?;
374        let from_db = db
375            .get_channel_by_id(None, &ce.get_id())
376            .await?
377            .expect("channel must be present");
378
379        assert_eq!(ce, from_db, "channels must be equal");
380
381        Ok(())
382    }
383
384    #[tokio::test]
385    async fn test_insert_get_by_parties() -> anyhow::Result<()> {
386        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
387
388        let a = Address::from(random_bytes());
389        let b = Address::from(random_bytes());
390
391        let ce = ChannelEntry::new(a, b, 0.into(), 0_u32.into(), ChannelStatus::Open, 0_u32.into());
392
393        db.upsert_channel(None, ce).await?;
394        let from_db = db
395            .get_channel_by_parties(None, &a, &b, false)
396            .await?
397            .context("channel must be present")?;
398
399        assert_eq!(ce, from_db, "channels must be equal");
400
401        Ok(())
402    }
403
404    #[tokio::test]
405    async fn test_channel_get_for_destination_that_does_not_exist_returns_none() -> anyhow::Result<()> {
406        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
407
408        let from_db = db
409            .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
410            .await?
411            .first()
412            .cloned();
413
414        assert_eq!(None, from_db, "should return None");
415
416        Ok(())
417    }
418
419    #[tokio::test]
420    async fn test_channel_get_for_destination_that_exists_should_be_returned() -> anyhow::Result<()> {
421        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
422
423        let expected_destination = Address::default();
424
425        let ce = ChannelEntry::new(
426            Address::default(),
427            expected_destination,
428            0.into(),
429            0_u32.into(),
430            ChannelStatus::Open,
431            0_u32.into(),
432        );
433
434        db.upsert_channel(None, ce).await?;
435        let from_db = db
436            .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
437            .await?
438            .first()
439            .cloned();
440
441        assert_eq!(Some(ce), from_db, "should return a valid channel");
442
443        Ok(())
444    }
445
446    #[tokio::test]
447    async fn test_incoming_outgoing_channels() -> anyhow::Result<()> {
448        let ckp = ChainKeypair::random();
449        let addr_1 = ckp.public().to_address();
450        let addr_2 = ChainKeypair::random().public().to_address();
451
452        let db = HoprDb::new_in_memory(ckp).await?;
453
454        let ce_1 = ChannelEntry::new(
455            addr_1,
456            addr_2,
457            0.into(),
458            1_u32.into(),
459            ChannelStatus::Open,
460            0_u32.into(),
461        );
462
463        let ce_2 = ChannelEntry::new(
464            addr_2,
465            addr_1,
466            0.into(),
467            2_u32.into(),
468            ChannelStatus::Open,
469            0_u32.into(),
470        );
471
472        let db_clone = db.clone();
473        db.begin_transaction()
474            .await?
475            .perform(|tx| {
476                Box::pin(async move {
477                    db_clone.upsert_channel(Some(tx), ce_1).await?;
478                    db_clone.upsert_channel(Some(tx), ce_2).await
479                })
480            })
481            .await?;
482
483        assert_eq!(vec![ce_2], db.get_incoming_channels(None).await?);
484        assert_eq!(vec![ce_1], db.get_outgoing_channels(None).await?);
485
486        Ok(())
487    }
488}