hopr_chain_connector/backend/
tempdb.rs

1use hopr_api::chain::HoprKeyIdent;
2use hopr_crypto_types::prelude::OffchainPublicKey;
3use hopr_internal_types::{
4    account::AccountEntry,
5    channels::{ChannelEntry, ChannelId},
6};
7use hopr_primitive_types::prelude::{Address, BytesRepresentable};
8use redb::{ReadableDatabase, TableDefinition};
9
10/// A backend that is implemented via [`redb`](https://docs.rs/redb/latest/redb/) database stored in a temporary file.
11///
12/// The database file is dropped once the last instance is dropped.
13#[derive(Clone)]
14pub struct TempDbBackend {
15    db: std::sync::Arc<redb::Database>,
16    _tmp: std::sync::Arc<tempfile::NamedTempFile>,
17}
18
19impl TempDbBackend {
20    pub fn new() -> Result<Self, std::io::Error> {
21        let file = tempfile::NamedTempFile::new().map_err(std::io::Error::other)?;
22
23        Ok(Self {
24            db: std::sync::Arc::new(redb::Database::create(file.path()).map_err(std::io::Error::other)?),
25            _tmp: std::sync::Arc::new(file),
26        })
27    }
28}
29
30const ACCOUNTS_TABLE_DEF: TableDefinition<u32, Vec<u8>> = TableDefinition::new("id_accounts");
31const CHANNELS_TABLE_DEF: TableDefinition<[u8; ChannelId::SIZE], Vec<u8>> = TableDefinition::new("id_channels");
32const ADDRESS_TO_ID: TableDefinition<[u8; Address::SIZE], u32> = TableDefinition::new("address_to_id");
33const KEY_TO_ID: TableDefinition<[u8; OffchainPublicKey::SIZE], u32> = TableDefinition::new("key_to_id");
34
35impl super::Backend for TempDbBackend {
36    type Error = redb::Error;
37
38    fn insert_account(&self, account: AccountEntry) -> Result<Option<AccountEntry>, Self::Error> {
39        let write_tx = self.db.begin_write()?;
40        let old_value = {
41            let mut accounts = write_tx.open_table(ACCOUNTS_TABLE_DEF)?;
42            let old_value = accounts
43                .insert(
44                    u32::from(account.key_id),
45                    postcard::to_allocvec(&account)
46                        .map_err(|e| redb::Error::Corrupted(format!("account serialization failed: {e}")))?,
47                )?
48                .map(|v| postcard::from_bytes::<AccountEntry>(&v.value()))
49                .transpose()
50                .map_err(|e| redb::Error::Corrupted(format!("account decoding failed: {e}")))?;
51
52            let mut address_to_id = write_tx.open_table(ADDRESS_TO_ID)?;
53            let mut key_to_id = write_tx.open_table(KEY_TO_ID)?;
54
55            // Remove old account entry references not to create stale mappings if keys changed
56            if let Some(old_entry) = &old_value {
57                let chain_addr: [u8; Address::SIZE] = old_entry.chain_addr.into();
58                let packet_addr: [u8; OffchainPublicKey::SIZE] = old_entry.public_key.into();
59                address_to_id.remove(&chain_addr)?;
60                key_to_id.remove(&packet_addr)?;
61            }
62
63            let chain_addr: [u8; Address::SIZE] = account.chain_addr.into();
64            address_to_id.insert(chain_addr, u32::from(account.key_id))?;
65
66            let packet_addr: [u8; OffchainPublicKey::SIZE] = account.public_key.into();
67            key_to_id.insert(packet_addr, u32::from(account.key_id))?;
68
69            old_value
70        };
71        write_tx.commit()?;
72
73        tracing::debug!(new = %account, old = ?old_value, "upserted account");
74        Ok(old_value)
75    }
76
77    fn insert_channel(&self, channel: ChannelEntry) -> Result<Option<ChannelEntry>, Self::Error> {
78        let write_tx = self.db.begin_write()?;
79        let old_value = {
80            let mut channels = write_tx.open_table(CHANNELS_TABLE_DEF)?;
81            let channel_id: [u8; ChannelId::SIZE] = channel.get_id().into();
82            channels
83                .insert(
84                    channel_id,
85                    postcard::to_allocvec(&channel)
86                        .map_err(|e| redb::Error::Corrupted(format!("channel encoding failed: {e}")))?,
87                )?
88                .map(|v| postcard::from_bytes::<ChannelEntry>(&v.value()))
89                .transpose()
90                .map_err(|e| redb::Error::Corrupted(format!("channel decoding failed: {e}")))?
91        };
92        write_tx.commit()?;
93
94        tracing::debug!(new = %channel, old = ?old_value, "upserted channel");
95        Ok(old_value)
96    }
97
98    fn get_account_by_id(&self, id: &HoprKeyIdent) -> Result<Option<AccountEntry>, Self::Error> {
99        let read_tx = self.db.begin_read()?;
100        let accounts = read_tx.open_table(ACCOUNTS_TABLE_DEF)?;
101        accounts
102            .get(u32::from(*id))?
103            .map(|v| postcard::from_bytes::<AccountEntry>(&v.value()))
104            .transpose()
105            .map_err(|e| redb::Error::Corrupted(format!("account decoding failed: {e}")))
106    }
107
108    fn get_account_by_key(&self, key: &OffchainPublicKey) -> Result<Option<AccountEntry>, Self::Error> {
109        let id = {
110            let read_tx = self.db.begin_read()?;
111            let keys_to_id = read_tx.open_table(KEY_TO_ID)?;
112            let packet_addr: [u8; OffchainPublicKey::SIZE] = (*key).into();
113            let Some(id) = keys_to_id.get(packet_addr)?.map(|v| v.value()) else {
114                return Ok(None);
115            };
116            id
117        };
118
119        self.get_account_by_id(&id.into())
120    }
121
122    fn get_account_by_address(&self, chain_key: &Address) -> Result<Option<AccountEntry>, Self::Error> {
123        let id = {
124            let read_tx = self.db.begin_read()?;
125            let address_to_id = read_tx.open_table(ADDRESS_TO_ID)?;
126            let chain_key: [u8; Address::SIZE] = (*chain_key).into();
127            let Some(id) = address_to_id.get(chain_key)?.map(|v| v.value()) else {
128                return Ok(None);
129            };
130            id
131        };
132
133        self.get_account_by_id(&id.into())
134    }
135
136    fn get_channel_by_id(&self, id: &ChannelId) -> Result<Option<ChannelEntry>, Self::Error> {
137        let read_tx = self.db.begin_read()?;
138        let accounts = read_tx.open_table(CHANNELS_TABLE_DEF)?;
139        let id: [u8; ChannelId::SIZE] = (*id).into();
140        accounts
141            .get(id)?
142            .map(|v| postcard::from_bytes::<ChannelEntry>(&v.value()))
143            .transpose()
144            .map_err(|e| redb::Error::Corrupted(format!("channel decoding failed: {e}")))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::backend::tests::test_backend;
152
153    #[test]
154    fn test_tempdb() -> anyhow::Result<()> {
155        let backend = TempDbBackend::new()?;
156        test_backend(backend)
157    }
158}