nautilus_infrastructure/python/sql/
cache.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::collections::HashMap;
17
18use bytes::Bytes;
19use nautilus_common::{
20    cache::database::CacheDatabaseAdapter, custom::CustomData, runtime::get_runtime, signal::Signal,
21};
22use nautilus_core::python::to_pyruntime_err;
23use nautilus_model::{
24    data::{Bar, DataType, QuoteTick, TradeTick},
25    events::{OrderSnapshot, PositionSnapshot},
26    identifiers::{AccountId, ClientId, ClientOrderId, InstrumentId, PositionId},
27    python::{
28        account::{account_any_to_pyobject, pyobject_to_account_any},
29        events::order::pyobject_to_order_event,
30        instruments::{instrument_any_to_pyobject, pyobject_to_instrument_any},
31        orders::{order_any_to_pyobject, pyobject_to_order_any},
32    },
33    types::Currency,
34};
35use pyo3::{IntoPyObjectExt, prelude::*};
36
37use crate::sql::{cache::PostgresCacheDatabase, queries::DatabaseQueries};
38
39#[pymethods]
40impl PostgresCacheDatabase {
41    #[staticmethod]
42    #[pyo3(name = "connect")]
43    #[pyo3(signature = (host=None, port=None, username=None, password=None, database=None))]
44    fn py_connect(
45        host: Option<String>,
46        port: Option<u16>,
47        username: Option<String>,
48        password: Option<String>,
49        database: Option<String>,
50    ) -> PyResult<Self> {
51        let result = get_runtime()
52            .block_on(async { Self::connect(host, port, username, password, database).await });
53        result.map_err(to_pyruntime_err)
54    }
55
56    #[pyo3(name = "close")]
57    fn py_close(&mut self) -> PyResult<()> {
58        self.close().map_err(to_pyruntime_err)
59    }
60
61    #[pyo3(name = "flush_db")]
62    fn py_flush_db(&mut self) -> PyResult<()> {
63        self.flush().map_err(to_pyruntime_err)
64    }
65
66    #[pyo3(name = "load")]
67    fn py_load(&self) -> PyResult<HashMap<String, Vec<u8>>> {
68        get_runtime()
69            .block_on(async { DatabaseQueries::load(&self.pool).await })
70            .map_err(to_pyruntime_err)
71    }
72
73    #[pyo3(name = "load_currency")]
74    fn py_load_currency(&self, code: &str) -> PyResult<Option<Currency>> {
75        let result = get_runtime()
76            .block_on(async { DatabaseQueries::load_currency(&self.pool, code).await });
77        result.map_err(to_pyruntime_err)
78    }
79
80    #[pyo3(name = "load_currencies")]
81    fn py_load_currencies(&self) -> PyResult<Vec<Currency>> {
82        let result =
83            get_runtime().block_on(async { DatabaseQueries::load_currencies(&self.pool).await });
84        result.map_err(to_pyruntime_err)
85    }
86
87    #[pyo3(name = "load_instrument")]
88    fn py_load_instrument(
89        &self,
90        py: Python,
91        instrument_id: InstrumentId,
92    ) -> PyResult<Option<PyObject>> {
93        get_runtime().block_on(async {
94            let result = DatabaseQueries::load_instrument(&self.pool, &instrument_id)
95                .await
96                .unwrap();
97            match result {
98                Some(instrument) => {
99                    let py_object = instrument_any_to_pyobject(py, instrument)?;
100                    Ok(Some(py_object))
101                }
102                None => Ok(None),
103            }
104        })
105    }
106
107    #[pyo3(name = "load_instruments")]
108    fn py_load_instruments(&self, py: Python) -> PyResult<Vec<PyObject>> {
109        get_runtime().block_on(async {
110            let result = DatabaseQueries::load_instruments(&self.pool).await.unwrap();
111            let mut instruments = Vec::new();
112            for instrument in result {
113                let py_object = instrument_any_to_pyobject(py, instrument)?;
114                instruments.push(py_object);
115            }
116            Ok(instruments)
117        })
118    }
119
120    #[pyo3(name = "load_order")]
121    fn py_load_order(
122        &self,
123        py: Python,
124        client_order_id: ClientOrderId,
125    ) -> PyResult<Option<PyObject>> {
126        get_runtime().block_on(async {
127            let result = DatabaseQueries::load_order(&self.pool, &client_order_id)
128                .await
129                .unwrap();
130            match result {
131                Some(order) => {
132                    let py_object = order_any_to_pyobject(py, order)?;
133                    Ok(Some(py_object))
134                }
135                None => Ok(None),
136            }
137        })
138    }
139
140    #[pyo3(name = "load_account")]
141    fn py_load_account(&self, py: Python, account_id: AccountId) -> PyResult<Option<PyObject>> {
142        get_runtime().block_on(async {
143            let result = DatabaseQueries::load_account(&self.pool, &account_id)
144                .await
145                .unwrap();
146            match result {
147                Some(account) => {
148                    let py_object = account_any_to_pyobject(py, account)?;
149                    Ok(Some(py_object))
150                }
151                None => Ok(None),
152            }
153        })
154    }
155
156    #[pyo3(name = "load_quotes")]
157    fn py_load_quotes(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
158        get_runtime().block_on(async {
159            let result = DatabaseQueries::load_quotes(&self.pool, &instrument_id)
160                .await
161                .unwrap();
162            let mut quotes = Vec::new();
163            for quote in result {
164                let py_object = quote.into_py_any(py)?;
165                quotes.push(py_object);
166            }
167            Ok(quotes)
168        })
169    }
170
171    #[pyo3(name = "load_trades")]
172    fn py_load_trades(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
173        get_runtime().block_on(async {
174            let result = DatabaseQueries::load_trades(&self.pool, &instrument_id)
175                .await
176                .unwrap();
177            let mut trades = Vec::new();
178            for trade in result {
179                let py_object = trade.into_py_any(py)?;
180                trades.push(py_object);
181            }
182            Ok(trades)
183        })
184    }
185
186    #[pyo3(name = "load_bars")]
187    fn py_load_bars(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<PyObject>> {
188        get_runtime().block_on(async {
189            let result = DatabaseQueries::load_bars(&self.pool, &instrument_id)
190                .await
191                .unwrap();
192            let mut bars = Vec::new();
193            for bar in result {
194                let py_object = bar.into_py_any(py)?;
195                bars.push(py_object);
196            }
197            Ok(bars)
198        })
199    }
200
201    #[pyo3(name = "load_signals")]
202    fn py_load_signals(&self, name: &str) -> PyResult<Vec<Signal>> {
203        get_runtime().block_on(async {
204            DatabaseQueries::load_signals(&self.pool, name)
205                .await
206                .map_err(to_pyruntime_err)
207        })
208    }
209
210    #[pyo3(name = "load_custom_data")]
211    fn py_load_custom_data(&self, data_type: DataType) -> PyResult<Vec<CustomData>> {
212        get_runtime().block_on(async {
213            DatabaseQueries::load_custom_data(&self.pool, &data_type)
214                .await
215                .map_err(to_pyruntime_err)
216        })
217    }
218
219    #[pyo3(name = "load_order_snapshot")]
220    fn py_load_order_snapshot(
221        &self,
222        client_order_id: ClientOrderId,
223    ) -> PyResult<Option<OrderSnapshot>> {
224        get_runtime().block_on(async {
225            DatabaseQueries::load_order_snapshot(&self.pool, &client_order_id)
226                .await
227                .map_err(to_pyruntime_err)
228        })
229    }
230
231    #[pyo3(name = "load_position_snapshot")]
232    fn py_load_position_snapshot(
233        &self,
234        position_id: PositionId,
235    ) -> PyResult<Option<PositionSnapshot>> {
236        get_runtime().block_on(async {
237            DatabaseQueries::load_position_snapshot(&self.pool, &position_id)
238                .await
239                .map_err(to_pyruntime_err)
240        })
241    }
242
243    #[pyo3(name = "add")]
244    fn py_add(&self, key: String, value: Vec<u8>) -> PyResult<()> {
245        self.add(key, Bytes::from(value)).map_err(to_pyruntime_err)
246    }
247
248    #[pyo3(name = "add_currency")]
249    fn py_add_currency(&self, currency: Currency) -> PyResult<()> {
250        self.add_currency(&currency).map_err(to_pyruntime_err)
251    }
252
253    #[pyo3(name = "add_instrument")]
254    fn py_add_instrument(&self, py: Python, instrument: PyObject) -> PyResult<()> {
255        let instrument_any = pyobject_to_instrument_any(py, instrument)?;
256        self.add_instrument(&instrument_any)
257            .map_err(to_pyruntime_err)
258    }
259
260    #[pyo3(name = "add_order")]
261    #[pyo3(signature = (order, client_id=None))]
262    fn py_add_order(
263        &self,
264        py: Python,
265        order: PyObject,
266        client_id: Option<ClientId>,
267    ) -> PyResult<()> {
268        let order_any = pyobject_to_order_any(py, order)?;
269        self.add_order(&order_any, client_id)
270            .map_err(to_pyruntime_err)
271    }
272
273    #[pyo3(name = "add_order_snapshot")]
274    fn py_add_order_snapshot(&self, snapshot: OrderSnapshot) -> PyResult<()> {
275        self.add_order_snapshot(&snapshot).map_err(to_pyruntime_err)
276    }
277
278    #[pyo3(name = "add_position_snapshot")]
279    fn py_add_position_snapshot(&self, snapshot: PositionSnapshot) -> PyResult<()> {
280        self.add_position_snapshot(&snapshot)
281            .map_err(to_pyruntime_err)
282    }
283
284    #[pyo3(name = "add_account")]
285    fn py_add_account(&self, py: Python, account: PyObject) -> PyResult<()> {
286        let account_any = pyobject_to_account_any(py, account)?;
287        self.add_account(&account_any).map_err(to_pyruntime_err)
288    }
289
290    #[pyo3(name = "add_quote")]
291    fn py_add_quote(&self, quote: QuoteTick) -> PyResult<()> {
292        self.add_quote(&quote).map_err(to_pyruntime_err)
293    }
294
295    #[pyo3(name = "add_trade")]
296    fn py_add_trade(&self, trade: TradeTick) -> PyResult<()> {
297        self.add_trade(&trade).map_err(to_pyruntime_err)
298    }
299
300    #[pyo3(name = "add_bar")]
301    fn py_add_bar(&self, bar: Bar) -> PyResult<()> {
302        self.add_bar(&bar).map_err(to_pyruntime_err)
303    }
304
305    #[pyo3(name = "add_signal")]
306    fn py_add_signal(&self, signal: Signal) -> PyResult<()> {
307        self.add_signal(&signal).map_err(to_pyruntime_err)
308    }
309
310    #[pyo3(name = "add_custom_data")]
311    fn py_add_custom_data(&self, data: CustomData) -> PyResult<()> {
312        self.add_custom_data(&data).map_err(to_pyruntime_err)
313    }
314
315    #[pyo3(name = "update_order")]
316    fn py_update_order(&self, py: Python, order_event: PyObject) -> PyResult<()> {
317        let event = pyobject_to_order_event(py, order_event)?;
318        self.update_order(&event).map_err(to_pyruntime_err)
319    }
320
321    #[pyo3(name = "update_account")]
322    fn py_update_account(&self, py: Python, order: PyObject) -> PyResult<()> {
323        let order_any = pyobject_to_account_any(py, order)?;
324        self.update_account(&order_any).map_err(to_pyruntime_err)
325    }
326}