hopr_crypto_sphinx/
shared_keys.rs1use blake2::Blake2s256;
2use generic_array::{ArrayLength, GenericArray};
3use hkdf::SimpleHkdf;
4use hopr_crypto_types::errors::CryptoError::CalculationError;
5use hopr_crypto_types::errors::Result;
6use hopr_crypto_types::keypairs::Keypair;
7use hopr_crypto_types::utils::SecretValue;
8use std::marker::PhantomData;
9use std::ops::Mul;
10
11pub type SharedSecret = SecretValue<typenum::U32>;
13
14pub trait Scalar: Mul<Output = Self> + Sized {
16 fn random() -> Self;
18
19 fn from_bytes(bytes: &[u8]) -> Result<Self>;
21
22 fn to_bytes(&self) -> Box<[u8]>;
24}
25
26pub type Alpha<A> = GenericArray<u8, A>;
29
30pub trait GroupElement<E: Scalar>: Clone + for<'a> Mul<&'a E, Output = Self> {
34 type AlphaLen: ArrayLength<u8>;
36
37 fn to_alpha(&self) -> Alpha<Self::AlphaLen>;
39
40 fn from_alpha(alpha: Alpha<Self::AlphaLen>) -> Result<Self>;
42
43 fn generate(scalar: &E) -> Self;
45
46 fn is_valid(&self) -> bool;
48
49 fn random_pair() -> (Self, E) {
53 let scalar = E::random();
54 (Self::generate(&scalar), scalar)
55 }
56
57 fn extract_key(&self, salt: &[u8]) -> SharedSecret {
59 let ikm = self.to_alpha();
60 SimpleHkdf::<Blake2s256>::extract(Some(salt), ikm.as_ref()).0.into()
61 }
62
63 fn expand_key(&self, salt: &[u8]) -> SharedSecret {
65 let mut out = GenericArray::default();
66 let ikm = self.to_alpha();
67 SimpleHkdf::<Blake2s256>::new(Some(salt), &ikm)
68 .expand(b"", &mut out)
69 .expect("invalid size of the shared secret output"); out.into()
72 }
73}
74
75pub struct SharedKeys<E: Scalar, G: GroupElement<E>> {
77 pub alpha: Alpha<G::AlphaLen>,
78 pub secrets: Vec<SharedSecret>,
79 _e: PhantomData<E>,
80 _g: PhantomData<G>,
81}
82
83impl<E: Scalar, G: GroupElement<E>> SharedKeys<E, G> {
84 pub fn generate(peer_group_elements: Vec<G>) -> Result<SharedKeys<E, G>> {
87 let mut shared_keys = Vec::new();
88
89 let (mut alpha_prev, mut coeff_prev) = G::random_pair();
92
93 let alpha = alpha_prev.to_alpha();
95
96 let keys_len = peer_group_elements.len();
98 for (i, group_element) in peer_group_elements.into_iter().enumerate() {
99 let salt = group_element.to_alpha();
101 let shared_secret = group_element.mul(&coeff_prev);
102
103 shared_keys.push(shared_secret.extract_key(&salt));
105
106 if i == keys_len - 1 {
108 break;
109 }
110
111 let b_k = shared_secret.expand_key(&alpha_prev.to_alpha());
113 let b_k_checked = E::from_bytes(b_k.as_ref())?;
114
115 alpha_prev = alpha_prev.mul(&b_k_checked);
117 coeff_prev = coeff_prev.mul(b_k_checked);
118
119 if !alpha_prev.is_valid() {
120 return Err(CalculationError);
121 }
122 }
123
124 Ok(SharedKeys {
125 alpha,
126 secrets: shared_keys,
127 _e: PhantomData,
128 _g: PhantomData,
129 })
130 }
131
132 pub fn forward_transform(
135 alpha: &Alpha<G::AlphaLen>,
136 private_scalar: &E,
137 public_group_element: &G,
138 ) -> Result<(Alpha<G::AlphaLen>, SharedSecret)> {
139 let alpha_point = G::from_alpha(alpha.clone())?;
140
141 let s_k = alpha_point.clone().mul(private_scalar);
142
143 let secret = s_k.extract_key(&public_group_element.to_alpha());
144
145 let b_k = s_k.expand_key(alpha);
146
147 let b_k_checked = E::from_bytes(b_k.as_ref())?;
148 let alpha_new = alpha_point.mul(&b_k_checked);
149
150 Ok((alpha_new.to_alpha(), secret))
151 }
152}
153
154pub trait SphinxSuite {
156 type P: Keypair;
158
159 type E: Scalar + for<'a> From<&'a Self::P>;
161
162 type G: GroupElement<Self::E> + for<'a> From<&'a <Self::P as Keypair>::Public>;
164
165 fn new_shared_keys(public_keys: &[<Self::P as Keypair>::Public]) -> Result<SharedKeys<Self::E, Self::G>> {
167 SharedKeys::generate(public_keys.iter().map(|pk| pk.into()).collect())
168 }
169}
170
171#[cfg(test)]
172pub(crate) mod tests {
173 use super::*;
174 use subtle::ConstantTimeEq;
175
176 pub fn generic_sphinx_suite_test<S: SphinxSuite>(node_count: usize) {
177 let (pub_keys, priv_keys): (Vec<S::G>, Vec<S::E>) = (0..node_count).map(|_| S::G::random_pair()).unzip();
178
179 let generated_shares = SharedKeys::<S::E, S::G>::generate(pub_keys.clone()).unwrap();
181 assert_eq!(
182 node_count,
183 generated_shares.secrets.len(),
184 "number of generated keys should be equal to the number of nodes"
185 );
186
187 let mut alpha_cpy = generated_shares.alpha.clone();
188 for (i, priv_key) in priv_keys.into_iter().enumerate() {
189 let (alpha, secret) =
190 SharedKeys::<S::E, S::G>::forward_transform(&alpha_cpy, &priv_key, &pub_keys[i]).unwrap();
191
192 assert_eq!(
193 secret.ct_eq(&generated_shares.secrets[i]).unwrap_u8(),
194 1,
195 "forward transform should yield the same shared secret"
196 );
197
198 alpha_cpy = alpha;
199 }
200 }
201}