hopr_transport_session/balancer/
pid.rs1use std::str::FromStr;
2
3use anyhow::anyhow;
4use pid::Pid;
5
6use crate::{
7 balancer::{BalancerControllerBounds, SurbBalancerController},
8 errors,
9 errors::SessionManagerError,
10};
11
12#[derive(Clone, Copy, Debug, PartialEq)]
14pub struct PidControllerGains(f64, f64, f64);
15
16impl PidControllerGains {
17 pub fn new(p: f64, i: f64, d: f64) -> errors::Result<Self> {
19 if p.is_finite() && i.is_finite() && d.is_finite() {
20 Ok(Self(p, i, d))
21 } else {
22 Err(SessionManagerError::other(anyhow!("gains must be finite")).into())
23 }
24 }
25
26 pub fn from_env_or_default() -> Self {
28 let default = Self::default();
29 Self(
30 std::env::var("HOPR_BALANCER_PID_P_GAIN")
31 .ok()
32 .and_then(|v| f64::from_str(&v).ok())
33 .unwrap_or(default.0),
34 std::env::var("HOPR_BALANCER_PID_I_GAIN")
35 .ok()
36 .and_then(|v| f64::from_str(&v).ok())
37 .unwrap_or(default.1),
38 std::env::var("HOPR_BALANCER_PID_D_GAIN")
39 .ok()
40 .and_then(|v| f64::from_str(&v).ok())
41 .unwrap_or(default.2),
42 )
43 }
44
45 #[inline]
47 pub fn p(&self) -> f64 {
48 self.0
49 }
50
51 #[inline]
53 pub fn i(&self) -> f64 {
54 self.1
55 }
56
57 #[inline]
59 pub fn d(&self) -> f64 {
60 self.2
61 }
62}
63
64impl Eq for PidControllerGains {}
66
67const DEFAULT_P_GAIN: f64 = 0.6;
70const DEFAULT_I_GAIN: f64 = 0.7;
71const DEFAULT_D_GAIN: f64 = 0.2;
72
73impl Default for PidControllerGains {
74 fn default() -> Self {
75 Self(DEFAULT_P_GAIN, DEFAULT_I_GAIN, DEFAULT_D_GAIN)
76 }
77}
78
79impl TryFrom<(f64, f64, f64)> for PidControllerGains {
80 type Error = errors::TransportSessionError;
81
82 fn try_from(value: (f64, f64, f64)) -> Result<Self, Self::Error> {
83 Self::new(value.0, value.1, value.2)
84 }
85}
86
87#[derive(Clone, Copy, Debug)]
89pub struct PidBalancerController(Pid<f64>);
90
91impl PidBalancerController {
92 pub fn new(setpoint: u64, output_limit: u64, gains: PidControllerGains) -> Self {
94 let mut pid = Pid::new(setpoint as f64, output_limit as f64);
95 pid.p(gains.p(), output_limit as f64);
96 pid.i(gains.i(), output_limit as f64);
97 pid.d(gains.d(), output_limit as f64);
98 Self(pid)
99 }
100
101 pub fn from_gains(gains: PidControllerGains) -> Self {
106 Self::new(0, 0, gains)
107 }
108}
109
110impl Default for PidBalancerController {
111 fn default() -> Self {
113 Self::new(0, 0, PidControllerGains::default())
114 }
115}
116
117impl SurbBalancerController for PidBalancerController {
118 fn bounds(&self) -> BalancerControllerBounds {
119 BalancerControllerBounds::new(self.0.setpoint as u64, self.0.output_limit as u64)
120 }
121
122 fn set_target_and_limit(&mut self, bounds: BalancerControllerBounds) {
123 let mut pid = Pid::new(bounds.target() as f64, bounds.output_limit() as f64);
124 pid.p(self.0.kp, bounds.output_limit() as f64);
125 pid.i(self.0.ki, bounds.output_limit() as f64);
126 pid.d(self.0.kd, bounds.output_limit() as f64);
127 self.0 = pid;
128 }
129
130 fn next_control_output(&mut self, current_buffer_level: u64) -> u64 {
131 self.0.next_control_output(current_buffer_level as f64).output.max(0.0) as u64
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
142 fn gains_default_values_are_stable() {
143 let gains = PidControllerGains::default();
144 insta::assert_yaml_snapshot!((gains.p(), gains.i(), gains.d()));
145 }
146
147 #[test]
148 fn gains_finite_values_are_accepted() -> anyhow::Result<()> {
149 let gains = PidControllerGains::new(1.0, 2.0, 3.0)?;
150 insta::assert_yaml_snapshot!((gains.p(), gains.i(), gains.d()));
151 Ok(())
152 }
153
154 #[test]
155 fn gains_infinity_is_rejected() {
156 assert!(PidControllerGains::new(f64::INFINITY, 0.0, 0.0).is_err());
157 assert!(PidControllerGains::new(0.0, f64::NEG_INFINITY, 0.0).is_err());
158 assert!(PidControllerGains::new(0.0, 0.0, f64::NAN).is_err());
159 }
160
161 #[test]
162 fn gains_try_from_tuple() -> anyhow::Result<()> {
163 let gains = PidControllerGains::try_from((0.5, 0.3, 0.1))?;
164 insta::assert_yaml_snapshot!((gains.p(), gains.i(), gains.d()));
165 Ok(())
166 }
167
168 #[test]
169 fn gains_try_from_tuple_with_nan_fails() {
170 assert!(PidControllerGains::try_from((f64::NAN, 0.0, 0.0)).is_err());
171 }
172
173 #[test]
174 fn gains_eq_works() -> anyhow::Result<()> {
175 let a = PidControllerGains::new(1.0, 2.0, 3.0)?;
176 let b = PidControllerGains::new(1.0, 2.0, 3.0)?;
177 assert_eq!(a, b);
178 Ok(())
179 }
180
181 #[test]
184 fn controller_default_has_zero_bounds() {
185 let ctrl = PidBalancerController::default();
186 assert_eq!(ctrl.bounds().unzip(), (0, 0));
187 }
188
189 #[test]
190 fn controller_new_stores_bounds() {
191 let gains = PidControllerGains::default();
192 let ctrl = PidBalancerController::new(100, 50, gains);
193 assert_eq!(ctrl.bounds().unzip(), (100, 50));
194 }
195
196 #[test]
197 fn controller_set_target_and_limit_updates_bounds() {
198 let mut ctrl = PidBalancerController::default();
199 ctrl.set_target_and_limit(BalancerControllerBounds::new(200, 100));
200 assert_eq!(ctrl.bounds().unzip(), (200, 100));
201 }
202
203 #[test]
204 fn controller_step_response_snapshot() {
205 let gains = PidControllerGains::default();
207 let mut ctrl = PidBalancerController::new(100, 200, gains);
208
209 let outputs: Vec<u64> = (0..10).map(|_| ctrl.next_control_output(0)).collect();
210 insta::assert_yaml_snapshot!(outputs);
211 }
212
213 #[test]
214 fn controller_at_setpoint_outputs_zero_or_near_zero() {
215 let gains = PidControllerGains::default();
216 let mut ctrl = PidBalancerController::new(100, 200, gains);
217
218 let output = ctrl.next_control_output(100);
220 assert_eq!(output, 0);
222 }
223
224 #[test]
225 fn controller_above_setpoint_clamps_to_zero() {
226 let gains = PidControllerGains::default();
227 let mut ctrl = PidBalancerController::new(100, 200, gains);
228
229 let output = ctrl.next_control_output(200);
231 assert_eq!(output, 0);
232 }
233
234 #[test]
235 fn controller_convergence_from_empty_buffer() {
236 let gains = PidControllerGains::default();
238 let mut ctrl = PidBalancerController::new(100, 200, gains);
239
240 let mut buffer: f64 = 0.0;
241 let mut history = Vec::new();
242
243 for _ in 0..20 {
244 let output = ctrl.next_control_output(buffer as u64);
245 buffer += output as f64;
246 buffer = buffer.min(200.0); history.push(buffer as u64);
248 }
249
250 insta::assert_yaml_snapshot!(history);
251 }
252
253 #[test]
254 fn controller_from_gains_uses_defaults_for_bounds() {
255 let gains = PidControllerGains::default();
256 let ctrl = PidBalancerController::from_gains(gains);
257 assert_eq!(ctrl.bounds().unzip(), (0, 0));
258 }
259}