1use std::collections::HashMap;
17
18use nautilus_core::{UnixNanos, datetime::secs_to_nanos};
19use rust_decimal::{Decimal, prelude::ToPrimitive};
20use serde::{Deserialize, Serialize};
21
22use crate::{
23 enums::{AccountType, LiquiditySide, OrderSide},
24 events::{AccountState, OrderFilled},
25 identifiers::AccountId,
26 instruments::{Instrument, InstrumentAny},
27 position::Position,
28 types::{AccountBalance, Currency, Money, Price, Quantity},
29};
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32#[cfg_attr(
33 feature = "python",
34 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.model")
35)]
36pub struct BaseAccount {
37 pub id: AccountId,
38 pub account_type: AccountType,
39 pub base_currency: Option<Currency>,
40 pub calculate_account_state: bool,
41 pub events: Vec<AccountState>,
42 pub commissions: HashMap<Currency, f64>,
43 pub balances: HashMap<Currency, AccountBalance>,
44 pub balances_starting: HashMap<Currency, Money>,
45}
46
47impl BaseAccount {
48 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
50 let mut balances_starting: HashMap<Currency, Money> = HashMap::new();
51 let mut balances: HashMap<Currency, AccountBalance> = HashMap::new();
52 event.balances.iter().for_each(|balance| {
53 balances_starting.insert(balance.currency, balance.total);
54 balances.insert(balance.currency, *balance);
55 });
56 Self {
57 id: event.account_id,
58 account_type: event.account_type,
59 base_currency: event.base_currency,
60 calculate_account_state,
61 events: vec![event],
62 commissions: HashMap::new(),
63 balances,
64 balances_starting,
65 }
66 }
67
68 #[must_use]
74 pub fn base_balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
75 let currency = currency
76 .or(self.base_currency)
77 .expect("Currency must be specified");
78 self.balances.get(¤cy)
79 }
80
81 #[must_use]
87 pub fn base_balance_total(&self, currency: Option<Currency>) -> Option<Money> {
88 let currency = currency
89 .or(self.base_currency)
90 .expect("Currency must be specified");
91 let account_balance = self.balances.get(¤cy);
92 account_balance.map(|balance| balance.total)
93 }
94
95 #[must_use]
96 pub fn base_balances_total(&self) -> HashMap<Currency, Money> {
97 self.balances
98 .iter()
99 .map(|(currency, balance)| (*currency, balance.total))
100 .collect()
101 }
102
103 #[must_use]
109 pub fn base_balance_free(&self, currency: Option<Currency>) -> Option<Money> {
110 let currency = currency
111 .or(self.base_currency)
112 .expect("Currency must be specified");
113 let account_balance = self.balances.get(¤cy);
114 account_balance.map(|balance| balance.free)
115 }
116
117 #[must_use]
118 pub fn base_balances_free(&self) -> HashMap<Currency, Money> {
119 self.balances
120 .iter()
121 .map(|(currency, balance)| (*currency, balance.free))
122 .collect()
123 }
124
125 #[must_use]
131 pub fn base_balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
132 let currency = currency
133 .or(self.base_currency)
134 .expect("Currency must be specified");
135 let account_balance = self.balances.get(¤cy);
136 account_balance.map(|balance| balance.locked)
137 }
138
139 #[must_use]
140 pub fn base_balances_locked(&self) -> HashMap<Currency, Money> {
141 self.balances
142 .iter()
143 .map(|(currency, balance)| (*currency, balance.locked))
144 .collect()
145 }
146
147 #[must_use]
148 pub fn base_last_event(&self) -> Option<AccountState> {
149 self.events.last().cloned()
150 }
151
152 pub fn update_balances(&mut self, balances: Vec<AccountBalance>) {
158 for balance in balances {
159 if balance.total.raw < 0 {
161 panic!("Cannot update balances with total less than 0.0")
163 } else {
164 self.balances.insert(balance.currency, balance);
166 }
167 }
168 }
169
170 pub fn update_commissions(&mut self, commission: Money) {
171 if commission.as_decimal() == Decimal::ZERO {
172 return;
173 }
174
175 let currency = commission.currency;
176 let total_commissions = self.commissions.get(¤cy).unwrap_or(&0.0);
177
178 self.commissions
179 .insert(currency, total_commissions + commission.as_f64());
180 }
181
182 pub fn base_apply(&mut self, event: AccountState) {
183 self.update_balances(event.balances.clone());
184 self.events.push(event);
185 }
186
187 pub fn base_purge_account_events(&mut self, ts_now: UnixNanos, lookback_secs: u64) {
195 let lookback_ns = UnixNanos::from(secs_to_nanos(lookback_secs as f64));
196
197 let mut retained_events = Vec::new();
198
199 for event in &self.events {
200 if event.ts_event + lookback_ns > ts_now {
201 retained_events.push(event.clone());
202 }
203 }
204
205 if retained_events.is_empty() && !self.events.is_empty() {
207 retained_events.push(self.events.last().unwrap().clone());
209 }
210
211 self.events = retained_events;
212 }
213
214 pub fn base_calculate_balance_locked(
224 &mut self,
225 instrument: InstrumentAny,
226 side: OrderSide,
227 quantity: Quantity,
228 price: Price,
229 use_quote_for_inverse: Option<bool>,
230 ) -> anyhow::Result<Money> {
231 let base_currency = instrument
232 .base_currency()
233 .unwrap_or(instrument.quote_currency());
234 let quote_currency = instrument.quote_currency();
235 let notional: f64 = match side {
236 OrderSide::Buy => instrument
237 .calculate_notional_value(quantity, price, use_quote_for_inverse)
238 .as_f64(),
239 OrderSide::Sell => quantity.as_f64(),
240 _ => panic!("Invalid `OrderSide` in `base_calculate_balance_locked`"),
241 };
242
243 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
245 Ok(Money::new(notional, base_currency))
246 } else if side == OrderSide::Buy {
247 Ok(Money::new(notional, quote_currency))
248 } else if side == OrderSide::Sell {
249 Ok(Money::new(notional, base_currency))
250 } else {
251 panic!("Invalid `OrderSide` in `base_calculate_balance_locked`")
252 }
253 }
254
255 pub fn base_calculate_pnls(
265 &self,
266 instrument: InstrumentAny,
267 fill: OrderFilled,
268 position: Option<Position>,
269 ) -> anyhow::Result<Vec<Money>> {
270 let mut pnls: HashMap<Currency, Money> = HashMap::new();
271 let quote_currency = instrument.quote_currency();
272 let base_currency = instrument.base_currency();
273
274 let fill_px = fill.last_px.as_f64();
275 let fill_qty = position.map_or(fill.last_qty.as_f64(), |pos| {
276 pos.quantity.as_f64().min(fill.last_qty.as_f64())
277 });
278 if fill.order_side == OrderSide::Buy {
279 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
280 pnls.insert(
281 base_currency_value,
282 Money::new(fill_qty, base_currency_value),
283 );
284 }
285 pnls.insert(
286 quote_currency,
287 Money::new(-(fill_qty * fill_px), quote_currency),
288 );
289 } else if fill.order_side == OrderSide::Sell {
290 if let (Some(base_currency_value), None) = (base_currency, self.base_currency) {
291 pnls.insert(
292 base_currency_value,
293 Money::new(-fill_qty, base_currency_value),
294 );
295 }
296 pnls.insert(
297 quote_currency,
298 Money::new(fill_qty * fill_px, quote_currency),
299 );
300 } else {
301 panic!("Invalid `OrderSide` in base_calculate_pnls")
302 }
303 Ok(pnls.into_values().collect())
304 }
305
306 pub fn base_calculate_commission(
316 &self,
317 instrument: InstrumentAny,
318 last_qty: Quantity,
319 last_px: Price,
320 liquidity_side: LiquiditySide,
321 use_quote_for_inverse: Option<bool>,
322 ) -> anyhow::Result<Money> {
323 assert!(
324 liquidity_side != LiquiditySide::NoLiquiditySide,
325 "Invalid `LiquiditySide`"
326 );
327 let notional = instrument
328 .calculate_notional_value(last_qty, last_px, use_quote_for_inverse)
329 .as_f64();
330 let commission = if liquidity_side == LiquiditySide::Maker {
331 notional * instrument.maker_fee().to_f64().unwrap()
332 } else if liquidity_side == LiquiditySide::Taker {
333 notional * instrument.taker_fee().to_f64().unwrap()
334 } else {
335 panic!("Invalid `LiquiditySide` {liquidity_side}")
336 };
337 if instrument.is_inverse() && !use_quote_for_inverse.unwrap_or(false) {
338 Ok(Money::new(commission, instrument.base_currency().unwrap()))
339 } else {
340 Ok(Money::new(commission, instrument.quote_currency()))
341 }
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[cfg(feature = "stubs")]
350 #[test]
351 fn test_base_purge_account_events_retains_latest_when_all_purged() {
352 use crate::{
353 enums::AccountType,
354 events::account::stubs::cash_account_state,
355 identifiers::stubs::{account_id, uuid4},
356 types::{Currency, stubs::stub_account_balance},
357 };
358
359 let mut account = BaseAccount::new(cash_account_state(), true);
360
361 let event1 = AccountState::new(
363 account_id(),
364 AccountType::Cash,
365 vec![stub_account_balance()],
366 vec![],
367 true,
368 uuid4(),
369 UnixNanos::from(100_000_000),
370 UnixNanos::from(100_000_000),
371 Some(Currency::USD()),
372 );
373 let event2 = AccountState::new(
374 account_id(),
375 AccountType::Cash,
376 vec![stub_account_balance()],
377 vec![],
378 true,
379 uuid4(),
380 UnixNanos::from(200_000_000),
381 UnixNanos::from(200_000_000),
382 Some(Currency::USD()),
383 );
384 let event3 = AccountState::new(
385 account_id(),
386 AccountType::Cash,
387 vec![stub_account_balance()],
388 vec![],
389 true,
390 uuid4(),
391 UnixNanos::from(300_000_000),
392 UnixNanos::from(300_000_000),
393 Some(Currency::USD()),
394 );
395
396 account.base_apply(event1);
397 account.base_apply(event2.clone());
398 account.base_apply(event3.clone());
399
400 assert_eq!(account.events.len(), 4);
401
402 account.base_purge_account_events(UnixNanos::from(1_000_000_000), 0);
403
404 assert_eq!(account.events.len(), 1);
405 assert_eq!(account.events[0].ts_event, event3.ts_event);
406 assert_eq!(account.base_last_event().unwrap().ts_event, event3.ts_event);
407 }
408}