nautilus_indicators/average/
sma.rs1use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::{
20 data::{Bar, QuoteTick, TradeTick},
21 enums::PriceType,
22};
23
24use crate::indicator::{Indicator, MovingAverage};
25
26const MAX_PERIOD: usize = 1_024;
27
28#[repr(C)]
29#[derive(Debug)]
30#[cfg_attr(
31 feature = "python",
32 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.indicators")
33)]
34pub struct SimpleMovingAverage {
35 pub period: usize,
36 pub price_type: PriceType,
37 pub value: f64,
38 sum: f64,
39 pub count: usize,
40 buf: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
41 pub initialized: bool,
42}
43
44impl Display for SimpleMovingAverage {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 write!(f, "{}({})", self.name(), self.period)
47 }
48}
49
50impl Indicator for SimpleMovingAverage {
51 fn name(&self) -> String {
52 stringify!(SimpleMovingAverage).into()
53 }
54
55 fn has_inputs(&self) -> bool {
56 self.count > 0
57 }
58
59 fn initialized(&self) -> bool {
60 self.initialized
61 }
62
63 fn handle_quote(&mut self, quote: &QuoteTick) {
64 self.process_raw(quote.extract_price(self.price_type).into());
65 }
66
67 fn handle_trade(&mut self, trade: &TradeTick) {
68 self.process_raw(trade.price.into());
69 }
70
71 fn handle_bar(&mut self, bar: &Bar) {
72 self.process_raw(bar.close.into());
73 }
74
75 fn reset(&mut self) {
76 self.value = 0.0;
77 self.sum = 0.0;
78 self.count = 0;
79 self.buf.clear();
80 self.initialized = false;
81 }
82}
83
84impl MovingAverage for SimpleMovingAverage {
85 fn value(&self) -> f64 {
86 self.value
87 }
88
89 fn count(&self) -> usize {
90 self.count
91 }
92
93 fn update_raw(&mut self, value: f64) {
94 self.process_raw(value);
95 }
96}
97
98impl SimpleMovingAverage {
99 #[must_use]
105 pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
106 assert!(period > 0, "SimpleMovingAverage: period must be > 0");
107 assert!(
108 period <= MAX_PERIOD,
109 "SimpleMovingAverage: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
110 );
111
112 Self {
113 period,
114 price_type: price_type.unwrap_or(PriceType::Last),
115 value: 0.0,
116 sum: 0.0,
117 count: 0,
118 buf: ArrayDeque::new(),
119 initialized: false,
120 }
121 }
122
123 fn process_raw(&mut self, price: f64) {
124 if self.count == self.period {
125 if let Some(oldest) = self.buf.pop_front() {
126 self.sum -= oldest;
127 }
128 } else {
129 self.count += 1;
130 }
131
132 let _ = self.buf.push_back(price);
133 self.sum += price;
134
135 self.value = self.sum / self.count as f64;
136 self.initialized = self.count >= self.period;
137 }
138}
139
140#[cfg(test)]
144mod tests {
145 use arraydeque::{ArrayDeque, Wrapping};
146 use nautilus_model::{
147 data::{QuoteTick, TradeTick},
148 enums::PriceType,
149 };
150 use rstest::rstest;
151
152 use super::MAX_PERIOD;
153 use crate::{
154 average::sma::SimpleMovingAverage,
155 indicator::{Indicator, MovingAverage},
156 stubs::*,
157 };
158
159 #[rstest]
160 fn sma_initialized_state(indicator_sma_10: SimpleMovingAverage) {
161 let display_str = format!("{indicator_sma_10}");
162 assert_eq!(display_str, "SimpleMovingAverage(10)");
163 assert_eq!(indicator_sma_10.period, 10);
164 assert_eq!(indicator_sma_10.price_type, PriceType::Mid);
165 assert_eq!(indicator_sma_10.value, 0.0);
166 assert_eq!(indicator_sma_10.count, 0);
167 assert!(!indicator_sma_10.initialized());
168 assert!(!indicator_sma_10.has_inputs());
169 }
170
171 #[rstest]
172 fn sma_update_raw_exact_period(indicator_sma_10: SimpleMovingAverage) {
173 let mut sma = indicator_sma_10;
174 for i in 1..=10 {
175 sma.update_raw(f64::from(i));
176 }
177 assert!(sma.has_inputs());
178 assert!(sma.initialized());
179 assert_eq!(sma.count, 10);
180 assert_eq!(sma.value, 5.5);
181 }
182
183 #[rstest]
184 fn sma_reset_smoke(indicator_sma_10: SimpleMovingAverage) {
185 let mut sma = indicator_sma_10;
186 sma.update_raw(1.0);
187 assert_eq!(sma.count, 1);
188 sma.reset();
189 assert_eq!(sma.count, 0);
190 assert_eq!(sma.value, 0.0);
191 assert!(!sma.initialized());
192 }
193
194 #[rstest]
195 fn sma_handle_single_quote(indicator_sma_10: SimpleMovingAverage, stub_quote: QuoteTick) {
196 let mut sma = indicator_sma_10;
197 sma.handle_quote(&stub_quote);
198 assert_eq!(sma.count, 1);
199 assert_eq!(sma.value, 1501.0);
200 }
201
202 #[rstest]
203 fn sma_handle_multiple_quotes(indicator_sma_10: SimpleMovingAverage) {
204 let mut sma = indicator_sma_10;
205 let q1 = stub_quote("1500.0", "1502.0");
206 let q2 = stub_quote("1502.0", "1504.0");
207
208 sma.handle_quote(&q1);
209 sma.handle_quote(&q2);
210 assert_eq!(sma.count, 2);
211 assert_eq!(sma.value, 1502.0);
212 }
213
214 #[rstest]
215 fn sma_handle_trade(indicator_sma_10: SimpleMovingAverage, stub_trade: TradeTick) {
216 let mut sma = indicator_sma_10;
217 sma.handle_trade(&stub_trade);
218 assert_eq!(sma.count, 1);
219 assert_eq!(sma.value, 1500.0);
220 }
221
222 #[rstest]
223 #[case(1)]
224 #[case(3)]
225 #[case(5)]
226 #[case(16)]
227 fn count_progression_respects_period(#[case] period: usize) {
228 let mut sma = SimpleMovingAverage::new(period, None);
229
230 for i in 0..(period * 3) {
231 sma.update_raw(i as f64);
232
233 assert!(
234 sma.count() <= period,
235 "period={period}, step={i}, count={}",
236 sma.count()
237 );
238
239 let expected = usize::min(i + 1, period);
240 assert_eq!(
241 sma.count(),
242 expected,
243 "period={period}, step={i}, expected={expected}, got={}",
244 sma.count()
245 );
246 }
247 }
248
249 #[rstest]
250 #[case(1)]
251 #[case(4)]
252 #[case(10)]
253 fn count_after_reset_is_zero(#[case] period: usize) {
254 let mut sma = SimpleMovingAverage::new(period, None);
255
256 for i in 0..(period + 2) {
257 sma.update_raw(i as f64);
258 }
259 assert_eq!(sma.count(), period, "pre-reset saturation failed");
260
261 sma.reset();
262 assert_eq!(sma.count(), 0, "count not reset to zero");
263 assert_eq!(sma.value(), 0.0, "value not reset to zero");
264 assert!(!sma.initialized(), "initialized flag not cleared");
265 }
266
267 #[rstest]
268 fn count_edge_case_period_one() {
269 let mut sma = SimpleMovingAverage::new(1, None);
270
271 sma.update_raw(10.0);
272 assert_eq!(sma.count(), 1);
273 assert_eq!(sma.value(), 10.0);
274
275 sma.update_raw(20.0);
276 assert_eq!(sma.count(), 1, "count exceeded 1 with period==1");
277 assert_eq!(sma.value(), 20.0, "value not equal to latest price");
278 }
279
280 #[rstest]
281 fn sliding_window_correctness() {
282 let mut sma = SimpleMovingAverage::new(3, None);
283
284 let prices = [1.0, 2.0, 3.0, 4.0, 5.0];
285 let expect_avg = [1.0, 1.5, 2.0, 3.0, 4.0];
286
287 for (i, &p) in prices.iter().enumerate() {
288 sma.update_raw(p);
289 assert!(
290 (sma.value() - expect_avg[i]).abs() < 1e-9,
291 "step {i}: expected {}, got {}",
292 expect_avg[i],
293 sma.value()
294 );
295 }
296 }
297
298 #[rstest]
299 #[case(2)]
300 #[case(6)]
301 fn initialized_transitions_with_count(#[case] period: usize) {
302 let mut sma = SimpleMovingAverage::new(period, None);
303
304 for i in 0..(period - 1) {
305 sma.update_raw(i as f64);
306 assert!(
307 !sma.initialized(),
308 "initialized early at i={i} (period={period})"
309 );
310 }
311
312 sma.update_raw(42.0);
313 assert_eq!(sma.count(), period);
314 assert!(sma.initialized(), "initialized flag not set at period");
315 }
316
317 #[rstest]
318 #[should_panic(expected = "period must be > 0")]
319 fn sma_new_with_zero_period_panics() {
320 let _ = SimpleMovingAverage::new(0, None);
321 }
322
323 #[rstest]
324 fn sma_rolling_mean_exact_values() {
325 let mut sma = SimpleMovingAverage::new(3, None);
326 let inputs = [1.0, 2.0, 3.0, 4.0, 5.0];
327 let expected = [1.0, 1.5, 2.0, 3.0, 4.0];
328
329 for (&price, &exp_mean) in inputs.iter().zip(expected.iter()) {
330 sma.update_raw(price);
331 assert!(
332 (sma.value() - exp_mean).abs() < 1e-12,
333 "input={price}, expected={exp_mean}, got={}",
334 sma.value()
335 );
336 }
337 }
338
339 #[rstest]
340 fn sma_matches_reference_implementation() {
341 const PERIOD: usize = 5;
342 let mut sma = SimpleMovingAverage::new(PERIOD, None);
343 let mut window: ArrayDeque<f64, PERIOD, Wrapping> = ArrayDeque::new();
344
345 for step in 0..20 {
346 let price = f64::from(step) * 10.0;
347 sma.update_raw(price);
348
349 if window.len() == PERIOD {
350 window.pop_front();
351 }
352 let _ = window.push_back(price);
353
354 let ref_mean: f64 = window.iter().sum::<f64>() / window.len() as f64;
355 assert!(
356 (sma.value() - ref_mean).abs() < 1e-12,
357 "step={step}, expected={ref_mean}, got={}",
358 sma.value()
359 );
360 }
361 }
362
363 #[rstest]
364 #[case(f64::NAN)]
365 #[case(f64::INFINITY)]
366 #[case(f64::NEG_INFINITY)]
367 fn sma_handles_bad_floats(#[case] bad: f64) {
368 let mut sma = SimpleMovingAverage::new(3, None);
369 sma.update_raw(1.0);
370 sma.update_raw(bad);
371 sma.update_raw(3.0);
372 assert!(
373 sma.value().is_nan() || !sma.value().is_finite(),
374 "bad float not propagated"
375 );
376 }
377
378 #[rstest]
379 fn deque_and_count_always_match() {
380 const PERIOD: usize = 8;
381 let mut sma = SimpleMovingAverage::new(PERIOD, None);
382 for i in 0..50 {
383 sma.update_raw(f64::from(i));
384 assert!(
385 sma.buf.len() == sma.count,
386 "buf.len() != count at step {i}: {} != {}",
387 sma.buf.len(),
388 sma.count
389 );
390 }
391 }
392
393 #[rstest]
394 fn sma_multiple_resets() {
395 let mut sma = SimpleMovingAverage::new(4, None);
396 for cycle in 0..5 {
397 for x in 0..4 {
398 sma.update_raw(f64::from(x));
399 }
400 assert!(sma.initialized(), "cycle {cycle}: not initialized");
401 sma.reset();
402 assert_eq!(sma.count(), 0);
403 assert_eq!(sma.value(), 0.0);
404 assert!(!sma.initialized());
405 }
406 }
407
408 #[rstest]
409 fn sma_buffer_never_exceeds_capacity() {
410 const PERIOD: usize = MAX_PERIOD;
411 let mut sma = super::SimpleMovingAverage::new(PERIOD, None);
412
413 for i in 0..(PERIOD * 2) {
414 sma.update_raw(i as f64);
415
416 assert!(
417 sma.buf.len() <= PERIOD,
418 "step {i}: buf.len()={}, exceeds PERIOD={PERIOD}",
419 sma.buf.len(),
420 );
421 }
422 assert!(
423 sma.buf.is_full(),
424 "buffer not reported as full after saturation"
425 );
426 assert_eq!(
427 sma.count(),
428 PERIOD,
429 "count diverged from logical window length"
430 );
431 }
432
433 #[rstest]
434 fn sma_deque_eviction_order() {
435 let mut sma = super::SimpleMovingAverage::new(3, None);
436
437 sma.update_raw(1.0);
438 sma.update_raw(2.0);
439 sma.update_raw(3.0);
440 sma.update_raw(4.0);
441
442 assert_eq!(sma.buf.front().copied(), Some(2.0), "oldest element wrong");
443 assert_eq!(sma.buf.back().copied(), Some(4.0), "newest element wrong");
444
445 assert!(
446 (sma.value() - 3.0).abs() < 1e-12,
447 "unexpected mean after eviction: {}",
448 sma.value()
449 );
450 }
451
452 #[rstest]
453 fn sma_sum_consistent_with_buffer() {
454 const PERIOD: usize = 7;
455 let mut sma = super::SimpleMovingAverage::new(PERIOD, None);
456
457 for i in 0..40 {
458 sma.update_raw(f64::from(i));
459
460 let deque_sum: f64 = sma.buf.iter().copied().sum();
461 assert!(
462 (sma.sum - deque_sum).abs() < 1e-12,
463 "step {i}: internal sum={} differs from buf sum={}",
464 sma.sum,
465 deque_sum
466 );
467 }
468 }
469}