hopr_db_sql/
channels.rs

1use async_trait::async_trait;
2use futures::{StreamExt, TryStreamExt, stream::BoxStream};
3use hopr_crypto_types::prelude::*;
4use hopr_db_entity::{channel, conversions::channels::ChannelStatusUpdate, prelude::Channel};
5use hopr_internal_types::prelude::*;
6use hopr_primitive_types::prelude::*;
7use sea_orm::{ActiveModelTrait, ActiveValue::Set, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
8use tracing::instrument;
9
10use crate::{
11    HoprDbGeneralModelOperations, OptTx,
12    cache::ChannelParties,
13    db::HoprDb,
14    errors::{DbSqlError, Result},
15};
16
17/// API for editing [ChannelEntry] in the DB.
18pub struct ChannelEditor {
19    orig: ChannelEntry,
20    model: channel::ActiveModel,
21    delete: bool,
22}
23
24impl ChannelEditor {
25    /// Original channel entry **before** the edits.
26    pub fn entry(&self) -> &ChannelEntry {
27        &self.orig
28    }
29
30    /// Change the HOPR balance of the channel.
31    pub fn change_balance(mut self, balance: HoprBalance) -> Self {
32        self.model.balance = Set(balance.amount().to_be_bytes().to_vec());
33        self
34    }
35
36    /// Change the channel status.
37    pub fn change_status(mut self, status: ChannelStatus) -> Self {
38        self.model.set_status(status);
39        self
40    }
41
42    /// Change the ticket index.
43    pub fn change_ticket_index(mut self, index: impl Into<U256>) -> Self {
44        self.model.ticket_index = Set(index.into().to_be_bytes().to_vec());
45        self
46    }
47
48    /// Change the channel epoch.
49    pub fn change_epoch(mut self, epoch: impl Into<U256>) -> Self {
50        self.model.epoch = Set(epoch.into().to_be_bytes().to_vec());
51        self
52    }
53
54    /// If set, the channel will be deleted, no other edits will be done.
55    pub fn delete(mut self) -> Self {
56        self.delete = true;
57        self
58    }
59}
60
61/// Defines DB API for accessing information about HOPR payment channels.
62#[async_trait]
63pub trait HoprDbChannelOperations {
64    /// Retrieves channel by its channel ID hash.
65    ///
66    /// See [generate_channel_id] on how to generate a channel ID hash from source and destination [Addresses](Address).
67    async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>>;
68
69    /// Start changes to channel entry.
70    /// If the channel with the given ID exists, the [ChannelEditor] is returned.
71    /// Use [`HoprDbChannelOperations::finish_channel_update`] to commit edits to the DB when done.
72    async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>>;
73
74    /// Commits changes of the channel to the database.
75    /// Returns the updated channel, or on deletion, the deleted channel entry.
76    async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<Option<ChannelEntry>>;
77
78    /// Retrieves the channel by source and destination.
79    /// This operation should be able to use cache since it can be also called from
80    /// performance-sensitive locations.
81    async fn get_channel_by_parties<'a>(
82        &'a self,
83        tx: OptTx<'a>,
84        src: &Address,
85        dst: &Address,
86        use_cache: bool,
87    ) -> Result<Option<ChannelEntry>>;
88
89    /// Fetches all channels that are `Incoming` to the given `target`, or `Outgoing` from the given `target`
90    async fn get_channels_via<'a>(
91        &'a self,
92        tx: OptTx<'a>,
93        direction: ChannelDirection,
94        target: &Address,
95    ) -> Result<Vec<ChannelEntry>>;
96
97    /// Fetches all channels that are `Incoming` to this node.
98    /// Shorthand for `get_channels_via(tx, ChannelDirection::Incoming, my_node)`
99    async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
100
101    /// Fetches all channels that are `Outgoing` from this node.
102    /// Shorthand for `get_channels_via(tx, ChannelDirection::Outgoing, my_node)`
103    async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
104
105    /// Retrieves all channels information from the DB.
106    async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
107
108    /// Returns a stream of all channels that are `Open` or `PendingToClose` with an active grace period.s
109    async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>>;
110
111    /// Inserts or updates the given channel entry.
112    async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()>;
113}
114
115#[async_trait]
116impl HoprDbChannelOperations for HoprDb {
117    async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>> {
118        let id_hex = id.to_hex();
119        self.nest_transaction(tx)
120            .await?
121            .perform(|tx| {
122                Box::pin(async move {
123                    Ok::<_, DbSqlError>(
124                        if let Some(model) = Channel::find()
125                            .filter(channel::Column::ChannelId.eq(id_hex))
126                            .one(tx.as_ref())
127                            .await?
128                        {
129                            Some(model.try_into()?)
130                        } else {
131                            None
132                        },
133                    )
134                })
135            })
136            .await
137    }
138
139    async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>> {
140        let id_hex = id.to_hex();
141        self.nest_transaction(tx)
142            .await?
143            .perform(|tx| {
144                Box::pin(async move {
145                    match Channel::find()
146                        .filter(channel::Column::ChannelId.eq(id_hex.clone()))
147                        .one(tx.as_ref())
148                        .await?
149                    {
150                        Some(model) => Ok(Some(ChannelEditor {
151                            orig: ChannelEntry::try_from(model.clone())?,
152                            model: model.into_active_model(),
153                            delete: false,
154                        })),
155                        None => Ok(None),
156                    }
157                })
158            })
159            .await
160    }
161
162    async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<Option<ChannelEntry>> {
163        let epoch = editor.model.epoch.clone();
164
165        let parties = ChannelParties(editor.orig.source, editor.orig.destination);
166        let ret = self
167            .nest_transaction(tx)
168            .await?
169            .perform(|tx| {
170                Box::pin(async move {
171                    if !editor.delete {
172                        let model = editor.model.update(tx.as_ref()).await?;
173                        match ChannelEntry::try_from(model) {
174                            Ok(channel) => Ok::<_, DbSqlError>(Some(channel)),
175                            Err(e) => Err(DbSqlError::from(e)),
176                        }
177                    } else {
178                        editor.model.delete(tx.as_ref()).await?;
179                        Ok::<_, DbSqlError>(Some(editor.orig))
180                    }
181                })
182            })
183            .await?;
184        self.caches.src_dst_to_channel.invalidate(&parties).await;
185
186        // Finally invalidate any unrealized values from the cache.
187        // This might be a no-op if the channel was not in the cache
188        // like for channels that are not ours.
189        let channel_id = editor.orig.get_id();
190        if let Some(channel_epoch) = epoch.try_as_ref() {
191            self.caches
192                .unrealized_value
193                .invalidate(&(channel_id, U256::from_big_endian(channel_epoch.as_slice())))
194                .await;
195        }
196
197        Ok(ret)
198    }
199
200    #[instrument(level = "trace", skip(self, tx), err)]
201    async fn get_channel_by_parties<'a>(
202        &'a self,
203        tx: OptTx<'a>,
204        src: &Address,
205        dst: &Address,
206        use_cache: bool,
207    ) -> Result<Option<ChannelEntry>> {
208        let fetch_channel = async move {
209            let src_hex = src.to_hex();
210            let dst_hex = dst.to_hex();
211            tracing::warn!(%src, %dst, "cache miss on get_channel_by_parties");
212            self.nest_transaction(tx)
213                .await?
214                .perform(|tx| {
215                    Box::pin(async move {
216                        Ok::<_, DbSqlError>(
217                            if let Some(model) = Channel::find()
218                                .filter(channel::Column::Source.eq(src_hex))
219                                .filter(channel::Column::Destination.eq(dst_hex))
220                                .one(tx.as_ref())
221                                .await?
222                            {
223                                Some(model.try_into()?)
224                            } else {
225                                None
226                            },
227                        )
228                    })
229                })
230                .await
231        };
232
233        if use_cache {
234            Ok(self
235                .caches
236                .src_dst_to_channel
237                .try_get_with(ChannelParties(*src, *dst), fetch_channel)
238                .await?)
239        } else {
240            fetch_channel.await
241        }
242    }
243
244    async fn get_channels_via<'a>(
245        &'a self,
246        tx: OptTx<'a>,
247        direction: ChannelDirection,
248        target: &Address,
249    ) -> Result<Vec<ChannelEntry>> {
250        let target_hex = target.to_hex();
251        self.nest_transaction(tx)
252            .await?
253            .perform(|tx| {
254                Box::pin(async move {
255                    Channel::find()
256                        .filter(match direction {
257                            ChannelDirection::Incoming => channel::Column::Destination.eq(target_hex),
258                            ChannelDirection::Outgoing => channel::Column::Source.eq(target_hex),
259                        })
260                        .all(tx.as_ref())
261                        .await?
262                        .into_iter()
263                        .map(|x| ChannelEntry::try_from(x).map_err(DbSqlError::from))
264                        .collect::<Result<Vec<_>>>()
265                })
266            })
267            .await
268    }
269
270    async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
271        self.get_channels_via(tx, ChannelDirection::Incoming, &self.me_onchain)
272            .await
273    }
274
275    async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
276        self.get_channels_via(tx, ChannelDirection::Outgoing, &self.me_onchain)
277            .await
278    }
279
280    async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
281        self.nest_transaction(tx)
282            .await?
283            .perform(|tx| {
284                Box::pin(async move {
285                    Channel::find()
286                        .stream(tx.as_ref())
287                        .await?
288                        .map_err(DbSqlError::from)
289                        .try_filter_map(|m| async move { Ok(Some(ChannelEntry::try_from(m)?)) })
290                        .try_collect()
291                        .await
292                })
293            })
294            .await
295    }
296
297    async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>> {
298        Ok(Channel::find()
299            .filter(
300                channel::Column::Status
301                    .eq(i8::from(ChannelStatus::Open))
302                    .or(channel::Column::Status
303                        .eq(i8::from(ChannelStatus::PendingToClose(
304                            hopr_platform::time::native::current_time(), // irrelevant
305                        )))
306                        .and(channel::Column::ClosureTime.gt(Utc::now()))),
307            )
308            .stream(self.index_db.read_only())
309            .await?
310            .map_err(DbSqlError::from)
311            .and_then(|m| async move { Ok(ChannelEntry::try_from(m)?) })
312            .boxed())
313    }
314
315    async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()> {
316        let parties = ChannelParties(channel_entry.source, channel_entry.destination);
317        self.nest_transaction(tx)
318            .await?
319            .perform(|tx| {
320                Box::pin(async move {
321                    let mut model = channel::ActiveModel::from(channel_entry);
322                    if let Some(channel) = channel::Entity::find()
323                        .filter(channel::Column::ChannelId.eq(channel_entry.get_id().to_hex()))
324                        .one(tx.as_ref())
325                        .await?
326                    {
327                        model.id = Set(channel.id);
328                    }
329
330                    Ok::<_, DbSqlError>(model.save(tx.as_ref()).await?)
331                })
332            })
333            .await?;
334
335        self.caches.src_dst_to_channel.invalidate(&parties).await;
336
337        // Finally, invalidate any unrealized values from the cache.
338        // This might be a no-op if the channel was not in the cache
339        // like for channels that are not ours.
340        let channel_id = channel_entry.get_id();
341        let channel_epoch = channel_entry.channel_epoch;
342        self.caches
343            .unrealized_value
344            .invalidate(&(channel_id, channel_epoch))
345            .await;
346
347        Ok(())
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use anyhow::Context;
354    use hopr_crypto_random::random_bytes;
355    use hopr_crypto_types::{keypairs::ChainKeypair, prelude::Keypair};
356    use hopr_internal_types::{
357        channels::ChannelStatus,
358        prelude::{ChannelDirection, ChannelEntry},
359    };
360    use hopr_primitive_types::prelude::Address;
361
362    use crate::{HoprDbGeneralModelOperations, channels::HoprDbChannelOperations, db::HoprDb};
363
364    #[tokio::test]
365    async fn test_insert_get_by_id() -> anyhow::Result<()> {
366        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
367
368        let ce = ChannelEntry::new(
369            Address::default(),
370            Address::default(),
371            0.into(),
372            0_u32.into(),
373            ChannelStatus::Open,
374            0_u32.into(),
375        );
376
377        db.upsert_channel(None, ce).await?;
378        let from_db = db
379            .get_channel_by_id(None, &ce.get_id())
380            .await?
381            .expect("channel must be present");
382
383        assert_eq!(ce, from_db, "channels must be equal");
384
385        Ok(())
386    }
387
388    #[tokio::test]
389    async fn test_insert_get_by_parties() -> anyhow::Result<()> {
390        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
391
392        let a = Address::from(random_bytes());
393        let b = Address::from(random_bytes());
394
395        let ce = ChannelEntry::new(a, b, 0.into(), 0_u32.into(), ChannelStatus::Open, 0_u32.into());
396
397        db.upsert_channel(None, ce).await?;
398        let from_db = db
399            .get_channel_by_parties(None, &a, &b, false)
400            .await?
401            .context("channel must be present")?;
402
403        assert_eq!(ce, from_db, "channels must be equal");
404
405        Ok(())
406    }
407
408    #[tokio::test]
409    async fn test_channel_get_for_destination_that_does_not_exist_returns_none() -> anyhow::Result<()> {
410        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
411
412        let from_db = db
413            .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
414            .await?
415            .first()
416            .cloned();
417
418        assert_eq!(None, from_db, "should return None");
419
420        Ok(())
421    }
422
423    #[tokio::test]
424    async fn test_channel_get_for_destination_that_exists_should_be_returned() -> anyhow::Result<()> {
425        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
426
427        let expected_destination = Address::default();
428
429        let ce = ChannelEntry::new(
430            Address::default(),
431            expected_destination,
432            0.into(),
433            0_u32.into(),
434            ChannelStatus::Open,
435            0_u32.into(),
436        );
437
438        db.upsert_channel(None, ce).await?;
439        let from_db = db
440            .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
441            .await?
442            .first()
443            .cloned();
444
445        assert_eq!(Some(ce), from_db, "should return a valid channel");
446
447        Ok(())
448    }
449
450    #[tokio::test]
451    async fn test_incoming_outgoing_channels() -> anyhow::Result<()> {
452        let ckp = ChainKeypair::random();
453        let addr_1 = ckp.public().to_address();
454        let addr_2 = ChainKeypair::random().public().to_address();
455
456        let db = HoprDb::new_in_memory(ckp).await?;
457
458        let ce_1 = ChannelEntry::new(
459            addr_1,
460            addr_2,
461            0.into(),
462            1_u32.into(),
463            ChannelStatus::Open,
464            0_u32.into(),
465        );
466
467        let ce_2 = ChannelEntry::new(
468            addr_2,
469            addr_1,
470            0.into(),
471            2_u32.into(),
472            ChannelStatus::Open,
473            0_u32.into(),
474        );
475
476        let db_clone = db.clone();
477        db.begin_transaction()
478            .await?
479            .perform(|tx| {
480                Box::pin(async move {
481                    db_clone.upsert_channel(Some(tx), ce_1).await?;
482                    db_clone.upsert_channel(Some(tx), ce_2).await
483                })
484            })
485            .await?;
486
487        assert_eq!(vec![ce_2], db.get_incoming_channels(None).await?);
488        assert_eq!(vec![ce_1], db.get_outgoing_channels(None).await?);
489
490        Ok(())
491    }
492}