1use async_trait::async_trait;
2use futures::{StreamExt, TryStreamExt, stream::BoxStream};
3use hopr_crypto_types::prelude::*;
4use hopr_db_entity::{channel, conversions::channels::ChannelStatusUpdate, prelude::Channel};
5use hopr_internal_types::prelude::*;
6use hopr_primitive_types::prelude::*;
7use sea_orm::{ActiveModelTrait, ActiveValue::Set, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter};
8
9use crate::{
10 HoprDbGeneralModelOperations, OptTx,
11 cache::ChannelParties,
12 db::HoprDb,
13 errors::{DbSqlError, Result},
14};
15
16pub struct ChannelEditor {
18 orig: ChannelEntry,
19 model: channel::ActiveModel,
20 delete: bool,
21}
22
23impl ChannelEditor {
24 pub fn entry(&self) -> &ChannelEntry {
26 &self.orig
27 }
28
29 pub fn change_balance(mut self, balance: HoprBalance) -> Self {
31 self.model.balance = Set(balance.amount().to_be_bytes().to_vec());
32 self
33 }
34
35 pub fn change_status(mut self, status: ChannelStatus) -> Self {
37 self.model.set_status(status);
38 self
39 }
40
41 pub fn change_ticket_index(mut self, index: impl Into<U256>) -> Self {
43 self.model.ticket_index = Set(index.into().to_be_bytes().to_vec());
44 self
45 }
46
47 pub fn change_epoch(mut self, epoch: impl Into<U256>) -> Self {
49 self.model.epoch = Set(epoch.into().to_be_bytes().to_vec());
50 self
51 }
52
53 pub fn delete(mut self) -> Self {
55 self.delete = true;
56 self
57 }
58}
59
60#[async_trait]
62pub trait HoprDbChannelOperations {
63 async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>>;
67
68 async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>>;
72
73 async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<ChannelEntry>;
76
77 async fn get_channel_by_parties<'a>(
81 &'a self,
82 tx: OptTx<'a>,
83 src: &Address,
84 dst: &Address,
85 use_cache: bool,
86 ) -> Result<Option<ChannelEntry>>;
87
88 async fn get_channels_via<'a>(
90 &'a self,
91 tx: OptTx<'a>,
92 direction: ChannelDirection,
93 target: &Address,
94 ) -> Result<Vec<ChannelEntry>>;
95
96 async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
99
100 async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
103
104 async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>>;
106
107 async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>>;
109
110 async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()>;
112}
113
114#[async_trait]
115impl HoprDbChannelOperations for HoprDb {
116 async fn get_channel_by_id<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEntry>> {
117 let id_hex = id.to_hex();
118 self.nest_transaction(tx)
119 .await?
120 .perform(|tx| {
121 Box::pin(async move {
122 Ok::<_, DbSqlError>(
123 if let Some(model) = Channel::find()
124 .filter(channel::Column::ChannelId.eq(id_hex))
125 .one(tx.as_ref())
126 .await?
127 {
128 Some(model.try_into()?)
129 } else {
130 None
131 },
132 )
133 })
134 })
135 .await
136 }
137
138 async fn begin_channel_update<'a>(&'a self, tx: OptTx<'a>, id: &Hash) -> Result<Option<ChannelEditor>> {
139 let id_hex = id.to_hex();
140 self.nest_transaction(tx)
141 .await?
142 .perform(|tx| {
143 Box::pin(async move {
144 Ok::<_, DbSqlError>(
145 if let Some(model) = Channel::find()
146 .filter(channel::Column::ChannelId.eq(id_hex))
147 .one(tx.as_ref())
148 .await?
149 {
150 Some(ChannelEditor {
151 orig: model.clone().try_into()?,
152 model: model.into_active_model(),
153 delete: false,
154 })
155 } else {
156 None
157 },
158 )
159 })
160 })
161 .await
162 }
163
164 async fn finish_channel_update<'a>(&'a self, tx: OptTx<'a>, editor: ChannelEditor) -> Result<ChannelEntry> {
165 let epoch = editor.model.epoch.clone();
166 let parties = ChannelParties(editor.orig.source, editor.orig.destination);
167 let ret = self
168 .nest_transaction(tx)
169 .await?
170 .perform(|tx| {
171 Box::pin(async move {
172 if !editor.delete {
173 let model = editor.model.update(tx.as_ref()).await?;
174 Ok::<_, DbSqlError>(model.try_into()?)
175 } else {
176 editor.model.delete(tx.as_ref()).await?;
177 Ok::<_, DbSqlError>(editor.orig)
178 }
179 })
180 })
181 .await?;
182 self.caches.src_dst_to_channel.invalidate(&parties).await;
183
184 let channel_id = editor.orig.get_id();
188 if let Some(channel_epoch) = epoch.try_as_ref() {
189 self.caches
190 .unrealized_value
191 .invalidate(&(channel_id, U256::from_big_endian(channel_epoch.as_slice())))
192 .await;
193 }
194
195 Ok(ret)
196 }
197
198 async fn get_channel_by_parties<'a>(
199 &'a self,
200 tx: OptTx<'a>,
201 src: &Address,
202 dst: &Address,
203 use_cache: bool,
204 ) -> Result<Option<ChannelEntry>> {
205 let fetch_channel = async move {
206 let src_hex = src.to_hex();
207 let dst_hex = dst.to_hex();
208 self.nest_transaction(tx)
209 .await?
210 .perform(|tx| {
211 Box::pin(async move {
212 Ok::<_, DbSqlError>(
213 if let Some(model) = Channel::find()
214 .filter(channel::Column::Source.eq(src_hex))
215 .filter(channel::Column::Destination.eq(dst_hex))
216 .one(tx.as_ref())
217 .await?
218 {
219 Some(model.try_into()?)
220 } else {
221 None
222 },
223 )
224 })
225 })
226 .await
227 };
228
229 if use_cache {
230 Ok(self
231 .caches
232 .src_dst_to_channel
233 .try_get_with(ChannelParties(*src, *dst), fetch_channel)
234 .await?)
235 } else {
236 fetch_channel.await
237 }
238 }
239
240 async fn get_channels_via<'a>(
241 &'a self,
242 tx: OptTx<'a>,
243 direction: ChannelDirection,
244 target: &Address,
245 ) -> Result<Vec<ChannelEntry>> {
246 let target_hex = target.to_hex();
247 self.nest_transaction(tx)
248 .await?
249 .perform(|tx| {
250 Box::pin(async move {
251 Channel::find()
252 .filter(match direction {
253 ChannelDirection::Incoming => channel::Column::Destination.eq(target_hex),
254 ChannelDirection::Outgoing => channel::Column::Source.eq(target_hex),
255 })
256 .all(tx.as_ref())
257 .await?
258 .into_iter()
259 .map(|x| ChannelEntry::try_from(x).map_err(DbSqlError::from))
260 .collect::<Result<Vec<_>>>()
261 })
262 })
263 .await
264 }
265
266 async fn get_incoming_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
267 self.get_channels_via(tx, ChannelDirection::Incoming, &self.me_onchain)
268 .await
269 }
270
271 async fn get_outgoing_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
272 self.get_channels_via(tx, ChannelDirection::Outgoing, &self.me_onchain)
273 .await
274 }
275
276 async fn get_all_channels<'a>(&'a self, tx: OptTx<'a>) -> Result<Vec<ChannelEntry>> {
277 self.nest_transaction(tx)
278 .await?
279 .perform(|tx| {
280 Box::pin(async move {
281 Channel::find()
282 .stream(tx.as_ref())
283 .await?
284 .map_err(DbSqlError::from)
285 .try_filter_map(|m| async move { Ok(Some(ChannelEntry::try_from(m)?)) })
286 .try_collect()
287 .await
288 })
289 })
290 .await
291 }
292
293 async fn stream_active_channels<'a>(&'a self) -> Result<BoxStream<'a, Result<ChannelEntry>>> {
294 Ok(Channel::find()
295 .filter(
296 channel::Column::Status
297 .eq(i8::from(ChannelStatus::Open))
298 .or(channel::Column::Status
299 .eq(i8::from(ChannelStatus::PendingToClose(
300 hopr_platform::time::native::current_time(), )))
302 .and(channel::Column::ClosureTime.gt(Utc::now()))),
303 )
304 .stream(&self.index_db)
305 .await?
306 .map_err(DbSqlError::from)
307 .and_then(|m| async move { Ok(ChannelEntry::try_from(m)?) })
308 .boxed())
309 }
310
311 async fn upsert_channel<'a>(&'a self, tx: OptTx<'a>, channel_entry: ChannelEntry) -> Result<()> {
312 let parties = ChannelParties(channel_entry.source, channel_entry.destination);
313 self.nest_transaction(tx)
314 .await?
315 .perform(|tx| {
316 Box::pin(async move {
317 let mut model = channel::ActiveModel::from(channel_entry);
318 if let Some(channel) = channel::Entity::find()
319 .filter(channel::Column::ChannelId.eq(channel_entry.get_id().to_hex()))
320 .one(tx.as_ref())
321 .await?
322 {
323 model.id = Set(channel.id);
324 }
325
326 Ok::<_, DbSqlError>(model.save(tx.as_ref()).await?)
327 })
328 })
329 .await?;
330
331 self.caches.src_dst_to_channel.invalidate(&parties).await;
332
333 let channel_id = channel_entry.get_id();
337 let channel_epoch = channel_entry.channel_epoch;
338 self.caches
339 .unrealized_value
340 .invalidate(&(channel_id, channel_epoch))
341 .await;
342
343 Ok(())
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use anyhow::Context;
350 use hopr_crypto_random::random_bytes;
351 use hopr_crypto_types::{keypairs::ChainKeypair, prelude::Keypair};
352 use hopr_internal_types::{
353 channels::ChannelStatus,
354 prelude::{ChannelDirection, ChannelEntry},
355 };
356 use hopr_primitive_types::prelude::Address;
357
358 use crate::{HoprDbGeneralModelOperations, channels::HoprDbChannelOperations, db::HoprDb};
359
360 #[tokio::test]
361 async fn test_insert_get_by_id() -> anyhow::Result<()> {
362 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
363
364 let ce = ChannelEntry::new(
365 Address::default(),
366 Address::default(),
367 0.into(),
368 0_u32.into(),
369 ChannelStatus::Open,
370 0_u32.into(),
371 );
372
373 db.upsert_channel(None, ce).await?;
374 let from_db = db
375 .get_channel_by_id(None, &ce.get_id())
376 .await?
377 .expect("channel must be present");
378
379 assert_eq!(ce, from_db, "channels must be equal");
380
381 Ok(())
382 }
383
384 #[tokio::test]
385 async fn test_insert_get_by_parties() -> anyhow::Result<()> {
386 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
387
388 let a = Address::from(random_bytes());
389 let b = Address::from(random_bytes());
390
391 let ce = ChannelEntry::new(a, b, 0.into(), 0_u32.into(), ChannelStatus::Open, 0_u32.into());
392
393 db.upsert_channel(None, ce).await?;
394 let from_db = db
395 .get_channel_by_parties(None, &a, &b, false)
396 .await?
397 .context("channel must be present")?;
398
399 assert_eq!(ce, from_db, "channels must be equal");
400
401 Ok(())
402 }
403
404 #[tokio::test]
405 async fn test_channel_get_for_destination_that_does_not_exist_returns_none() -> anyhow::Result<()> {
406 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
407
408 let from_db = db
409 .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
410 .await?
411 .first()
412 .cloned();
413
414 assert_eq!(None, from_db, "should return None");
415
416 Ok(())
417 }
418
419 #[tokio::test]
420 async fn test_channel_get_for_destination_that_exists_should_be_returned() -> anyhow::Result<()> {
421 let db = HoprDb::new_in_memory(ChainKeypair::random()).await?;
422
423 let expected_destination = Address::default();
424
425 let ce = ChannelEntry::new(
426 Address::default(),
427 expected_destination,
428 0.into(),
429 0_u32.into(),
430 ChannelStatus::Open,
431 0_u32.into(),
432 );
433
434 db.upsert_channel(None, ce).await?;
435 let from_db = db
436 .get_channels_via(None, ChannelDirection::Incoming, &Address::default())
437 .await?
438 .first()
439 .cloned();
440
441 assert_eq!(Some(ce), from_db, "should return a valid channel");
442
443 Ok(())
444 }
445
446 #[tokio::test]
447 async fn test_incoming_outgoing_channels() -> anyhow::Result<()> {
448 let ckp = ChainKeypair::random();
449 let addr_1 = ckp.public().to_address();
450 let addr_2 = ChainKeypair::random().public().to_address();
451
452 let db = HoprDb::new_in_memory(ckp).await?;
453
454 let ce_1 = ChannelEntry::new(
455 addr_1,
456 addr_2,
457 0.into(),
458 1_u32.into(),
459 ChannelStatus::Open,
460 0_u32.into(),
461 );
462
463 let ce_2 = ChannelEntry::new(
464 addr_2,
465 addr_1,
466 0.into(),
467 2_u32.into(),
468 ChannelStatus::Open,
469 0_u32.into(),
470 );
471
472 let db_clone = db.clone();
473 db.begin_transaction()
474 .await?
475 .perform(|tx| {
476 Box::pin(async move {
477 db_clone.upsert_channel(Some(tx), ce_1).await?;
478 db_clone.upsert_channel(Some(tx), ce_2).await
479 })
480 })
481 .await?;
482
483 assert_eq!(vec![ce_2], db.get_incoming_channels(None).await?);
484 assert_eq!(vec![ce_1], db.get_outgoing_channels(None).await?);
485
486 Ok(())
487 }
488}