hoprd_api/
preconditions.rs1use axum::{
2 extract::{OriginalUri, Request, State},
3 http::{
4 header::{HeaderName, AUTHORIZATION},
5 status::StatusCode,
6 HeaderMap,
7 },
8 middleware::Next,
9 response::IntoResponse,
10};
11use std::{str::FromStr, sync::atomic::Ordering::Relaxed};
12use urlencoding::decode;
13
14use crate::{ApiErrorStatus, Auth, InternalState, BASE_PATH};
15
16fn is_a_websocket_uri(uri: &OriginalUri) -> bool {
17 const SESSION_PATH: &str = const_format::formatcp!("{BASE_PATH}/session/websocket");
18
19 uri.path().starts_with(SESSION_PATH)
20}
21
22pub(crate) async fn cap_websockets(
23 State(state): State<InternalState>,
24 uri: OriginalUri,
25 _headers: HeaderMap,
26 request: Request,
27 next: Next,
28) -> impl IntoResponse {
29 let max_websocket_count = std::env::var("HOPR_INTERNAL_REST_API_MAX_CONCURRENT_WEBSOCKET_COUNT")
30 .and_then(|v| v.parse::<u16>().map_err(|_e| std::env::VarError::NotPresent))
31 .unwrap_or(10);
32
33 if is_a_websocket_uri(&uri) {
34 let ws_count = state.websocket_active_count;
35
36 if ws_count.fetch_add(1, Relaxed) > max_websocket_count {
37 ws_count.fetch_sub(1, Relaxed);
38
39 return (
40 StatusCode::TOO_MANY_REQUESTS,
41 ApiErrorStatus::TooManyOpenWebsocketConnections,
42 )
43 .into_response();
44 }
45 }
46
47 next.run(request).await
49}
50
51pub(crate) async fn authenticate(
52 State(state): State<InternalState>,
53 uri: OriginalUri,
54 headers: HeaderMap,
55 request: Request,
56 next: Next,
57) -> impl IntoResponse {
58 let auth = state.auth.clone();
59
60 let x_auth_header = HeaderName::from_str("x-auth-token").expect("Invalid header name: x-auth-token");
61 let websocket_proto_header =
62 HeaderName::from_str("Sec-Websocket-Protocol").expect("Invalid header name: Sec-Websocket-Protocol");
63
64 let is_authorized = match auth.as_ref() {
65 Auth::Token(expected_token) => {
66 let auth_headers = headers
67 .iter()
68 .filter_map(|(n, v)| {
69 (AUTHORIZATION.eq(n) || x_auth_header.eq(n) || websocket_proto_header.eq(n))
70 .then_some((n, v.to_str().expect("Invalid header value")))
71 })
72 .collect::<Vec<_>>();
73
74 let is_ws_auth = if is_a_websocket_uri(&uri) {
75 uri.query()
76 .map(|q| {
77 if q.len() > 2048 {
79 return false;
80 }
81 match decode(q) {
82 Ok(decoded) => decoded.into_owned().contains(&format!("apiToken={}", expected_token)),
83 Err(_) => false,
84 }
85 })
86 .unwrap_or(false)
87 } else {
88 false
89 };
90 (!auth_headers.is_empty()
92 && (auth_headers.contains(&(&AUTHORIZATION, &format!("Bearer {}", expected_token)))
93 || auth_headers.contains(&(&x_auth_header, expected_token)))
94 )
95 || is_ws_auth
98 }
99 Auth::None => true,
100 };
101
102 if !is_authorized {
103 return (StatusCode::UNAUTHORIZED, ApiErrorStatus::Unauthorized).into_response();
104 }
105
106 next.run(request).await
108}