1use async_trait::async_trait;
2use futures::{StreamExt, TryStreamExt, stream::BoxStream};
3use hopr_crypto_types::prelude::*;
4use hopr_db_entity::{channel, conversions::channels::ChannelStatusUpdate, prelude::Channel};
5use hopr_internal_types::prelude::*;
6use hopr_primitive_types::prelude::*;
7use sea_orm::{ActiveModelTrait, ActiveValue::Set, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
8use tracing::instrument;
9
10use crate::{
11 HoprDbGeneralModelOperations, OptTx,
12 cache::ChannelParties,
13 db::HoprDb,
14 errors::{DbSqlError, Result},
15};
16
17pub struct ChannelEditor {
19 orig: ChannelEntry,
20 model: channel::ActiveModel,
21 delete: bool,
22}
23
24impl ChannelEditor {
25 pub fn entry(&self) -> &ChannelEntry {
27 &self.orig
28 }
29
30 pub fn change_balance(mut self, balance: HoprBalance) -> Self {
32 self.model.balance = Set(balance.amount().to_be_bytes().to_vec());
33 self
34 }
35
36 pub fn change_status(mut self, status: ChannelStatus) -> Self {
38 self.model.set_status(status);
39 self
40 }
41
42 pub fn change_ticket_index(mut self, index: impl Into<U256>) -> Self {
44 self.model.ticket_index = Set(index.into().to_be_bytes().to_vec());
45 self
46 }
47
48 pub fn change_epoch(mut self, epoch: impl Into<U256>) -> Self {
50 self.model.epoch = Set(epoch.into().to_be_bytes().to_vec());
51 self
52 }
53
54 pub fn delete(mut self) -> Self {
56 self.delete = true;
57 self
58 }
59}
60
61#[async_trait]
63pub trait HoprDbChannelOperations {
64 async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>>;
68
69 async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>>;
73
74 async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<Option<ChannelEntry>>;
77
78 async fn get_channel_by_parties<'a>(
82 &'a self,
83 tx: OptTx<'a>,
84 src: &Address,
85 dst: &Address,
86 use_cache: bool,
87 ) -> Result<Option<ChannelEntry>>;
88
89 async fn get_channels_via<'a>(
91 &'a self,
92 tx: OptTx<'a>,
93 direction: ChannelDirection,
94 target: &Address,
95 ) -> Result<Vec<ChannelEntry>>;
96
97 async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
100
101 async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
104
105 async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
107
108 async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>>;
110
111 async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()>;
113}
114
115#[async_trait]
116impl HoprDbChannelOperations for HoprDb {
117 async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>> {
118 let id_hex = id.to_hex();
119 self.nest_transaction(tx)
120 .await?
121 .perform(|tx| {
122 Box::pin(async move {
123 Ok::<_, DbSqlError>(
124 if let Some(model) = Channel::find()
125 .filter(channel::Column::ChannelId.eq(id_hex))
126 .one(tx.as_ref())
127 .await?
128 {
129 Some(model.try_into()?)
130 } else {
131 None
132 },
133 )
134 })
135 })
136 .await
137 }
138
139 async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>> {
140 let id_hex = id.to_hex();
141 self.nest_transaction(tx)
142 .await?
143 .perform(|tx| {
144 Box::pin(async move {
145 match Channel::find()
146 .filter(channel::Column::ChannelId.eq(id_hex.clone()))
147 .one(tx.as_ref())
148 .await?
149 {
150 Some(model) => Ok(Some(ChannelEditor {
151 orig: ChannelEntry::try_from(model.clone())?,
152 model: model.into_active_model(),
153 delete: false,
154 })),
155 None => Ok(None),
156 }
157 })
158 })
159 .await
160 }
161
162 async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<Option<ChannelEntry>> {
163 let epoch = editor.model.epoch.clone();
164
165 let parties = ChannelParties(editor.orig.source, editor.orig.destination);
166 let ret = self
167 .nest_transaction(tx)
168 .await?
169 .perform(|tx| {
170 Box::pin(async move {
171 if !editor.delete {
172 let model = editor.model.update(tx.as_ref()).await?;
173 match ChannelEntry::try_from(model) {
174 Ok(channel) => Ok::<_, DbSqlError>(Some(channel)),
175 Err(e) => Err(DbSqlError::from(e)),
176 }
177 } else {
178 editor.model.delete(tx.as_ref()).await?;
179 Ok::<_, DbSqlError>(Some(editor.orig))
180 }
181 })
182 })
183 .await?;
184 self.caches.src_dst_to_channel.invalidate(&parties).await;
185
186 let channel_id = editor.orig.get_id();
190 if let Some(channel_epoch) = epoch.try_as_ref() {
191 self.caches
192 .unrealized_value
193 .invalidate(&(channel_id, U256::from_big_endian(channel_epoch.as_slice())))
194 .await;
195 }
196
197 Ok(ret)
198 }
199
200 #[instrument(level = "trace", skip(self, tx), err)]
201 async fn get_channel_by_parties<'a>(
202 &'a self,
203 tx: OptTx<'a>,
204 src: &Address,
205 dst: &Address,
206 use_cache: bool,
207 ) -> Result<Option<ChannelEntry>> {
208 let fetch_channel = async move {
209 let src_hex = src.to_hex();
210 let dst_hex = dst.to_hex();
211 tracing::warn!(%src, %dst, "cache miss on get_channel_by_parties");
212 self.nest_transaction(tx)
213 .await?
214 .perform(|tx| {
215 Box::pin(async move {
216 Ok::<_, DbSqlError>(
217 if let Some(model) = Channel::find()
218 .filter(channel::Column::Source.eq(src_hex))
219 .filter(channel::Column::Destination.eq(dst_hex))
220 .one(tx.as_ref())
221 .await?
222 {
223 Some(model.try_into()?)
224 } else {
225 None
226 },
227 )
228 })
229 })
230 .await
231 };
232
233 if use_cache {
234 Ok(self
235 .caches
236 .src_dst_to_channel
237 .try_get_with(ChannelParties(*src, *dst), fetch_channel)
238 .await?)
239 } else {
240 fetch_channel.await
241 }
242 }
243
244 async fn get_channels_via<'a>(
245 &'a self,
246 tx: OptTx<'a>,
247 direction: ChannelDirection,
248 target: &Address,
249 ) -> Result<Vec<ChannelEntry>> {
250 let target_hex = target.to_hex();
251 self.nest_transaction(tx)
252 .await?
253 .perform(|tx| {
254 Box::pin(async move {
255 Channel::find()
256 .filter(match direction {
257 ChannelDirection::Incoming => channel::Column::Destination.eq(target_hex),
258 ChannelDirection::Outgoing => channel::Column::Source.eq(target_hex),
259 })
260 .all(tx.as_ref())
261 .await?
262 .into_iter()
263 .map(|x| ChannelEntry::try_from(x).map_err(DbSqlError::from))
264 .collect::<Result<Vec<_>>>()
265 })
266 })
267 .await
268 }
269
270 async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
271 self.get_channels_via(tx, ChannelDirection::Incoming, &self.me_onchain)
272 .await
273 }
274
275 async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
276 self.get_channels_via(tx, ChannelDirection::Outgoing, &self.me_onchain)
277 .await
278 }
279
280 async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
281 self.nest_transaction(tx)
282 .await?
283 .perform(|tx| {
284 Box::pin(async move {
285 Channel::find()
286 .stream(tx.as_ref())
287 .await?
288 .map_err(DbSqlError::from)
289 .try_filter_map(|m| async move { Ok(Some(ChannelEntry::try_from(m)?)) })
290 .try_collect()
291 .await
292 })
293 })
294 .await
295 }
296
297 async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>> {
298 Ok(Channel::find()
299 .filter(
300 channel::Column::Status
301 .eq(i8::from(ChannelStatus::Open))
302 .or(channel::Column::Status
303 .eq(i8::from(ChannelStatus::PendingToClose(
304 hopr_platform::time::native::current_time(), )))
306 .and(channel::Column::ClosureTime.gt(Utc::now()))),
307 )
308 .stream(self.index_db.read_only())
309 .await?
310 .map_err(DbSqlError::from)
311 .and_then(|m| async move { Ok(ChannelEntry::try_from(m)?) })
312 .boxed())
313 }
314
315 async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()> {
316 let parties = ChannelParties(channel_entry.source, channel_entry.destination);
317 self.nest_transaction(tx)
318 .await?
319 .perform(|tx| {
320 Box::pin(async move {
321 let mut model = channel::ActiveModel::from(channel_entry);
322 if let Some(channel) = channel::Entity::find()
323 .filter(channel::Column::ChannelId.eq(channel_entry.get_id().to_hex()))
324 .one(tx.as_ref())
325 .await?
326 {
327 model.id = Set(channel.id);
328 }
329
330 Ok::<_, DbSqlError>(model.save(tx.as_ref()).await?)
331 })
332 })
333 .await?;
334
335 self.caches.src_dst_to_channel.invalidate(&parties).await;
336
337 let channel_id = channel_entry.get_id();
341 let channel_epoch = channel_entry.channel_epoch;
342 self.caches
343 .unrealized_value
344 .invalidate(&(channel_id, channel_epoch))
345 .await;
346
347 Ok(())
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use anyhow::Context;
354 use hopr_crypto_random::random_bytes;
355 use hopr_crypto_types::{keypairs::ChainKeypair, prelude::Keypair};
356 use hopr_internal_types::{
357 channels::ChannelStatus,
358 prelude::{ChannelDirection, ChannelEntry},
359 };
360 use hopr_primitive_types::prelude::Address;
361
362 use crate::{HoprDbGeneralModelOperations, channels::HoprDbChannelOperations, db::HoprDb};
363
364 #[tokio::test]
365 async fn test_insert_get_by_id() -> anyhow::Result<()> {
366 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
367
368 let ce = ChannelEntry::new(
369 Address::default(),
370 Address::default(),
371 0.into(),
372 0_u32.into(),
373 ChannelStatus::Open,
374 0_u32.into(),
375 );
376
377 db.upsert_channel(None, ce).await?;
378 let from_db = db
379 .get_channel_by_id(None, &ce.get_id())
380 .await?
381 .expect("channel must be present");
382
383 assert_eq!(ce, from_db, "channels must be equal");
384
385 Ok(())
386 }
387
388 #[tokio::test]
389 async fn test_insert_get_by_parties() -> anyhow::Result<()> {
390 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
391
392 let a = Address::from(random_bytes());
393 let b = Address::from(random_bytes());
394
395 let ce = ChannelEntry::new(a, b, 0.into(), 0_u32.into(), ChannelStatus::Open, 0_u32.into());
396
397 db.upsert_channel(None, ce).await?;
398 let from_db = db
399 .get_channel_by_parties(None, &a, &b, false)
400 .await?
401 .context("channel must be present")?;
402
403 assert_eq!(ce, from_db, "channels must be equal");
404
405 Ok(())
406 }
407
408 #[tokio::test]
409 async fn test_channel_get_for_destination_that_does_not_exist_returns_none() -> anyhow::Result<()> {
410 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
411
412 let from_db = db
413 .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
414 .await?
415 .first()
416 .cloned();
417
418 assert_eq!(None, from_db, "should return None");
419
420 Ok(())
421 }
422
423 #[tokio::test]
424 async fn test_channel_get_for_destination_that_exists_should_be_returned() -> anyhow::Result<()> {
425 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
426
427 let expected_destination = Address::default();
428
429 let ce = ChannelEntry::new(
430 Address::default(),
431 expected_destination,
432 0.into(),
433 0_u32.into(),
434 ChannelStatus::Open,
435 0_u32.into(),
436 );
437
438 db.upsert_channel(None, ce).await?;
439 let from_db = db
440 .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
441 .await?
442 .first()
443 .cloned();
444
445 assert_eq!(Some(ce), from_db, "should return a valid channel");
446
447 Ok(())
448 }
449
450 #[tokio::test]
451 async fn test_incoming_outgoing_channels() -> anyhow::Result<()> {
452 let ckp = ChainKeypair::random();
453 let addr_1 = ckp.public().to_address();
454 let addr_2 = ChainKeypair::random().public().to_address();
455
456 let db = HoprDb::new_in_memory(ckp).await?;
457
458 let ce_1 = ChannelEntry::new(
459 addr_1,
460 addr_2,
461 0.into(),
462 1_u32.into(),
463 ChannelStatus::Open,
464 0_u32.into(),
465 );
466
467 let ce_2 = ChannelEntry::new(
468 addr_2,
469 addr_1,
470 0.into(),
471 2_u32.into(),
472 ChannelStatus::Open,
473 0_u32.into(),
474 );
475
476 let db_clone = db.clone();
477 db.begin_transaction()
478 .await?
479 .perform(|tx| {
480 Box::pin(async move {
481 db_clone.upsert_channel(Some(tx), ce_1).await?;
482 db_clone.upsert_channel(Some(tx), ce_2).await
483 })
484 })
485 .await?;
486
487 assert_eq!(vec![ce_2], db.get_incoming_channels(None).await?);
488 assert_eq!(vec![ce_1], db.get_outgoing_channels(None).await?);
489
490 Ok(())
491 }
492}