nautilus_network/ratelimiter/
mod.rs1pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25 fmt::Debug,
26 hash::Hash,
27 num::NonZeroU64,
28 sync::atomic::{AtomicU64, Ordering},
29 time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34use tokio::time::sleep;
35
36use self::{
37 clock::{Clock, FakeRelativeClock, MonotonicClock},
38 gcra::{Gcra, NotUntil},
39 nanos::Nanos,
40 quota::Quota,
41};
42
43#[derive(Debug, Default)]
52pub struct InMemoryState(AtomicU64);
53
54impl InMemoryState {
55 pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
61 where
62 F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
63 {
64 let mut prev = self.0.load(Ordering::Acquire);
65 let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
66 while let Ok((result, new_data)) = decision {
67 match self.0.compare_exchange_weak(
68 prev,
69 new_data.into(),
70 Ordering::Release,
71 Ordering::Relaxed,
72 ) {
73 Ok(_) => return Ok(result),
74 Err(next_prev) => prev = next_prev,
75 }
76 decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
77 }
78 decision.map(|(result, _)| result)
81 }
82}
83
84pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
86
87pub trait StateStore {
98 type Key;
100
101 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
118 where
119 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
120}
121
122impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
123 type Key = K;
124
125 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
126 where
127 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
128 {
129 if let Some(v) = self.get(key) {
130 return v.measure_and_replace_one(f);
132 }
133 let entry = self.entry(key.clone()).or_default();
135 (*entry).measure_and_replace_one(f)
136 }
137}
138
139pub struct RateLimiter<K, C>
140where
141 C: Clock,
142{
143 default_gcra: Option<Gcra>,
144 state: DashMapStateStore<K>,
145 gcra: DashMap<K, Gcra>,
146 clock: C,
147 start: C::Instant,
148}
149
150impl<K, C> Debug for RateLimiter<K, C>
151where
152 K: Debug,
153 C: Clock,
154{
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.debug_struct(stringify!(RateLimiter)).finish()
157 }
158}
159
160impl<K> RateLimiter<K, MonotonicClock>
161where
162 K: Eq + Hash,
163{
164 pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
165 let clock = MonotonicClock {};
166 let start = MonotonicClock::now(&clock);
167 let gcra = DashMap::from_iter(keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))));
168 Self {
169 default_gcra: base_quota.map(Gcra::new),
170 state: DashMapStateStore::new(),
171 gcra,
172 clock,
173 start,
174 }
175 }
176}
177
178impl<K> RateLimiter<K, FakeRelativeClock>
179where
180 K: Hash + Eq + Clone,
181{
182 pub fn advance_clock(&self, by: Duration) {
183 self.clock.advance(by);
184 }
185}
186
187impl<K, C> RateLimiter<K, C>
188where
189 K: Hash + Eq + Clone,
190 C: Clock,
191{
192 pub fn add_quota_for_key(&self, key: K, value: Quota) {
193 self.gcra.insert(key, Gcra::new(value));
194 }
195
196 pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
202 match self.gcra.get(key) {
203 Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
204 None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
205 gcra.test_and_update(self.start, key, &self.state, self.clock.now())
206 }),
207 }
208 }
209
210 pub async fn until_key_ready(&self, key: &K) {
211 loop {
212 match self.check_key(key) {
213 Ok(()) => {
214 break;
215 }
216 Err(neg) => {
217 sleep(neg.wait_time_from(self.clock.now())).await;
218 }
219 }
220 }
221 }
222
223 pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
224 let keys = keys.unwrap_or_default();
225 let tasks = keys.iter().map(|key| self.until_key_ready(key));
226
227 futures::stream::iter(tasks)
228 .for_each_concurrent(None, |key_future| async move {
229 key_future.await;
230 })
231 .await;
232 }
233}
234
235#[cfg(test)]
239mod tests {
240 use std::{num::NonZeroU32, time::Duration};
241
242 use dashmap::DashMap;
243 use rstest::rstest;
244
245 use super::{
246 DashMapStateStore, RateLimiter,
247 clock::{Clock, FakeRelativeClock},
248 gcra::Gcra,
249 quota::Quota,
250 };
251
252 fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
253 let clock = FakeRelativeClock::default();
254 let start = clock.now();
255 let gcra = DashMap::new();
256 let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
257 RateLimiter {
258 default_gcra: Some(Gcra::new(base_quota)),
259 state: DashMapStateStore::new(),
260 gcra,
261 clock,
262 start,
263 }
264 }
265
266 #[rstest]
267 fn test_default_quota() {
268 let mock_limiter = initialize_mock_rate_limiter();
269
270 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
272 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
273
274 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
276
277 mock_limiter.advance_clock(Duration::from_secs(1));
279 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
280 }
281
282 #[rstest]
283 fn test_custom_key_quota() {
284 let mock_limiter = initialize_mock_rate_limiter();
285
286 mock_limiter.add_quota_for_key(
288 "custom".to_string(),
289 Quota::per_second(NonZeroU32::new(1).unwrap()),
290 );
291
292 assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
294 assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
295
296 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
298 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
299 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
300 }
301
302 #[rstest]
303 fn test_multiple_keys() {
304 let mock_limiter = initialize_mock_rate_limiter();
305
306 mock_limiter.add_quota_for_key(
307 "key1".to_string(),
308 Quota::per_second(NonZeroU32::new(1).unwrap()),
309 );
310 mock_limiter.add_quota_for_key(
311 "key2".to_string(),
312 Quota::per_second(NonZeroU32::new(3).unwrap()),
313 );
314
315 assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
317 assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
318
319 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
321 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
322 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
323 assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
324 }
325
326 #[rstest]
327 fn test_quota_reset() {
328 let mock_limiter = initialize_mock_rate_limiter();
329
330 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
332 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
333 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
334
335 mock_limiter.advance_clock(Duration::from_millis(499));
337 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
338
339 mock_limiter.advance_clock(Duration::from_millis(501));
341 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
342 }
343
344 #[rstest]
345 fn test_different_quotas() {
346 let mock_limiter = initialize_mock_rate_limiter();
347
348 mock_limiter.add_quota_for_key(
349 "per_second".to_string(),
350 Quota::per_second(NonZeroU32::new(2).unwrap()),
351 );
352 mock_limiter.add_quota_for_key(
353 "per_minute".to_string(),
354 Quota::per_minute(NonZeroU32::new(3).unwrap()),
355 );
356
357 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
359 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
360 assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
361
362 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
364 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
365 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
366 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
367
368 mock_limiter.advance_clock(Duration::from_secs(1));
370 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
371 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
372 }
373
374 #[tokio::test]
375 async fn test_await_keys_ready() {
376 let mock_limiter = initialize_mock_rate_limiter();
377
378 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
380 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
381
382 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
384
385 mock_limiter.advance_clock(Duration::from_secs(1));
387 mock_limiter
388 .await_keys_ready(Some(vec!["default".to_string()]))
389 .await;
390 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
391 }
392}