nautilus_indicators/average/
sma.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::{
20    data::{Bar, QuoteTick, TradeTick},
21    enums::PriceType,
22};
23
24use crate::indicator::{Indicator, MovingAverage};
25
26const MAX_PERIOD: usize = 1_024;
27
28#[repr(C)]
29#[derive(Debug)]
30#[cfg_attr(
31    feature = "python",
32    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.indicators")
33)]
34pub struct SimpleMovingAverage {
35    pub period: usize,
36    pub price_type: PriceType,
37    pub value: f64,
38    sum: f64,
39    pub count: usize,
40    buf: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
41    pub initialized: bool,
42}
43
44impl Display for SimpleMovingAverage {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        write!(f, "{}({})", self.name(), self.period)
47    }
48}
49
50impl Indicator for SimpleMovingAverage {
51    fn name(&self) -> String {
52        stringify!(SimpleMovingAverage).into()
53    }
54
55    fn has_inputs(&self) -> bool {
56        self.count > 0
57    }
58
59    fn initialized(&self) -> bool {
60        self.initialized
61    }
62
63    fn handle_quote(&mut self, quote: &QuoteTick) {
64        self.process_raw(quote.extract_price(self.price_type).into());
65    }
66
67    fn handle_trade(&mut self, trade: &TradeTick) {
68        self.process_raw(trade.price.into());
69    }
70
71    fn handle_bar(&mut self, bar: &Bar) {
72        self.process_raw(bar.close.into());
73    }
74
75    fn reset(&mut self) {
76        self.value = 0.0;
77        self.sum = 0.0;
78        self.count = 0;
79        self.buf.clear();
80        self.initialized = false;
81    }
82}
83
84impl MovingAverage for SimpleMovingAverage {
85    fn value(&self) -> f64 {
86        self.value
87    }
88
89    fn count(&self) -> usize {
90        self.count
91    }
92
93    fn update_raw(&mut self, value: f64) {
94        self.process_raw(value);
95    }
96}
97
98impl SimpleMovingAverage {
99    /// Creates a new [`SimpleMovingAverage`] instance.
100    ///
101    /// # Panics
102    ///
103    /// Panics if `period` is not positive (> 0).
104    #[must_use]
105    pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
106        assert!(period > 0, "SimpleMovingAverage: period must be > 0");
107        assert!(
108            period <= MAX_PERIOD,
109            "SimpleMovingAverage: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
110        );
111
112        Self {
113            period,
114            price_type: price_type.unwrap_or(PriceType::Last),
115            value: 0.0,
116            sum: 0.0,
117            count: 0,
118            buf: ArrayDeque::new(),
119            initialized: false,
120        }
121    }
122
123    fn process_raw(&mut self, price: f64) {
124        if self.count == self.period {
125            if let Some(oldest) = self.buf.pop_front() {
126                self.sum -= oldest;
127            }
128        } else {
129            self.count += 1;
130        }
131
132        let _ = self.buf.push_back(price);
133        self.sum += price;
134
135        self.value = self.sum / self.count as f64;
136        self.initialized = self.count >= self.period;
137    }
138}
139
140////////////////////////////////////////////////////////////////////////////////
141// Tests
142////////////////////////////////////////////////////////////////////////////////
143#[cfg(test)]
144mod tests {
145    use arraydeque::{ArrayDeque, Wrapping};
146    use nautilus_model::{
147        data::{QuoteTick, TradeTick},
148        enums::PriceType,
149    };
150    use rstest::rstest;
151
152    use super::MAX_PERIOD;
153    use crate::{
154        average::sma::SimpleMovingAverage,
155        indicator::{Indicator, MovingAverage},
156        stubs::*,
157    };
158
159    #[rstest]
160    fn sma_initialized_state(indicator_sma_10: SimpleMovingAverage) {
161        let display_str = format!("{indicator_sma_10}");
162        assert_eq!(display_str, "SimpleMovingAverage(10)");
163        assert_eq!(indicator_sma_10.period, 10);
164        assert_eq!(indicator_sma_10.price_type, PriceType::Mid);
165        assert_eq!(indicator_sma_10.value, 0.0);
166        assert_eq!(indicator_sma_10.count, 0);
167        assert!(!indicator_sma_10.initialized());
168        assert!(!indicator_sma_10.has_inputs());
169    }
170
171    #[rstest]
172    fn sma_update_raw_exact_period(indicator_sma_10: SimpleMovingAverage) {
173        let mut sma = indicator_sma_10;
174        for i in 1..=10 {
175            sma.update_raw(f64::from(i));
176        }
177        assert!(sma.has_inputs());
178        assert!(sma.initialized());
179        assert_eq!(sma.count, 10);
180        assert_eq!(sma.value, 5.5);
181    }
182
183    #[rstest]
184    fn sma_reset_smoke(indicator_sma_10: SimpleMovingAverage) {
185        let mut sma = indicator_sma_10;
186        sma.update_raw(1.0);
187        assert_eq!(sma.count, 1);
188        sma.reset();
189        assert_eq!(sma.count, 0);
190        assert_eq!(sma.value, 0.0);
191        assert!(!sma.initialized());
192    }
193
194    #[rstest]
195    fn sma_handle_single_quote(indicator_sma_10: SimpleMovingAverage, stub_quote: QuoteTick) {
196        let mut sma = indicator_sma_10;
197        sma.handle_quote(&stub_quote);
198        assert_eq!(sma.count, 1);
199        assert_eq!(sma.value, 1501.0);
200    }
201
202    #[rstest]
203    fn sma_handle_multiple_quotes(indicator_sma_10: SimpleMovingAverage) {
204        let mut sma = indicator_sma_10;
205        let q1 = stub_quote("1500.0", "1502.0");
206        let q2 = stub_quote("1502.0", "1504.0");
207
208        sma.handle_quote(&q1);
209        sma.handle_quote(&q2);
210        assert_eq!(sma.count, 2);
211        assert_eq!(sma.value, 1502.0);
212    }
213
214    #[rstest]
215    fn sma_handle_trade(indicator_sma_10: SimpleMovingAverage, stub_trade: TradeTick) {
216        let mut sma = indicator_sma_10;
217        sma.handle_trade(&stub_trade);
218        assert_eq!(sma.count, 1);
219        assert_eq!(sma.value, 1500.0);
220    }
221
222    #[rstest]
223    #[case(1)]
224    #[case(3)]
225    #[case(5)]
226    #[case(16)]
227    fn count_progression_respects_period(#[case] period: usize) {
228        let mut sma = SimpleMovingAverage::new(period, None);
229
230        for i in 0..(period * 3) {
231            sma.update_raw(i as f64);
232
233            assert!(
234                sma.count() <= period,
235                "period={period}, step={i}, count={}",
236                sma.count()
237            );
238
239            let expected = usize::min(i + 1, period);
240            assert_eq!(
241                sma.count(),
242                expected,
243                "period={period}, step={i}, expected={expected}, got={}",
244                sma.count()
245            );
246        }
247    }
248
249    #[rstest]
250    #[case(1)]
251    #[case(4)]
252    #[case(10)]
253    fn count_after_reset_is_zero(#[case] period: usize) {
254        let mut sma = SimpleMovingAverage::new(period, None);
255
256        for i in 0..(period + 2) {
257            sma.update_raw(i as f64);
258        }
259        assert_eq!(sma.count(), period, "pre-reset saturation failed");
260
261        sma.reset();
262        assert_eq!(sma.count(), 0, "count not reset to zero");
263        assert_eq!(sma.value(), 0.0, "value not reset to zero");
264        assert!(!sma.initialized(), "initialized flag not cleared");
265    }
266
267    #[rstest]
268    fn count_edge_case_period_one() {
269        let mut sma = SimpleMovingAverage::new(1, None);
270
271        sma.update_raw(10.0);
272        assert_eq!(sma.count(), 1);
273        assert_eq!(sma.value(), 10.0);
274
275        sma.update_raw(20.0);
276        assert_eq!(sma.count(), 1, "count exceeded 1 with period==1");
277        assert_eq!(sma.value(), 20.0, "value not equal to latest price");
278    }
279
280    #[rstest]
281    fn sliding_window_correctness() {
282        let mut sma = SimpleMovingAverage::new(3, None);
283
284        let prices = [1.0, 2.0, 3.0, 4.0, 5.0];
285        let expect_avg = [1.0, 1.5, 2.0, 3.0, 4.0];
286
287        for (i, &p) in prices.iter().enumerate() {
288            sma.update_raw(p);
289            assert!(
290                (sma.value() - expect_avg[i]).abs() < 1e-9,
291                "step {i}: expected {}, got {}",
292                expect_avg[i],
293                sma.value()
294            );
295        }
296    }
297
298    #[rstest]
299    #[case(2)]
300    #[case(6)]
301    fn initialized_transitions_with_count(#[case] period: usize) {
302        let mut sma = SimpleMovingAverage::new(period, None);
303
304        for i in 0..(period - 1) {
305            sma.update_raw(i as f64);
306            assert!(
307                !sma.initialized(),
308                "initialized early at i={i} (period={period})"
309            );
310        }
311
312        sma.update_raw(42.0);
313        assert_eq!(sma.count(), period);
314        assert!(sma.initialized(), "initialized flag not set at period");
315    }
316
317    #[rstest]
318    #[should_panic(expected = "period must be > 0")]
319    fn sma_new_with_zero_period_panics() {
320        let _ = SimpleMovingAverage::new(0, None);
321    }
322
323    #[rstest]
324    fn sma_rolling_mean_exact_values() {
325        let mut sma = SimpleMovingAverage::new(3, None);
326        let inputs = [1.0, 2.0, 3.0, 4.0, 5.0];
327        let expected = [1.0, 1.5, 2.0, 3.0, 4.0];
328
329        for (&price, &exp_mean) in inputs.iter().zip(expected.iter()) {
330            sma.update_raw(price);
331            assert!(
332                (sma.value() - exp_mean).abs() < 1e-12,
333                "input={price}, expected={exp_mean}, got={}",
334                sma.value()
335            );
336        }
337    }
338
339    #[rstest]
340    fn sma_matches_reference_implementation() {
341        const PERIOD: usize = 5;
342        let mut sma = SimpleMovingAverage::new(PERIOD, None);
343        let mut window: ArrayDeque<f64, PERIOD, Wrapping> = ArrayDeque::new();
344
345        for step in 0..20 {
346            let price = f64::from(step) * 10.0;
347            sma.update_raw(price);
348
349            if window.len() == PERIOD {
350                window.pop_front();
351            }
352            let _ = window.push_back(price);
353
354            let ref_mean: f64 = window.iter().sum::<f64>() / window.len() as f64;
355            assert!(
356                (sma.value() - ref_mean).abs() < 1e-12,
357                "step={step}, expected={ref_mean}, got={}",
358                sma.value()
359            );
360        }
361    }
362
363    #[rstest]
364    #[case(f64::NAN)]
365    #[case(f64::INFINITY)]
366    #[case(f64::NEG_INFINITY)]
367    fn sma_handles_bad_floats(#[case] bad: f64) {
368        let mut sma = SimpleMovingAverage::new(3, None);
369        sma.update_raw(1.0);
370        sma.update_raw(bad);
371        sma.update_raw(3.0);
372        assert!(
373            sma.value().is_nan() || !sma.value().is_finite(),
374            "bad float not propagated"
375        );
376    }
377
378    #[rstest]
379    fn deque_and_count_always_match() {
380        const PERIOD: usize = 8;
381        let mut sma = SimpleMovingAverage::new(PERIOD, None);
382        for i in 0..50 {
383            sma.update_raw(f64::from(i));
384            assert!(
385                sma.buf.len() == sma.count,
386                "buf.len() != count at step {i}: {} != {}",
387                sma.buf.len(),
388                sma.count
389            );
390        }
391    }
392
393    #[rstest]
394    fn sma_multiple_resets() {
395        let mut sma = SimpleMovingAverage::new(4, None);
396        for cycle in 0..5 {
397            for x in 0..4 {
398                sma.update_raw(f64::from(x));
399            }
400            assert!(sma.initialized(), "cycle {cycle}: not initialized");
401            sma.reset();
402            assert_eq!(sma.count(), 0);
403            assert_eq!(sma.value(), 0.0);
404            assert!(!sma.initialized());
405        }
406    }
407
408    #[rstest]
409    fn sma_buffer_never_exceeds_capacity() {
410        const PERIOD: usize = MAX_PERIOD;
411        let mut sma = super::SimpleMovingAverage::new(PERIOD, None);
412
413        for i in 0..(PERIOD * 2) {
414            sma.update_raw(i as f64);
415
416            assert!(
417                sma.buf.len() <= PERIOD,
418                "step {i}: buf.len()={}, exceeds PERIOD={PERIOD}",
419                sma.buf.len(),
420            );
421        }
422        assert!(
423            sma.buf.is_full(),
424            "buffer not reported as full after saturation"
425        );
426        assert_eq!(
427            sma.count(),
428            PERIOD,
429            "count diverged from logical window length"
430        );
431    }
432
433    #[rstest]
434    fn sma_deque_eviction_order() {
435        let mut sma = super::SimpleMovingAverage::new(3, None);
436
437        sma.update_raw(1.0);
438        sma.update_raw(2.0);
439        sma.update_raw(3.0);
440        sma.update_raw(4.0);
441
442        assert_eq!(sma.buf.front().copied(), Some(2.0), "oldest element wrong");
443        assert_eq!(sma.buf.back().copied(), Some(4.0), "newest element wrong");
444
445        assert!(
446            (sma.value() - 3.0).abs() < 1e-12,
447            "unexpected mean after eviction: {}",
448            sma.value()
449        );
450    }
451
452    #[rstest]
453    fn sma_sum_consistent_with_buffer() {
454        const PERIOD: usize = 7;
455        let mut sma = super::SimpleMovingAverage::new(PERIOD, None);
456
457        for i in 0..40 {
458            sma.update_raw(f64::from(i));
459
460            let deque_sum: f64 = sma.buf.iter().copied().sum();
461            assert!(
462                (sma.sum - deque_sum).abs() < 1e-12,
463                "step {i}: internal sum={} differs from buf sum={}",
464                sma.sum,
465                deque_sum
466            );
467        }
468    }
469}