Skip to main content

hopr_transport_session/balancer/
pid.rs

1use std::str::FromStr;
2
3use anyhow::anyhow;
4use pid::Pid;
5
6use crate::{
7    balancer::{BalancerControllerBounds, SurbBalancerController},
8    errors,
9    errors::SessionManagerError,
10};
11
12/// Carries finite Proportional, Integral and Derivative controller gains for a PID controller.
13#[derive(Clone, Copy, Debug, PartialEq)]
14pub struct PidControllerGains(f64, f64, f64);
15
16impl PidControllerGains {
17    /// Creates PID controller gains, returns an error if the gains are not finite.
18    pub fn new(p: f64, i: f64, d: f64) -> errors::Result<Self> {
19        if p.is_finite() && i.is_finite() && d.is_finite() {
20            Ok(Self(p, i, d))
21        } else {
22            Err(SessionManagerError::other(anyhow!("gains must be finite")).into())
23        }
24    }
25
26    /// Uses PID controller gains from the env variables or uses the defaults if not set.
27    pub fn from_env_or_default() -> Self {
28        let default = Self::default();
29        Self(
30            std::env::var("HOPR_BALANCER_PID_P_GAIN")
31                .ok()
32                .and_then(|v| f64::from_str(&v).ok())
33                .unwrap_or(default.0),
34            std::env::var("HOPR_BALANCER_PID_I_GAIN")
35                .ok()
36                .and_then(|v| f64::from_str(&v).ok())
37                .unwrap_or(default.1),
38            std::env::var("HOPR_BALANCER_PID_D_GAIN")
39                .ok()
40                .and_then(|v| f64::from_str(&v).ok())
41                .unwrap_or(default.2),
42        )
43    }
44
45    /// P gain.
46    #[inline]
47    pub fn p(&self) -> f64 {
48        self.0
49    }
50
51    /// I gain.
52    #[inline]
53    pub fn i(&self) -> f64 {
54        self.1
55    }
56
57    /// D gain.
58    #[inline]
59    pub fn d(&self) -> f64 {
60        self.2
61    }
62}
63
64// Safe to implement Eq, because the floats are finite
65impl Eq for PidControllerGains {}
66
67// Default coefficients for the PID controller
68// This might be tweaked in the future.
69const DEFAULT_P_GAIN: f64 = 0.6;
70const DEFAULT_I_GAIN: f64 = 0.7;
71const DEFAULT_D_GAIN: f64 = 0.2;
72
73impl Default for PidControllerGains {
74    fn default() -> Self {
75        Self(DEFAULT_P_GAIN, DEFAULT_I_GAIN, DEFAULT_D_GAIN)
76    }
77}
78
79impl TryFrom<(f64, f64, f64)> for PidControllerGains {
80    type Error = errors::TransportSessionError;
81
82    fn try_from(value: (f64, f64, f64)) -> Result<Self, Self::Error> {
83        Self::new(value.0, value.1, value.2)
84    }
85}
86
87/// Implementation of [`SurbBalancerController`] using a PID controller.
88#[derive(Clone, Copy, Debug)]
89pub struct PidBalancerController(Pid<f64>);
90
91impl PidBalancerController {
92    /// Creates new instance given the `setpoint`, `output_limit` and PID gains (P,I and D).
93    pub fn new(setpoint: u64, output_limit: u64, gains: PidControllerGains) -> Self {
94        let mut pid = Pid::new(setpoint as f64, output_limit as f64);
95        pid.p(gains.p(), output_limit as f64);
96        pid.i(gains.i(), output_limit as f64);
97        pid.d(gains.d(), output_limit as f64);
98        Self(pid)
99    }
100
101    /// Creates new instance with setpoint and output limit set to 0.
102    ///
103    /// Needs to be [reconfigured](SurbBalancerController::set_target_and_limit) in order to function
104    /// correctly.
105    pub fn from_gains(gains: PidControllerGains) -> Self {
106        Self::new(0, 0, gains)
107    }
108}
109
110impl Default for PidBalancerController {
111    /// The default instance does nothing unless [reconfigured](SurbBalancerController::set_target_and_limit).
112    fn default() -> Self {
113        Self::new(0, 0, PidControllerGains::default())
114    }
115}
116
117impl SurbBalancerController for PidBalancerController {
118    fn bounds(&self) -> BalancerControllerBounds {
119        BalancerControllerBounds::new(self.0.setpoint as u64, self.0.output_limit as u64)
120    }
121
122    fn set_target_and_limit(&mut self, bounds: BalancerControllerBounds) {
123        let mut pid = Pid::new(bounds.target() as f64, bounds.output_limit() as f64);
124        pid.p(self.0.kp, bounds.output_limit() as f64);
125        pid.i(self.0.ki, bounds.output_limit() as f64);
126        pid.d(self.0.kd, bounds.output_limit() as f64);
127        self.0 = pid;
128    }
129
130    fn next_control_output(&mut self, current_buffer_level: u64) -> u64 {
131        self.0.next_control_output(current_buffer_level as f64).output.max(0.0) as u64
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    // --- PidControllerGains tests ---
140
141    #[test]
142    fn gains_default_values_are_stable() {
143        let gains = PidControllerGains::default();
144        insta::assert_yaml_snapshot!((gains.p(), gains.i(), gains.d()));
145    }
146
147    #[test]
148    fn gains_finite_values_are_accepted() -> anyhow::Result<()> {
149        let gains = PidControllerGains::new(1.0, 2.0, 3.0)?;
150        insta::assert_yaml_snapshot!((gains.p(), gains.i(), gains.d()));
151        Ok(())
152    }
153
154    #[test]
155    fn gains_infinity_is_rejected() {
156        assert!(PidControllerGains::new(f64::INFINITY, 0.0, 0.0).is_err());
157        assert!(PidControllerGains::new(0.0, f64::NEG_INFINITY, 0.0).is_err());
158        assert!(PidControllerGains::new(0.0, 0.0, f64::NAN).is_err());
159    }
160
161    #[test]
162    fn gains_try_from_tuple() -> anyhow::Result<()> {
163        let gains = PidControllerGains::try_from((0.5, 0.3, 0.1))?;
164        insta::assert_yaml_snapshot!((gains.p(), gains.i(), gains.d()));
165        Ok(())
166    }
167
168    #[test]
169    fn gains_try_from_tuple_with_nan_fails() {
170        assert!(PidControllerGains::try_from((f64::NAN, 0.0, 0.0)).is_err());
171    }
172
173    #[test]
174    fn gains_eq_works() -> anyhow::Result<()> {
175        let a = PidControllerGains::new(1.0, 2.0, 3.0)?;
176        let b = PidControllerGains::new(1.0, 2.0, 3.0)?;
177        assert_eq!(a, b);
178        Ok(())
179    }
180
181    // --- PidBalancerController tests ---
182
183    #[test]
184    fn controller_default_has_zero_bounds() {
185        let ctrl = PidBalancerController::default();
186        assert_eq!(ctrl.bounds().unzip(), (0, 0));
187    }
188
189    #[test]
190    fn controller_new_stores_bounds() {
191        let gains = PidControllerGains::default();
192        let ctrl = PidBalancerController::new(100, 50, gains);
193        assert_eq!(ctrl.bounds().unzip(), (100, 50));
194    }
195
196    #[test]
197    fn controller_set_target_and_limit_updates_bounds() {
198        let mut ctrl = PidBalancerController::default();
199        ctrl.set_target_and_limit(BalancerControllerBounds::new(200, 100));
200        assert_eq!(ctrl.bounds().unzip(), (200, 100));
201    }
202
203    #[test]
204    fn controller_step_response_snapshot() {
205        // Apply a step input (setpoint=100, buffer=0) and observe outputs over N steps
206        let gains = PidControllerGains::default();
207        let mut ctrl = PidBalancerController::new(100, 200, gains);
208
209        let outputs: Vec<u64> = (0..10).map(|_| ctrl.next_control_output(0)).collect();
210        insta::assert_yaml_snapshot!(outputs);
211    }
212
213    #[test]
214    fn controller_at_setpoint_outputs_zero_or_near_zero() {
215        let gains = PidControllerGains::default();
216        let mut ctrl = PidBalancerController::new(100, 200, gains);
217
218        // When buffer level equals setpoint, output should converge toward 0
219        let output = ctrl.next_control_output(100);
220        // First call at setpoint: P=0, I=0, D=0 → output=0
221        assert_eq!(output, 0);
222    }
223
224    #[test]
225    fn controller_above_setpoint_clamps_to_zero() {
226        let gains = PidControllerGains::default();
227        let mut ctrl = PidBalancerController::new(100, 200, gains);
228
229        // Buffer well above setpoint — PID error is negative, output clamped to 0
230        let output = ctrl.next_control_output(200);
231        assert_eq!(output, 0);
232    }
233
234    #[test]
235    fn controller_convergence_from_empty_buffer() {
236        // Simulate filling a buffer from 0 toward setpoint=100
237        let gains = PidControllerGains::default();
238        let mut ctrl = PidBalancerController::new(100, 200, gains);
239
240        let mut buffer: f64 = 0.0;
241        let mut history = Vec::new();
242
243        for _ in 0..20 {
244            let output = ctrl.next_control_output(buffer as u64);
245            buffer += output as f64;
246            buffer = buffer.min(200.0); // clamp to limit
247            history.push(buffer as u64);
248        }
249
250        insta::assert_yaml_snapshot!(history);
251    }
252
253    #[test]
254    fn controller_from_gains_uses_defaults_for_bounds() {
255        let gains = PidControllerGains::default();
256        let ctrl = PidBalancerController::from_gains(gains);
257        assert_eq!(ctrl.bounds().unzip(), (0, 0));
258    }
259}