1use async_trait::async_trait;
2use futures::stream::BoxStream;
3use futures::{StreamExt, TryStreamExt};
4use sea_orm::ActiveValue::Set;
5use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
6
7use hopr_crypto_types::prelude::*;
8use hopr_db_entity::channel;
9use hopr_db_entity::conversions::channels::ChannelStatusUpdate;
10use hopr_db_entity::prelude::Channel;
11use hopr_internal_types::prelude::*;
12use hopr_primitive_types::prelude::*;
13
14use crate::cache::ChannelParties;
15use crate::db::HoprDb;
16use crate::errors::{DbSqlError, Result};
17use crate::{HoprDbGeneralModelOperations, OptTx};
18
19pub struct ChannelEditor {
21 orig: ChannelEntry,
22 model: channel::ActiveModel,
23 delete: bool,
24}
25
26impl ChannelEditor {
27 pub fn entry(&self) -> &ChannelEntry {
29 &self.orig
30 }
31
32 pub fn change_balance(mut self, balance: Balance) -> Self {
34 assert_eq!(BalanceType::HOPR, balance.balance_type());
35 self.model.balance = Set(balance.amount().to_be_bytes().to_vec());
36 self
37 }
38
39 pub fn change_status(mut self, status: ChannelStatus) -> Self {
41 self.model.set_status(status);
42 self
43 }
44
45 pub fn change_ticket_index(mut self, index: impl Into<U256>) -> Self {
47 self.model.ticket_index = Set(index.into().to_be_bytes().to_vec());
48 self
49 }
50
51 pub fn change_epoch(mut self, epoch: impl Into<U256>) -> Self {
53 self.model.epoch = Set(epoch.into().to_be_bytes().to_vec());
54 self
55 }
56
57 pub fn delete(mut self) -> Self {
59 self.delete = true;
60 self
61 }
62}
63
64#[async_trait]
66pub trait HoprDbChannelOperations {
67 async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>>;
71
72 async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>>;
76
77 async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<ChannelEntry>;
80
81 async fn get_channel_by_parties<'a>(
85 &'a self,
86 tx: OptTx<'a>,
87 src: &Address,
88 dst: &Address,
89 use_cache: bool,
90 ) -> Result<Option<ChannelEntry>>;
91
92 async fn get_channels_via<'a>(
94 &'a self,
95 tx: OptTx<'a>,
96 direction: ChannelDirection,
97 target: &Address,
98 ) -> Result<Vec<ChannelEntry>>;
99
100 async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
103
104 async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
107
108 async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
110
111 async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>>;
113
114 async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()>;
116}
117
118#[async_trait]
119impl HoprDbChannelOperations for HoprDb {
120 async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>> {
121 let id_hex = id.to_hex();
122 self.nest_transaction(tx)
123 .await?
124 .perform(|tx| {
125 Box::pin(async move {
126 Ok::<_, DbSqlError>(
127 if let Some(model) = Channel::find()
128 .filter(channel::Column::ChannelId.eq(id_hex))
129 .one(tx.as_ref())
130 .await?
131 {
132 Some(model.try_into()?)
133 } else {
134 None
135 },
136 )
137 })
138 })
139 .await
140 }
141
142 async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>> {
143 let id_hex = id.to_hex();
144 self.nest_transaction(tx)
145 .await?
146 .perform(|tx| {
147 Box::pin(async move {
148 Ok::<_, DbSqlError>(
149 if let Some(model) = Channel::find()
150 .filter(channel::Column::ChannelId.eq(id_hex))
151 .one(tx.as_ref())
152 .await?
153 {
154 Some(ChannelEditor {
155 orig: model.clone().try_into()?,
156 model: model.into_active_model(),
157 delete: false,
158 })
159 } else {
160 None
161 },
162 )
163 })
164 })
165 .await
166 }
167
168 async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<ChannelEntry> {
169 let epoch = editor.model.epoch.clone();
170 let parties = ChannelParties(editor.orig.source, editor.orig.destination);
171 let ret = self
172 .nest_transaction(tx)
173 .await?
174 .perform(|tx| {
175 Box::pin(async move {
176 if !editor.delete {
177 let model = editor.model.update(tx.as_ref()).await?;
178 Ok::<_, DbSqlError>(model.try_into()?)
179 } else {
180 editor.model.delete(tx.as_ref()).await?;
181 Ok::<_, DbSqlError>(editor.orig)
182 }
183 })
184 })
185 .await?;
186 self.caches.src_dst_to_channel.invalidate(&parties).await;
187
188 let channel_id = editor.orig.get_id();
192 if let Some(channel_epoch) = epoch.try_as_ref() {
193 self.caches
194 .unrealized_value
195 .invalidate(&(channel_id, channel_epoch.as_slice().into()))
196 .await;
197 }
198
199 Ok(ret)
200 }
201
202 async fn get_channel_by_parties<'a>(
203 &'a self,
204 tx: OptTx<'a>,
205 src: &Address,
206 dst: &Address,
207 use_cache: bool,
208 ) -> Result<Option<ChannelEntry>> {
209 let fetch_channel = async move {
210 let src_hex = src.to_hex();
211 let dst_hex = dst.to_hex();
212 self.nest_transaction(tx)
213 .await?
214 .perform(|tx| {
215 Box::pin(async move {
216 Ok::<_, DbSqlError>(
217 if let Some(model) = Channel::find()
218 .filter(channel::Column::Source.eq(src_hex))
219 .filter(channel::Column::Destination.eq(dst_hex))
220 .one(tx.as_ref())
221 .await?
222 {
223 Some(model.try_into()?)
224 } else {
225 None
226 },
227 )
228 })
229 })
230 .await
231 };
232
233 if use_cache {
234 Ok(self
235 .caches
236 .src_dst_to_channel
237 .try_get_with(ChannelParties(*src, *dst), fetch_channel)
238 .await?)
239 } else {
240 fetch_channel.await
241 }
242 }
243
244 async fn get_channels_via<'a>(
245 &'a self,
246 tx: OptTx<'a>,
247 direction: ChannelDirection,
248 target: &Address,
249 ) -> Result<Vec<ChannelEntry>> {
250 let target_hex = target.to_hex();
251 self.nest_transaction(tx)
252 .await?
253 .perform(|tx| {
254 Box::pin(async move {
255 Channel::find()
256 .filter(match direction {
257 ChannelDirection::Incoming => channel::Column::Destination.eq(target_hex),
258 ChannelDirection::Outgoing => channel::Column::Source.eq(target_hex),
259 })
260 .all(tx.as_ref())
261 .await?
262 .into_iter()
263 .map(|x| ChannelEntry::try_from(x).map_err(DbSqlError::from))
264 .collect::<Result<Vec<_>>>()
265 })
266 })
267 .await
268 }
269
270 async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
271 self.get_channels_via(tx, ChannelDirection::Incoming, &self.me_onchain)
272 .await
273 }
274
275 async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
276 self.get_channels_via(tx, ChannelDirection::Outgoing, &self.me_onchain)
277 .await
278 }
279
280 async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
281 self.nest_transaction(tx)
282 .await?
283 .perform(|tx| {
284 Box::pin(async move {
285 Channel::find()
286 .stream(tx.as_ref())
287 .await?
288 .map_err(DbSqlError::from)
289 .try_filter_map(|m| async move { Ok(Some(ChannelEntry::try_from(m)?)) })
290 .try_collect()
291 .await
292 })
293 })
294 .await
295 }
296
297 async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>> {
298 Ok(Channel::find()
299 .filter(
300 channel::Column::Status
301 .eq(i8::from(ChannelStatus::Open))
302 .or(channel::Column::Status
303 .eq(i8::from(ChannelStatus::PendingToClose(
304 hopr_platform::time::native::current_time(), )))
306 .and(channel::Column::ClosureTime.gt(Utc::now()))),
307 )
308 .stream(&self.index_db)
309 .await?
310 .map_err(DbSqlError::from)
311 .and_then(|m| async move { Ok(ChannelEntry::try_from(m)?) })
312 .boxed())
313 }
314
315 async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()> {
316 let parties = ChannelParties(channel_entry.source, channel_entry.destination);
317 self.nest_transaction(tx)
318 .await?
319 .perform(|tx| {
320 Box::pin(async move {
321 let mut model = channel::ActiveModel::from(channel_entry);
322 if let Some(channel) = channel::Entity::find()
323 .filter(channel::Column::ChannelId.eq(channel_entry.get_id().to_hex()))
324 .one(tx.as_ref())
325 .await?
326 {
327 model.id = Set(channel.id);
328 }
329
330 Ok::<_, DbSqlError>(model.save(tx.as_ref()).await?)
331 })
332 })
333 .await?;
334
335 self.caches.src_dst_to_channel.invalidate(&parties).await;
336
337 let channel_id = channel_entry.get_id();
341 let channel_epoch = channel_entry.channel_epoch;
342 self.caches
343 .unrealized_value
344 .invalidate(&(channel_id, channel_epoch))
345 .await;
346
347 Ok(())
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use crate::channels::HoprDbChannelOperations;
354 use crate::db::HoprDb;
355 use crate::HoprDbGeneralModelOperations;
356 use anyhow::Context;
357 use hopr_crypto_random::random_bytes;
358 use hopr_crypto_types::keypairs::ChainKeypair;
359 use hopr_crypto_types::prelude::Keypair;
360 use hopr_internal_types::channels::ChannelStatus;
361 use hopr_internal_types::prelude::{ChannelDirection, ChannelEntry};
362 use hopr_primitive_types::prelude::{Address, BalanceType};
363
364 #[async_std::test]
365 async fn test_insert_get_by_id() -> anyhow::Result<()> {
366 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
367
368 let ce = ChannelEntry::new(
369 Address::default(),
370 Address::default(),
371 BalanceType::HOPR.zero(),
372 0_u32.into(),
373 ChannelStatus::Open,
374 0_u32.into(),
375 );
376
377 db.upsert_channel(None, ce).await?;
378 let from_db = db
379 .get_channel_by_id(None, &ce.get_id())
380 .await?
381 .expect("channel must be present");
382
383 assert_eq!(ce, from_db, "channels must be equal");
384
385 Ok(())
386 }
387
388 #[async_std::test]
389 async fn test_insert_get_by_parties() -> anyhow::Result<()> {
390 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
391
392 let a = Address::from(random_bytes());
393 let b = Address::from(random_bytes());
394
395 let ce = ChannelEntry::new(
396 a,
397 b,
398 BalanceType::HOPR.zero(),
399 0_u32.into(),
400 ChannelStatus::Open,
401 0_u32.into(),
402 );
403
404 db.upsert_channel(None, ce).await?;
405 let from_db = db
406 .get_channel_by_parties(None, &a, &b, false)
407 .await?
408 .context("channel must be present")?;
409
410 assert_eq!(ce, from_db, "channels must be equal");
411
412 Ok(())
413 }
414
415 #[async_std::test]
416 async fn test_channel_get_for_destination_that_does_not_exist_returns_none() -> anyhow::Result<()> {
417 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
418
419 let from_db = db
420 .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
421 .await?
422 .first()
423 .cloned();
424
425 assert_eq!(None, from_db, "should return None");
426
427 Ok(())
428 }
429
430 #[async_std::test]
431 async fn test_channel_get_for_destination_that_exists_should_be_returned() -> anyhow::Result<()> {
432 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
433
434 let expected_destination = Address::default();
435
436 let ce = ChannelEntry::new(
437 Address::default(),
438 expected_destination,
439 BalanceType::HOPR.zero(),
440 0_u32.into(),
441 ChannelStatus::Open,
442 0_u32.into(),
443 );
444
445 db.upsert_channel(None, ce).await?;
446 let from_db = db
447 .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
448 .await?
449 .first()
450 .cloned();
451
452 assert_eq!(Some(ce), from_db, "should return a valid channel");
453
454 Ok(())
455 }
456
457 #[async_std::test]
458 async fn test_incoming_outgoing_channels() -> anyhow::Result<()> {
459 let ckp = ChainKeypair::random();
460 let addr_1 = ckp.public().to_address();
461 let addr_2 = ChainKeypair::random().public().to_address();
462
463 let db = HoprDb::new_in_memory(ckp).await?;
464
465 let ce_1 = ChannelEntry::new(
466 addr_1,
467 addr_2,
468 BalanceType::HOPR.zero(),
469 1_u32.into(),
470 ChannelStatus::Open,
471 0_u32.into(),
472 );
473
474 let ce_2 = ChannelEntry::new(
475 addr_2,
476 addr_1,
477 BalanceType::HOPR.zero(),
478 2_u32.into(),
479 ChannelStatus::Open,
480 0_u32.into(),
481 );
482
483 let db_clone = db.clone();
484 db.begin_transaction()
485 .await?
486 .perform(|tx| {
487 Box::pin(async move {
488 db_clone.upsert_channel(Some(tx), ce_1).await?;
489 db_clone.upsert_channel(Some(tx), ce_2).await
490 })
491 })
492 .await?;
493
494 assert_eq!(vec![ce_2], db.get_incoming_channels(None).await?);
495 assert_eq!(vec![ce_1], db.get_outgoing_channels(None).await?);
496
497 Ok(())
498 }
499}