1use std::{
17 collections::HashMap,
18 fmt::Display,
19 ops::{Deref, DerefMut},
20};
21
22use rust_decimal::{Decimal, prelude::ToPrimitive};
23use serde::{Deserialize, Serialize};
24
25use crate::{
26 accounts::{Account, base::BaseAccount},
27 enums::{AccountType, LiquiditySide, OrderSide},
28 events::{AccountState, OrderFilled},
29 identifiers::AccountId,
30 instruments::InstrumentAny,
31 position::Position,
32 types::{AccountBalance, Currency, Money, Price, Quantity},
33};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36#[cfg_attr(
37 feature = "python",
38 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.model")
39)]
40pub struct CashAccount {
41 pub base: BaseAccount,
42}
43
44impl CashAccount {
45 pub fn new(event: AccountState, calculate_account_state: bool) -> Self {
47 Self {
48 base: BaseAccount::new(event, calculate_account_state),
49 }
50 }
51
52 #[must_use]
53 pub fn is_cash_account(&self) -> bool {
54 self.account_type == AccountType::Cash
55 }
56 #[must_use]
57 pub fn is_margin_account(&self) -> bool {
58 self.account_type == AccountType::Margin
59 }
60
61 #[must_use]
62 pub const fn is_unleveraged(&self) -> bool {
63 false
64 }
65
66 pub fn recalculate_balance(&mut self, currency: Currency) {
72 let current_balance = match self.balances.get(¤cy) {
73 Some(balance) => *balance,
74 None => {
75 return;
76 }
77 };
78
79 let total_locked = self
80 .balances
81 .values()
82 .filter(|balance| balance.currency == currency)
83 .fold(Decimal::ZERO, |acc, balance| {
84 acc + balance.locked.as_decimal()
85 });
86
87 let new_balance = AccountBalance::new(
88 current_balance.total,
89 Money::new(total_locked.to_f64().unwrap(), currency),
90 Money::new(
91 (current_balance.total.as_decimal() - total_locked)
92 .to_f64()
93 .unwrap(),
94 currency,
95 ),
96 );
97
98 self.balances.insert(currency, new_balance);
99 }
100}
101
102impl Account for CashAccount {
103 fn id(&self) -> AccountId {
104 self.id
105 }
106
107 fn account_type(&self) -> AccountType {
108 self.account_type
109 }
110
111 fn base_currency(&self) -> Option<Currency> {
112 self.base_currency
113 }
114
115 fn is_cash_account(&self) -> bool {
116 self.account_type == AccountType::Cash
117 }
118
119 fn is_margin_account(&self) -> bool {
120 self.account_type == AccountType::Margin
121 }
122
123 fn calculated_account_state(&self) -> bool {
124 false }
126
127 fn balance_total(&self, currency: Option<Currency>) -> Option<Money> {
128 self.base_balance_total(currency)
129 }
130
131 fn balances_total(&self) -> HashMap<Currency, Money> {
132 self.base_balances_total()
133 }
134
135 fn balance_free(&self, currency: Option<Currency>) -> Option<Money> {
136 self.base_balance_free(currency)
137 }
138
139 fn balances_free(&self) -> HashMap<Currency, Money> {
140 self.base_balances_free()
141 }
142
143 fn balance_locked(&self, currency: Option<Currency>) -> Option<Money> {
144 self.base_balance_locked(currency)
145 }
146
147 fn balances_locked(&self) -> HashMap<Currency, Money> {
148 self.base_balances_locked()
149 }
150
151 fn balance(&self, currency: Option<Currency>) -> Option<&AccountBalance> {
152 self.base_balance(currency)
153 }
154
155 fn last_event(&self) -> Option<AccountState> {
156 self.base_last_event()
157 }
158
159 fn events(&self) -> Vec<AccountState> {
160 self.events.clone()
161 }
162
163 fn event_count(&self) -> usize {
164 self.events.len()
165 }
166
167 fn currencies(&self) -> Vec<Currency> {
168 self.balances.keys().copied().collect()
169 }
170
171 fn starting_balances(&self) -> HashMap<Currency, Money> {
172 self.balances_starting.clone()
173 }
174
175 fn balances(&self) -> HashMap<Currency, AccountBalance> {
176 self.balances.clone()
177 }
178
179 fn apply(&mut self, event: AccountState) {
180 self.base_apply(event);
181 }
182
183 fn purge_account_events(&mut self, ts_now: nautilus_core::UnixNanos, lookback_secs: u64) {
184 self.base.base_purge_account_events(ts_now, lookback_secs);
185 }
186
187 fn calculate_balance_locked(
188 &mut self,
189 instrument: InstrumentAny,
190 side: OrderSide,
191 quantity: Quantity,
192 price: Price,
193 use_quote_for_inverse: Option<bool>,
194 ) -> anyhow::Result<Money> {
195 self.base_calculate_balance_locked(instrument, side, quantity, price, use_quote_for_inverse)
196 }
197
198 fn calculate_pnls(
199 &self,
200 instrument: InstrumentAny, fill: OrderFilled, position: Option<Position>,
203 ) -> anyhow::Result<Vec<Money>> {
204 self.base_calculate_pnls(instrument, fill, position)
205 }
206
207 fn calculate_commission(
208 &self,
209 instrument: InstrumentAny,
210 last_qty: Quantity,
211 last_px: Price,
212 liquidity_side: LiquiditySide,
213 use_quote_for_inverse: Option<bool>,
214 ) -> anyhow::Result<Money> {
215 self.base_calculate_commission(
216 instrument,
217 last_qty,
218 last_px,
219 liquidity_side,
220 use_quote_for_inverse,
221 )
222 }
223}
224
225impl Deref for CashAccount {
226 type Target = BaseAccount;
227
228 fn deref(&self) -> &Self::Target {
229 &self.base
230 }
231}
232
233impl DerefMut for CashAccount {
234 fn deref_mut(&mut self) -> &mut Self::Target {
235 &mut self.base
236 }
237}
238
239impl PartialEq for CashAccount {
240 fn eq(&self, other: &Self) -> bool {
241 self.id == other.id
242 }
243}
244
245impl Eq for CashAccount {}
246
247impl Display for CashAccount {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 write!(
250 f,
251 "CashAccount(id={}, type={}, base={})",
252 self.id,
253 self.account_type,
254 self.base_currency.map_or_else(
255 || "None".to_string(),
256 |base_currency| format!("{}", base_currency.code)
257 ),
258 )
259 }
260}
261
262#[cfg(test)]
266mod tests {
267 use std::collections::{HashMap, HashSet};
268
269 use rstest::rstest;
270
271 use crate::{
272 accounts::{Account, CashAccount, stubs::*},
273 enums::{AccountType, LiquiditySide, OrderSide, OrderType},
274 events::{AccountState, account::stubs::*},
275 identifiers::{AccountId, position_id::PositionId},
276 instruments::{CryptoPerpetual, CurrencyPair, Equity, Instrument, InstrumentAny, stubs::*},
277 orders::{builder::OrderTestBuilder, stubs::TestOrderEventStubs},
278 position::Position,
279 types::{Currency, Money, Price, Quantity},
280 };
281
282 #[rstest]
283 fn test_display(cash_account: CashAccount) {
284 assert_eq!(
285 format!("{cash_account}"),
286 "CashAccount(id=SIM-001, type=CASH, base=USD)"
287 );
288 }
289
290 #[rstest]
291 fn test_instantiate_single_asset_cash_account(
292 cash_account: CashAccount,
293 cash_account_state: AccountState,
294 ) {
295 assert_eq!(cash_account.id, AccountId::from("SIM-001"));
296 assert_eq!(cash_account.account_type, AccountType::Cash);
297 assert_eq!(cash_account.base_currency, Some(Currency::from("USD")));
298 assert_eq!(cash_account.last_event(), Some(cash_account_state.clone()));
299 assert_eq!(cash_account.events(), vec![cash_account_state]);
300 assert_eq!(cash_account.event_count(), 1);
301 assert_eq!(
302 cash_account.balance_total(None),
303 Some(Money::from("1525000 USD"))
304 );
305 assert_eq!(
306 cash_account.balance_free(None),
307 Some(Money::from("1500000 USD"))
308 );
309 assert_eq!(
310 cash_account.balance_locked(None),
311 Some(Money::from("25000 USD"))
312 );
313 let mut balances_total_expected = HashMap::new();
314 balances_total_expected.insert(Currency::from("USD"), Money::from("1525000 USD"));
315 assert_eq!(cash_account.balances_total(), balances_total_expected);
316 let mut balances_free_expected = HashMap::new();
317 balances_free_expected.insert(Currency::from("USD"), Money::from("1500000 USD"));
318 assert_eq!(cash_account.balances_free(), balances_free_expected);
319 let mut balances_locked_expected = HashMap::new();
320 balances_locked_expected.insert(Currency::from("USD"), Money::from("25000 USD"));
321 assert_eq!(cash_account.balances_locked(), balances_locked_expected);
322 }
323
324 #[rstest]
325 fn test_instantiate_multi_asset_cash_account(
326 cash_account_multi: CashAccount,
327 cash_account_state_multi: AccountState,
328 ) {
329 assert_eq!(cash_account_multi.id, AccountId::from("SIM-001"));
330 assert_eq!(cash_account_multi.account_type, AccountType::Cash);
331 assert_eq!(
332 cash_account_multi.last_event(),
333 Some(cash_account_state_multi.clone())
334 );
335 assert_eq!(cash_account_state_multi.base_currency, None);
336 assert_eq!(cash_account_multi.events(), vec![cash_account_state_multi]);
337 assert_eq!(cash_account_multi.event_count(), 1);
338 assert_eq!(
339 cash_account_multi.balance_total(Some(Currency::BTC())),
340 Some(Money::from("10 BTC"))
341 );
342 assert_eq!(
343 cash_account_multi.balance_total(Some(Currency::ETH())),
344 Some(Money::from("20 ETH"))
345 );
346 assert_eq!(
347 cash_account_multi.balance_free(Some(Currency::BTC())),
348 Some(Money::from("10 BTC"))
349 );
350 assert_eq!(
351 cash_account_multi.balance_free(Some(Currency::ETH())),
352 Some(Money::from("20 ETH"))
353 );
354 assert_eq!(
355 cash_account_multi.balance_locked(Some(Currency::BTC())),
356 Some(Money::from("0 BTC"))
357 );
358 assert_eq!(
359 cash_account_multi.balance_locked(Some(Currency::ETH())),
360 Some(Money::from("0 ETH"))
361 );
362 let mut balances_total_expected = HashMap::new();
363 balances_total_expected.insert(Currency::from("BTC"), Money::from("10 BTC"));
364 balances_total_expected.insert(Currency::from("ETH"), Money::from("20 ETH"));
365 assert_eq!(cash_account_multi.balances_total(), balances_total_expected);
366 let mut balances_free_expected = HashMap::new();
367 balances_free_expected.insert(Currency::from("BTC"), Money::from("10 BTC"));
368 balances_free_expected.insert(Currency::from("ETH"), Money::from("20 ETH"));
369 assert_eq!(cash_account_multi.balances_free(), balances_free_expected);
370 let mut balances_locked_expected = HashMap::new();
371 balances_locked_expected.insert(Currency::from("BTC"), Money::from("0 BTC"));
372 balances_locked_expected.insert(Currency::from("ETH"), Money::from("0 ETH"));
373 assert_eq!(
374 cash_account_multi.balances_locked(),
375 balances_locked_expected
376 );
377 }
378
379 #[rstest]
380 fn test_apply_given_new_state_event_updates_correctly(
381 mut cash_account_multi: CashAccount,
382 cash_account_state_multi: AccountState,
383 cash_account_state_multi_changed_btc: AccountState,
384 ) {
385 cash_account_multi.apply(cash_account_state_multi_changed_btc.clone());
387 assert_eq!(
388 cash_account_multi.last_event(),
389 Some(cash_account_state_multi_changed_btc.clone())
390 );
391 assert_eq!(
392 cash_account_multi.events,
393 vec![
394 cash_account_state_multi,
395 cash_account_state_multi_changed_btc
396 ]
397 );
398 assert_eq!(cash_account_multi.event_count(), 2);
399 assert_eq!(
400 cash_account_multi.balance_total(Some(Currency::BTC())),
401 Some(Money::from("9 BTC"))
402 );
403 assert_eq!(
404 cash_account_multi.balance_free(Some(Currency::BTC())),
405 Some(Money::from("8.5 BTC"))
406 );
407 assert_eq!(
408 cash_account_multi.balance_locked(Some(Currency::BTC())),
409 Some(Money::from("0.5 BTC"))
410 );
411 assert_eq!(
412 cash_account_multi.balance_total(Some(Currency::ETH())),
413 Some(Money::from("20 ETH"))
414 );
415 assert_eq!(
416 cash_account_multi.balance_free(Some(Currency::ETH())),
417 Some(Money::from("20 ETH"))
418 );
419 assert_eq!(
420 cash_account_multi.balance_locked(Some(Currency::ETH())),
421 Some(Money::from("0 ETH"))
422 );
423 }
424
425 #[rstest]
426 fn test_calculate_balance_locked_buy(
427 mut cash_account_million_usd: CashAccount,
428 audusd_sim: CurrencyPair,
429 ) {
430 let balance_locked = cash_account_million_usd
431 .calculate_balance_locked(
432 audusd_sim.into_any(),
433 OrderSide::Buy,
434 Quantity::from("1000000"),
435 Price::from("0.8"),
436 None,
437 )
438 .unwrap();
439 assert_eq!(balance_locked, Money::from("800000 USD"));
440 }
441
442 #[rstest]
443 fn test_calculate_balance_locked_sell(
444 mut cash_account_million_usd: CashAccount,
445 audusd_sim: CurrencyPair,
446 ) {
447 let balance_locked = cash_account_million_usd
448 .calculate_balance_locked(
449 audusd_sim.into_any(),
450 OrderSide::Sell,
451 Quantity::from("1000000"),
452 Price::from("0.8"),
453 None,
454 )
455 .unwrap();
456 assert_eq!(balance_locked, Money::from("1000000 AUD"));
457 }
458
459 #[rstest]
460 fn test_calculate_balance_locked_sell_no_base_currency(
461 mut cash_account_million_usd: CashAccount,
462 equity_aapl: Equity,
463 ) {
464 let balance_locked = cash_account_million_usd
465 .calculate_balance_locked(
466 equity_aapl.into_any(),
467 OrderSide::Sell,
468 Quantity::from("100"),
469 Price::from("1500.0"),
470 None,
471 )
472 .unwrap();
473 assert_eq!(balance_locked, Money::from("100 USD"));
474 }
475
476 #[rstest]
477 fn test_calculate_pnls_for_single_currency_cash_account(
478 cash_account_million_usd: CashAccount,
479 audusd_sim: CurrencyPair,
480 ) {
481 let audusd_sim = InstrumentAny::CurrencyPair(audusd_sim);
482 let order = OrderTestBuilder::new(OrderType::Market)
483 .instrument_id(audusd_sim.id())
484 .side(OrderSide::Buy)
485 .quantity(Quantity::from("1000000"))
486 .build();
487 let fill = TestOrderEventStubs::filled(
488 &order,
489 &audusd_sim,
490 None,
491 Some(PositionId::new("P-123456")),
492 Some(Price::from("0.8")),
493 None,
494 None,
495 None,
496 None,
497 Some(AccountId::from("SIM-001")),
498 );
499 let position = Position::new(&audusd_sim, fill.clone().into());
500 let pnls = cash_account_million_usd
501 .calculate_pnls(audusd_sim, fill.into(), Some(position)) .unwrap();
503 assert_eq!(pnls, vec![Money::from("-800000 USD")]);
504 }
505
506 #[rstest]
507 fn test_calculate_pnls_for_multi_currency_cash_account_btcusdt(
508 cash_account_multi: CashAccount,
509 currency_pair_btcusdt: CurrencyPair,
510 ) {
511 let btcusdt = InstrumentAny::CurrencyPair(currency_pair_btcusdt);
512 let order1 = OrderTestBuilder::new(OrderType::Market)
513 .instrument_id(currency_pair_btcusdt.id)
514 .side(OrderSide::Sell)
515 .quantity(Quantity::from("0.5"))
516 .build();
517 let fill1 = TestOrderEventStubs::filled(
518 &order1,
519 &btcusdt,
520 None,
521 Some(PositionId::new("P-123456")),
522 Some(Price::from("45500.00")),
523 None,
524 None,
525 None,
526 None,
527 Some(AccountId::from("SIM-001")),
528 );
529 let position = Position::new(&btcusdt, fill1.clone().into());
530 let result1 = cash_account_multi
531 .calculate_pnls(
532 currency_pair_btcusdt.into_any(),
533 fill1.into(), Some(position.clone()),
535 )
536 .unwrap();
537 let order2 = OrderTestBuilder::new(OrderType::Market)
538 .instrument_id(currency_pair_btcusdt.id)
539 .side(OrderSide::Buy)
540 .quantity(Quantity::from("0.5"))
541 .build();
542 let fill2 = TestOrderEventStubs::filled(
543 &order2,
544 &btcusdt,
545 None,
546 Some(PositionId::new("P-123456")),
547 Some(Price::from("45500.00")),
548 None,
549 None,
550 None,
551 None,
552 Some(AccountId::from("SIM-001")),
553 );
554 let result2 = cash_account_multi
555 .calculate_pnls(
556 currency_pair_btcusdt.into_any(),
557 fill2.into(),
558 Some(position),
559 )
560 .unwrap();
561 let result1_set: HashSet<Money> = result1.into_iter().collect();
563 let result1_expected: HashSet<Money> =
564 vec![Money::from("22750 USDT"), Money::from("-0.5 BTC")]
565 .into_iter()
566 .collect();
567 let result2_set: HashSet<Money> = result2.into_iter().collect();
568 let result2_expected: HashSet<Money> =
569 vec![Money::from("-22750 USDT"), Money::from("0.5 BTC")]
570 .into_iter()
571 .collect();
572 assert_eq!(result1_set, result1_expected);
573 assert_eq!(result2_set, result2_expected);
574 }
575
576 #[rstest]
577 #[case(false, Money::from("-0.00218331 BTC"))]
578 #[case(true, Money::from("-25.0 USD"))]
579 fn test_calculate_commission_for_inverse_maker_crypto(
580 #[case] use_quote_for_inverse: bool,
581 #[case] expected: Money,
582 cash_account_million_usd: CashAccount,
583 xbtusd_bitmex: CryptoPerpetual,
584 ) {
585 let result = cash_account_million_usd
586 .calculate_commission(
587 xbtusd_bitmex.into_any(),
588 Quantity::from("100000"),
589 Price::from("11450.50"),
590 LiquiditySide::Maker,
591 Some(use_quote_for_inverse),
592 )
593 .unwrap();
594 assert_eq!(result, expected);
595 }
596
597 #[rstest]
598 fn test_calculate_commission_for_taker_fx(
599 cash_account_million_usd: CashAccount,
600 audusd_sim: CurrencyPair,
601 ) {
602 let result = cash_account_million_usd
603 .calculate_commission(
604 audusd_sim.into_any(),
605 Quantity::from("1500000"),
606 Price::from("0.8005"),
607 LiquiditySide::Taker,
608 None,
609 )
610 .unwrap();
611 assert_eq!(result, Money::from("24.02 USD"));
612 }
613
614 #[rstest]
615 fn test_calculate_commission_crypto_taker(
616 cash_account_million_usd: CashAccount,
617 xbtusd_bitmex: CryptoPerpetual,
618 ) {
619 let result = cash_account_million_usd
620 .calculate_commission(
621 xbtusd_bitmex.into_any(),
622 Quantity::from("100000"),
623 Price::from("11450.50"),
624 LiquiditySide::Taker,
625 None,
626 )
627 .unwrap();
628 assert_eq!(result, Money::from("0.00654993 BTC"));
629 }
630
631 #[rstest]
632 fn test_calculate_commission_fx_taker(cash_account_million_usd: CashAccount) {
633 let instrument = usdjpy_idealpro();
634 let result = cash_account_million_usd
635 .calculate_commission(
636 instrument.into_any(),
637 Quantity::from("2200000"),
638 Price::from("120.310"),
639 LiquiditySide::Taker,
640 None,
641 )
642 .unwrap();
643 assert_eq!(result, Money::from("5294 JPY"));
644 }
645}