hopr_db_sql/
corrupted_channels.rs

1use async_trait::async_trait;
2use futures::TryStreamExt;
3use hopr_crypto_types::prelude::*;
4use hopr_db_entity::{corrupted_channel, prelude::CorruptedChannel};
5use hopr_internal_types::{channels::CorruptedChannelEntry, prelude::*};
6use hopr_primitive_types::prelude::*;
7use sea_orm::{ActiveModelTrait, ActiveValue::Set, ColumnTrait, EntityTrait, QueryFilter};
8
9use crate::{
10    HoprDbGeneralModelOperations, HoprIndexerDb, OptTx,
11    errors::{DbSqlError, Result},
12};
13
14/// Defines DB API for accessing information about HOPR payment channels.
15#[async_trait]
16pub trait HoprDbCorruptedChannelOperations {
17    /// Retrieves corrupted channel by its channel ID hash.
18    ///
19    /// See [generate_channel_id] on how to generate a channel ID hash from source and destination [Addresses](Address).
20    async fn get_corrupted_channel_by_id<'a>(
21        &'a self,
22        tx: OptTx<'a>,
23        id: &Hash,
24    ) -> Result<Option<CorruptedChannelEntry>>;
25
26    /// Retrieves all corrupted channels information from the DB.
27    async fn get_all_corrupted_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<CorruptedChannelEntry>>;
28
29    /// Inserts the given ChannelID as a corrupted channel entry.
30    async fn upsert_corrupted_channel<'a>(&'a self, tx: OptTx<'a>, channel_id: ChannelId) -> Result<()>;
31}
32
33#[async_trait]
34impl HoprDbCorruptedChannelOperations for HoprIndexerDb {
35    async fn get_corrupted_channel_by_id<'a>(
36        &'a self,
37        tx: OptTx<'a>,
38        id: &Hash,
39    ) -> Result<Option<CorruptedChannelEntry>> {
40        let id_hex = id.to_hex();
41        self.nest_transaction(tx)
42            .await?
43            .perform(|tx| {
44                Box::pin(async move {
45                    Ok::<_, DbSqlError>(
46                        if let Some(model) = CorruptedChannel::find()
47                            .filter(corrupted_channel::Column::ChannelId.eq(id_hex))
48                            .one(tx.as_ref())
49                            .await?
50                        {
51                            Some(model.try_into()?)
52                        } else {
53                            None
54                        },
55                    )
56                })
57            })
58            .await
59    }
60
61    async fn get_all_corrupted_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<CorruptedChannelEntry>> {
62        self.nest_transaction(tx)
63            .await?
64            .perform(|tx| {
65                Box::pin(async move {
66                    CorruptedChannel::find()
67                        .stream(tx.as_ref())
68                        .await?
69                        .map_err(DbSqlError::from)
70                        .try_filter_map(|m| async move { Ok(Some(CorruptedChannelEntry::try_from(m)?)) })
71                        .try_collect()
72                        .await
73                })
74            })
75            .await
76    }
77
78    async fn upsert_corrupted_channel<'a>(&'a self, tx: OptTx<'a>, channel_id: ChannelId) -> Result<()> {
79        self.nest_transaction(tx)
80            .await?
81            .perform(|tx| {
82                Box::pin(async move {
83                    let channel_entry = CorruptedChannelEntry::from(channel_id);
84                    let mut model = corrupted_channel::ActiveModel::from(channel_entry);
85                    if let Some(channel) = corrupted_channel::Entity::find()
86                        .filter(corrupted_channel::Column::ChannelId.eq(channel_entry.channel_id().to_hex()))
87                        .one(tx.as_ref())
88                        .await?
89                    {
90                        model.id = Set(channel.id);
91                    }
92
93                    Ok::<_, DbSqlError>(model.save(tx.as_ref()).await?)
94                })
95            })
96            .await?;
97
98        Ok(())
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use anyhow::Context;
105    use hopr_crypto_random::random_bytes;
106    use hopr_crypto_types::{keypairs::ChainKeypair, prelude::Keypair, types::Hash};
107
108    use super::*;
109    use crate::corrupted_channels::HoprDbCorruptedChannelOperations;
110
111    #[tokio::test]
112    async fn test_insert_get_by_id() -> anyhow::Result<()> {
113        let db = HoprIndexerDb::new_in_memory(ChainKeypair::random()).await?;
114
115        let channel_id = Hash::from(random_bytes());
116
117        db.upsert_corrupted_channel(None, channel_id).await?;
118
119        let from_db = db
120            .get_corrupted_channel_by_id(None, &channel_id)
121            .await?
122            .expect("channel must be present");
123
124        assert_eq!(channel_id, *from_db.channel_id(), "channels must be equal");
125
126        Ok(())
127    }
128
129    #[tokio::test]
130    async fn test_insert_duplicates_should_not_insert() -> anyhow::Result<()> {
131        let db = HoprIndexerDb::new_in_memory(ChainKeypair::random()).await?;
132        let channel_id = Hash::from(random_bytes());
133
134        db.upsert_corrupted_channel(None, channel_id)
135            .await
136            .context("Inserting a corrupted channel should not fail")?;
137
138        db.upsert_corrupted_channel(None, channel_id)
139            .await
140            .context("Inserting a duplicate corrupted channel should not fail")?;
141
142        let all_channels = db.get_all_corrupted_channels(None).await?;
143
144        assert_eq!(all_channels.len(), 1, "There should be only one corrupted channel");
145        assert_eq!(
146            all_channels[0].channel_id(),
147            &channel_id,
148            "The channel ID should match the inserted one"
149        );
150        Ok(())
151    }
152}