1use std::time::Duration;
25
26use anyhow;
27use nautilus_core::correctness::{check_in_range_inclusive_f64, check_predicate_true};
28use rand::Rng;
29
30#[derive(Clone, Debug)]
31pub struct ExponentialBackoff {
32 delay_initial: Duration,
34 delay_max: Duration,
36 delay_current: Duration,
38 factor: f64,
40 jitter_ms: u64,
42 immediate_reconnect: bool,
44 immediate_reconnect_original: bool,
46}
47
48impl ExponentialBackoff {
56 pub fn new(
65 delay_initial: Duration,
66 delay_max: Duration,
67 factor: f64,
68 jitter_ms: u64,
69 immediate_first: bool,
70 ) -> anyhow::Result<Self> {
71 check_predicate_true(!delay_initial.is_zero(), "delay_initial must be non-zero")?;
72 check_predicate_true(
73 delay_max >= delay_initial,
74 "delay_max must be >= delay_initial",
75 )?;
76 check_in_range_inclusive_f64(factor, 1.0, 100.0, "factor")?;
77
78 Ok(Self {
79 delay_initial,
80 delay_max,
81 delay_current: delay_initial,
82 factor,
83 jitter_ms,
84 immediate_reconnect: immediate_first,
85 immediate_reconnect_original: immediate_first,
86 })
87 }
88
89 pub fn next_duration(&mut self) -> Duration {
95 if self.immediate_reconnect && self.delay_current == self.delay_initial {
96 self.immediate_reconnect = false;
97 return Duration::ZERO;
98 }
99
100 let jitter = rand::rng().random_range(0..=self.jitter_ms);
102 let delay_with_jitter = self.delay_current + Duration::from_millis(jitter);
103
104 let current_nanos = self.delay_current.as_nanos();
106 let max_nanos = self.delay_max.as_nanos() as u64;
107 let next_nanos = (current_nanos as f64 * self.factor) as u64;
108 self.delay_current = Duration::from_nanos(std::cmp::min(next_nanos, max_nanos));
109
110 delay_with_jitter
111 }
112
113 pub const fn reset(&mut self) {
115 self.delay_current = self.delay_initial;
116 self.immediate_reconnect = self.immediate_reconnect_original;
117 }
118
119 #[must_use]
123 pub const fn current_delay(&self) -> Duration {
124 self.delay_current
125 }
126}
127
128#[cfg(test)]
132mod tests {
133 use std::time::Duration;
134
135 use rstest::rstest;
136
137 use super::*;
138
139 #[rstest]
140 fn test_no_jitter_exponential_growth() {
141 let initial = Duration::from_millis(100);
142 let max = Duration::from_millis(1600);
143 let factor = 2.0;
144 let jitter = 0;
145 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
146
147 let d1 = backoff.next_duration();
149 assert_eq!(d1, Duration::from_millis(100));
150
151 let d2 = backoff.next_duration();
153 assert_eq!(d2, Duration::from_millis(200));
154
155 let d3 = backoff.next_duration();
157 assert_eq!(d3, Duration::from_millis(400));
158
159 let d4 = backoff.next_duration();
161 assert_eq!(d4, Duration::from_millis(800));
162
163 let d5 = backoff.next_duration();
165 assert_eq!(d5, Duration::from_millis(1600));
166
167 let d6 = backoff.next_duration();
169 assert_eq!(d6, Duration::from_millis(1600));
170 }
171
172 #[rstest]
173 fn test_reset() {
174 let initial = Duration::from_millis(100);
175 let max = Duration::from_millis(1600);
176 let factor = 2.0;
177 let jitter = 0;
178 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
179
180 let _ = backoff.next_duration(); backoff.reset();
183 let d = backoff.next_duration();
184 assert_eq!(d, Duration::from_millis(100));
186 }
187
188 #[rstest]
189 fn test_jitter_within_bounds() {
190 let initial = Duration::from_millis(100);
191 let max = Duration::from_millis(1000);
192 let factor = 2.0;
193 let jitter = 50;
194 for _ in 0..10 {
196 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
197 let base = backoff.delay_current;
199 let delay = backoff.next_duration();
200 let min_expected = base;
202 let max_expected = base + Duration::from_millis(jitter);
203 assert!(
204 delay >= min_expected,
205 "Delay {delay:?} is less than expected minimum {min_expected:?}"
206 );
207 assert!(
208 delay <= max_expected,
209 "Delay {delay:?} exceeds expected maximum {max_expected:?}"
210 );
211 }
212 }
213
214 #[rstest]
215 fn test_factor_less_than_two() {
216 let initial = Duration::from_millis(100);
217 let max = Duration::from_millis(200);
218 let factor = 1.5;
219 let jitter = 0;
220 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
221
222 let d1 = backoff.next_duration();
224 assert_eq!(d1, Duration::from_millis(100));
225
226 let d2 = backoff.next_duration();
228 assert_eq!(d2, Duration::from_millis(150));
229
230 let d3 = backoff.next_duration();
232 assert_eq!(d3, Duration::from_millis(200));
233
234 let d4 = backoff.next_duration();
236 assert_eq!(d4, Duration::from_millis(200));
237 }
238
239 #[rstest]
240 fn test_max_delay_is_respected() {
241 let initial = Duration::from_millis(500);
242 let max = Duration::from_millis(1000);
243 let factor = 3.0;
244 let jitter = 0;
245 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
246
247 let d1 = backoff.next_duration();
249 assert_eq!(d1, Duration::from_millis(500));
250
251 let d2 = backoff.next_duration();
253 assert_eq!(d2, Duration::from_millis(1000));
254
255 let d3 = backoff.next_duration();
257 assert_eq!(d3, Duration::from_millis(1000));
258 }
259
260 #[rstest]
261 fn test_current_delay_getter() {
262 let initial = Duration::from_millis(100);
263 let max = Duration::from_millis(1600);
264 let factor = 2.0;
265 let jitter = 0;
266 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
267
268 assert_eq!(backoff.current_delay(), initial);
269
270 let _ = backoff.next_duration();
271 assert_eq!(backoff.current_delay(), Duration::from_millis(200));
272
273 let _ = backoff.next_duration();
274 assert_eq!(backoff.current_delay(), Duration::from_millis(400));
275
276 backoff.reset();
277 assert_eq!(backoff.current_delay(), initial);
278 }
279
280 #[rstest]
281 fn test_validation_zero_initial_delay() {
282 let result =
283 ExponentialBackoff::new(Duration::ZERO, Duration::from_millis(1000), 2.0, 0, false);
284 assert!(result.is_err());
285 assert!(
286 result
287 .unwrap_err()
288 .to_string()
289 .contains("delay_initial must be non-zero")
290 );
291 }
292
293 #[rstest]
294 fn test_validation_max_less_than_initial() {
295 let result = ExponentialBackoff::new(
296 Duration::from_millis(1000),
297 Duration::from_millis(500),
298 2.0,
299 0,
300 false,
301 );
302 assert!(result.is_err());
303 assert!(
304 result
305 .unwrap_err()
306 .to_string()
307 .contains("delay_max must be >= delay_initial")
308 );
309 }
310
311 #[rstest]
312 fn test_validation_factor_too_small() {
313 let result = ExponentialBackoff::new(
314 Duration::from_millis(100),
315 Duration::from_millis(1000),
316 0.5,
317 0,
318 false,
319 );
320 assert!(result.is_err());
321 assert!(result.unwrap_err().to_string().contains("factor"));
322 }
323
324 #[rstest]
325 fn test_validation_factor_too_large() {
326 let result = ExponentialBackoff::new(
327 Duration::from_millis(100),
328 Duration::from_millis(1000),
329 150.0,
330 0,
331 false,
332 );
333 assert!(result.is_err());
334 assert!(result.unwrap_err().to_string().contains("factor"));
335 }
336
337 #[rstest]
338 fn test_immediate_first() {
339 let initial = Duration::from_millis(100);
340 let max = Duration::from_millis(1600);
341 let factor = 2.0;
342 let jitter = 0;
343 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
344
345 let d1 = backoff.next_duration();
347 assert_eq!(
348 d1,
349 Duration::ZERO,
350 "Expected immediate reconnect (zero delay) on first call"
351 );
352
353 let d2 = backoff.next_duration();
355 assert_eq!(
356 d2, initial,
357 "Expected the delay to be the initial delay after immediate reconnect"
358 );
359
360 let d3 = backoff.next_duration();
362 let expected = initial * 2; assert_eq!(
364 d3, expected,
365 "Expected exponential growth from the initial delay"
366 );
367 }
368
369 #[rstest]
370 fn test_reset_restores_immediate_first() {
371 let initial = Duration::from_millis(100);
372 let max = Duration::from_millis(1600);
373 let factor = 2.0;
374 let jitter = 0;
375 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
376
377 let d1 = backoff.next_duration();
379 assert_eq!(d1, Duration::ZERO);
380
381 let d2 = backoff.next_duration();
383 assert_eq!(d2, initial);
384
385 backoff.reset();
387 let d3 = backoff.next_duration();
388 assert_eq!(
389 d3,
390 Duration::ZERO,
391 "Reset should restore immediate_first behavior"
392 );
393 }
394}