hopr_db_sql/
corrupted_channels.rs

1use async_trait::async_trait;
2use futures::TryStreamExt;
3use hopr_crypto_types::prelude::*;
4use hopr_db_entity::{corrupted_channel, prelude::CorruptedChannel};
5use hopr_internal_types::{channels::CorruptedChannelEntry, prelude::*};
6use hopr_primitive_types::prelude::*;
7use sea_orm::{ActiveModelTrait, ActiveValue::Set, ColumnTrait, EntityTrait, QueryFilter};
8
9use crate::{
10    HoprDbGeneralModelOperations, OptTx,
11    db::HoprDb,
12    errors::{DbSqlError, Result},
13};
14
15/// Defines DB API for accessing information about HOPR payment channels.
16#[async_trait]
17pub trait HoprDbCorruptedChannelOperations {
18    /// Retrieves corrupted channel by its channel ID hash.
19    ///
20    /// See [generate_channel_id] on how to generate a channel ID hash from source and destination [Addresses](Address).
21    async fn get_corrupted_channel_by_id<'a>(
22        &'a self,
23        tx: OptTx<'a>,
24        id: &Hash,
25    ) -> Result<Option<CorruptedChannelEntry>>;
26
27    /// Retrieves all corrupted channels information from the DB.
28    async fn get_all_corrupted_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<CorruptedChannelEntry>>;
29
30    /// Inserts the given ChannelID as a corrupted channel entry.
31    async fn upsert_corrupted_channel<'a>(&'a self, tx: OptTx<'a>, channel_id: ChannelId) -> Result<()>;
32}
33
34#[async_trait]
35impl HoprDbCorruptedChannelOperations for HoprDb {
36    async fn get_corrupted_channel_by_id<'a>(
37        &'a self,
38        tx: OptTx<'a>,
39        id: &Hash,
40    ) -> Result<Option<CorruptedChannelEntry>> {
41        let id_hex = id.to_hex();
42        self.nest_transaction(tx)
43            .await?
44            .perform(|tx| {
45                Box::pin(async move {
46                    Ok::<_, DbSqlError>(
47                        if let Some(model) = CorruptedChannel::find()
48                            .filter(corrupted_channel::Column::ChannelId.eq(id_hex))
49                            .one(tx.as_ref())
50                            .await?
51                        {
52                            Some(model.try_into()?)
53                        } else {
54                            None
55                        },
56                    )
57                })
58            })
59            .await
60    }
61
62    async fn get_all_corrupted_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<CorruptedChannelEntry>> {
63        self.nest_transaction(tx)
64            .await?
65            .perform(|tx| {
66                Box::pin(async move {
67                    CorruptedChannel::find()
68                        .stream(tx.as_ref())
69                        .await?
70                        .map_err(DbSqlError::from)
71                        .try_filter_map(|m| async move { Ok(Some(CorruptedChannelEntry::try_from(m)?)) })
72                        .try_collect()
73                        .await
74                })
75            })
76            .await
77    }
78
79    async fn upsert_corrupted_channel<'a>(&'a self, tx: OptTx<'a>, channel_id: ChannelId) -> Result<()> {
80        self.nest_transaction(tx)
81            .await?
82            .perform(|tx| {
83                Box::pin(async move {
84                    let channel_entry = CorruptedChannelEntry::from(channel_id);
85                    let mut model = corrupted_channel::ActiveModel::from(channel_entry);
86                    if let Some(channel) = corrupted_channel::Entity::find()
87                        .filter(corrupted_channel::Column::ChannelId.eq(channel_entry.channel_id().to_hex()))
88                        .one(tx.as_ref())
89                        .await?
90                    {
91                        model.id = Set(channel.id);
92                    }
93
94                    Ok::<_, DbSqlError>(model.save(tx.as_ref()).await?)
95                })
96            })
97            .await?;
98
99        Ok(())
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use anyhow::Context;
106    use hopr_crypto_random::random_bytes;
107    use hopr_crypto_types::{keypairs::ChainKeypair, prelude::Keypair, types::Hash};
108
109    use crate::{corrupted_channels::HoprDbCorruptedChannelOperations, db::HoprDb};
110
111    #[tokio::test]
112    async fn test_insert_get_by_id() -> anyhow::Result<()> {
113        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
114
115        let channel_id = Hash::from(random_bytes());
116
117        db.upsert_corrupted_channel(None, channel_id).await?;
118
119        let from_db = db
120            .get_corrupted_channel_by_id(None, &channel_id)
121            .await?
122            .expect("channel must be present");
123
124        assert_eq!(channel_id, *from_db.channel_id(), "channels must be equal");
125
126        Ok(())
127    }
128
129    #[tokio::test]
130    async fn test_insert_duplicates_should_not_insert() -> anyhow::Result<()> {
131        let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
132        let channel_id = Hash::from(random_bytes());
133
134        db.upsert_corrupted_channel(None, channel_id)
135            .await
136            .context("Inserting a corrupted channel should not fail")?;
137
138        db.upsert_corrupted_channel(None, channel_id)
139            .await
140            .context("Inserting a duplicate corrupted channel should not fail")?;
141
142        let all_channels = db.get_all_corrupted_channels(None).await?;
143
144        assert_eq!(all_channels.len(), 1, "There should be only one corrupted channel");
145        assert_eq!(
146            all_channels[0].channel_id(),
147            &channel_id,
148            "The channel ID should match the inserted one"
149        );
150        Ok(())
151    }
152}