hoprd_api/
preconditions.rs

1use axum::{
2    extract::{OriginalUri, Request, State},
3    http::{
4        header::{HeaderName, AUTHORIZATION},
5        status::StatusCode,
6        HeaderMap,
7    },
8    middleware::Next,
9    response::IntoResponse,
10};
11use std::{str::FromStr, sync::atomic::Ordering::Relaxed};
12use urlencoding::decode;
13
14use crate::{ApiErrorStatus, Auth, InternalState, BASE_PATH};
15
16fn is_a_websocket_uri(uri: &OriginalUri) -> bool {
17    const SESSION_PATH: &str = const_format::formatcp!("{BASE_PATH}/session/websocket");
18
19    uri.path().starts_with(SESSION_PATH)
20}
21
22pub(crate) async fn cap_websockets(
23    State(state): State<InternalState>,
24    uri: OriginalUri,
25    _headers: HeaderMap,
26    request: Request,
27    next: Next,
28) -> impl IntoResponse {
29    let max_websocket_count = std::env::var("HOPR_INTERNAL_REST_API_MAX_CONCURRENT_WEBSOCKET_COUNT")
30        .and_then(|v| v.parse::<u16>().map_err(|_e| std::env::VarError::NotPresent))
31        .unwrap_or(10);
32
33    if is_a_websocket_uri(&uri) {
34        let ws_count = state.websocket_active_count;
35
36        if ws_count.fetch_add(1, Relaxed) > max_websocket_count {
37            ws_count.fetch_sub(1, Relaxed);
38
39            return (
40                StatusCode::TOO_MANY_REQUESTS,
41                ApiErrorStatus::TooManyOpenWebsocketConnections,
42            )
43                .into_response();
44        }
45    }
46
47    // Go forward to the next middleware or request handler
48    next.run(request).await
49}
50
51pub(crate) async fn authenticate(
52    State(state): State<InternalState>,
53    uri: OriginalUri,
54    headers: HeaderMap,
55    request: Request,
56    next: Next,
57) -> impl IntoResponse {
58    let auth = state.auth.clone();
59
60    let x_auth_header = HeaderName::from_str("x-auth-token").expect("Invalid header name: x-auth-token");
61    let websocket_proto_header =
62        HeaderName::from_str("Sec-Websocket-Protocol").expect("Invalid header name: Sec-Websocket-Protocol");
63
64    let is_authorized = match auth.as_ref() {
65        Auth::Token(expected_token) => {
66            let auth_headers = headers
67                .iter()
68                .filter_map(|(n, v)| {
69                    (AUTHORIZATION.eq(n) || x_auth_header.eq(n) || websocket_proto_header.eq(n))
70                        .then_some((n, v.to_str().expect("Invalid header value")))
71                })
72                .collect::<Vec<_>>();
73
74            let is_ws_auth = if is_a_websocket_uri(&uri) {
75                uri.query()
76                    .map(|q| {
77                        // Reasonable limit for query string
78                        if q.len() > 2048 {
79                            return false;
80                        }
81                        match decode(q) {
82                            Ok(decoded) => decoded.into_owned().contains(&format!("apiToken={}", expected_token)),
83                            Err(_) => false,
84                        }
85                    })
86                    .unwrap_or(false)
87            } else {
88                false
89            };
90            // Use "Authorization Bearer <token>" and "X-Auth-Token <token>" headers and "Sec-Websocket-Protocol"
91            (!auth_headers.is_empty()
92                    && (auth_headers.contains(&(&AUTHORIZATION, &format!("Bearer {}", expected_token)))
93                        || auth_headers.contains(&(&x_auth_header, expected_token)))
94                )
95                // The following line would never be needed, if the JavaScript browser was able to properly
96                // pass the x-auth-token or Bearer headers.
97                || is_ws_auth
98        }
99        Auth::None => true,
100    };
101
102    if !is_authorized {
103        return (StatusCode::UNAUTHORIZED, ApiErrorStatus::Unauthorized).into_response();
104    }
105
106    // Go forward to the next middleware or request handler
107    next.run(request).await
108}