nautilus_coinbase_intx/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Posei Systems Pty Ltd. All rights reserved.
3//  https://poseitrader.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::HashMap,
18    sync::{
19        Arc,
20        atomic::{AtomicBool, Ordering},
21    },
22    time::{Duration, SystemTime},
23};
24
25use chrono::Utc;
26use futures_util::{Stream, StreamExt};
27use nautilus_common::{logging::log_task_stopped, runtime::get_runtime};
28use nautilus_core::{
29    consts::NAUTILUS_USER_AGENT, env::get_env_var, time::get_atomic_clock_realtime,
30};
31use nautilus_model::{
32    data::{BarType, Data, OrderBookDeltas_API},
33    identifiers::InstrumentId,
34    instruments::{Instrument, InstrumentAny},
35};
36use nautilus_network::websocket::{Consumer, MessageReader, WebSocketClient, WebSocketConfig};
37use reqwest::header::USER_AGENT;
38use tokio::sync::Mutex;
39use tokio_tungstenite::tungstenite::{Error, Message};
40use ustr::Ustr;
41
42use super::{
43    enums::{CoinbaseIntxWsChannel, WsOperation},
44    error::CoinbaseIntxWsError,
45    messages::{CoinbaseIntxSubscription, CoinbaseIntxWsMessage, PoseiWsMessage},
46    parse::{
47        parse_candle_msg, parse_index_price_msg, parse_mark_price_msg,
48        parse_orderbook_snapshot_msg, parse_orderbook_update_msg, parse_quote_msg,
49    },
50};
51use crate::{
52    common::{
53        consts::COINBASE_INTX_WS_URL, credential::Credential, parse::bar_spec_as_coinbase_channel,
54    },
55    websocket::parse::{parse_instrument_any, parse_trade_msg},
56};
57
58/// Provides a WebSocket client for connecting to [Coinbase International](https://www.coinbase.com/en/international-exchange).
59#[derive(Debug, Clone)]
60#[cfg_attr(
61    feature = "python",
62    pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.adapters")
63)]
64pub struct CoinbaseIntxWebSocketClient {
65    url: String,
66    credential: Credential,
67    heartbeat: Option<u64>,
68    inner: Option<Arc<WebSocketClient>>,
69    rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<PoseiWsMessage>>>,
70    signal: Arc<AtomicBool>,
71    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
72    subscriptions: Arc<Mutex<HashMap<CoinbaseIntxWsChannel, Vec<Ustr>>>>,
73}
74
75impl Default for CoinbaseIntxWebSocketClient {
76    fn default() -> Self {
77        Self::new(None, None, None, None, Some(10)).expect("Failed to create client")
78    }
79}
80
81impl CoinbaseIntxWebSocketClient {
82    /// Creates a new [`CoinbaseIntxWebSocketClient`] instance.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if required environment variables are missing or invalid.
87    pub fn new(
88        url: Option<String>,
89        api_key: Option<String>,
90        api_secret: Option<String>,
91        api_passphrase: Option<String>,
92        heartbeat: Option<u64>,
93    ) -> anyhow::Result<Self> {
94        let url = url.unwrap_or(COINBASE_INTX_WS_URL.to_string());
95        let api_key = api_key.unwrap_or(get_env_var("COINBASE_INTX_API_KEY")?);
96        let api_secret = api_secret.unwrap_or(get_env_var("COINBASE_INTX_API_SECRET")?);
97        let api_passphrase = api_passphrase.unwrap_or(get_env_var("COINBASE_INTX_API_PASSPHRASE")?);
98
99        let credential = Credential::new(api_key, api_secret, api_passphrase);
100        let signal = Arc::new(AtomicBool::new(false));
101        let subscriptions = Arc::new(Mutex::new(HashMap::new()));
102
103        Ok(Self {
104            url,
105            credential,
106            heartbeat,
107            inner: None,
108            rx: None,
109            signal,
110            task_handle: None,
111            subscriptions,
112        })
113    }
114
115    /// Creates a new authenticated [`CoinbaseIntxWebSocketClient`] using environment variables and
116    /// the default Coinbase International production websocket url.
117    ///
118    /// # Errors
119    ///
120    /// Returns an error if required environment variables are missing or invalid.
121    pub fn from_env() -> anyhow::Result<Self> {
122        Self::new(None, None, None, None, None)
123    }
124
125    /// Returns the websocket url being used by the client.
126    #[must_use]
127    pub const fn url(&self) -> &str {
128        self.url.as_str()
129    }
130
131    /// Returns the public API key being used by the client.
132    #[must_use]
133    pub fn api_key(&self) -> &str {
134        self.credential.api_key.as_str()
135    }
136
137    /// Returns a value indicating whether the client is active.
138    #[must_use]
139    pub fn is_active(&self) -> bool {
140        match &self.inner {
141            Some(inner) => inner.is_active(),
142            None => false,
143        }
144    }
145
146    /// Returns a value indicating whether the client is closed.
147    #[must_use]
148    pub fn is_closed(&self) -> bool {
149        match &self.inner {
150            Some(inner) => inner.is_closed(),
151            None => true,
152        }
153    }
154
155    /// Connects the client to the server and caches the given instruments.
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if the WebSocket connection or initial subscription fails.
160    pub async fn connect(&mut self, instruments: Vec<InstrumentAny>) -> anyhow::Result<()> {
161        let client = self.clone();
162        let post_reconnect = Arc::new(move || {
163            let client = client.clone();
164            tokio::spawn(async move { client.resubscribe_all().await });
165        });
166
167        let config = WebSocketConfig {
168            url: self.url.clone(),
169            headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
170            heartbeat: self.heartbeat,
171            heartbeat_msg: None,
172            #[cfg(feature = "python")]
173            handler: Consumer::Python(None),
174            #[cfg(feature = "python")]
175            ping_handler: None,
176            reconnect_timeout_ms: Some(5_000),
177            reconnect_delay_initial_ms: None, // Use default
178            reconnect_delay_max_ms: None,     // Use default
179            reconnect_backoff_factor: None,   // Use default
180            reconnect_jitter_ms: None,        // Use default
181        };
182        let (reader, client) =
183            WebSocketClient::connect_stream(config, vec![], None, Some(post_reconnect)).await?;
184
185        self.inner = Some(Arc::new(client));
186
187        let mut instruments_map: HashMap<Ustr, InstrumentAny> = HashMap::new();
188        for inst in instruments {
189            instruments_map.insert(inst.raw_symbol().inner(), inst);
190        }
191
192        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<PoseiWsMessage>();
193        self.rx = Some(Arc::new(rx));
194        let signal = self.signal.clone();
195
196        let stream_handle = get_runtime().spawn(async move {
197            CoinbaseIntxWsMessageHandler::new(instruments_map, reader, signal, tx)
198                .run()
199                .await;
200        });
201
202        self.task_handle = Some(Arc::new(stream_handle));
203
204        Ok(())
205    }
206
207    /// Provides the internal data stream as a channel-based stream.
208    ///
209    /// # Panics
210    ///
211    /// This function panics if:
212    /// - The websocket is not connected.
213    /// - If `stream_data` has already been called somewhere else (stream receiver is then taken).
214    pub fn stream(&mut self) -> impl Stream<Item = PoseiWsMessage> + 'static {
215        let rx = self
216            .rx
217            .take()
218            .expect("Data stream receiver already taken or not connected"); // Design-time error
219        let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
220        async_stream::stream! {
221            while let Some(data) = rx.recv().await {
222                yield data;
223            }
224        }
225    }
226
227    /// Closes the client.
228    ///
229    /// # Errors
230    ///
231    /// Returns an error if the WebSocket fails to close properly.
232    pub async fn close(&mut self) -> Result<(), Error> {
233        tracing::debug!("Closing");
234        self.signal.store(true, Ordering::Relaxed);
235
236        match tokio::time::timeout(Duration::from_secs(5), async {
237            if let Some(inner) = &self.inner {
238                inner.disconnect().await;
239            } else {
240                log::error!("Error on close: not connected");
241            }
242        })
243        .await
244        {
245            Ok(()) => {
246                tracing::debug!("Inner disconnected");
247            }
248            Err(_) => {
249                tracing::error!("Timeout waiting for inner client to disconnect");
250            }
251        }
252
253        log::debug!("Closed");
254
255        Ok(())
256    }
257
258    /// Subscribes to the given channels and product IDs.
259    ///
260    /// # Errors
261    ///
262    /// Returns an error if the subscription message cannot be sent.
263    async fn subscribe(
264        &self,
265        channels: Vec<CoinbaseIntxWsChannel>,
266        product_ids: Vec<Ustr>,
267    ) -> Result<(), CoinbaseIntxWsError> {
268        // Update active subscriptions
269        let mut active_subs = self.subscriptions.lock().await;
270        for channel in &channels {
271            active_subs
272                .entry(*channel)
273                .or_insert_with(Vec::new)
274                .extend(product_ids.clone());
275        }
276        tracing::debug!(
277            "Added active subscription(s): channels={channels:?}, product_ids={product_ids:?}"
278        );
279
280        let time = chrono::DateTime::<Utc>::from(SystemTime::now())
281            .timestamp()
282            .to_string();
283        let signature = self.credential.sign_ws(&time);
284        let message = CoinbaseIntxSubscription {
285            op: WsOperation::Subscribe,
286            product_ids: Some(product_ids),
287            channels,
288            time,
289            key: self.credential.api_key,
290            passphrase: self.credential.api_passphrase,
291            signature,
292        };
293
294        let json_txt = serde_json::to_string(&message)
295            .map_err(|e| CoinbaseIntxWsError::JsonError(e.to_string()))?;
296
297        if let Some(inner) = &self.inner {
298            if let Err(err) = inner.send_text(json_txt, None).await {
299                tracing::error!("Error sending message: {err:?}");
300            }
301        } else {
302            return Err(CoinbaseIntxWsError::ClientError(
303                "Cannot send message: not connected".to_string(),
304            ));
305        }
306
307        Ok(())
308    }
309
310    /// Unsubscribes from the given channels and product IDs.
311    async fn unsubscribe(
312        &self,
313        channels: Vec<CoinbaseIntxWsChannel>,
314        product_ids: Vec<Ustr>,
315    ) -> Result<(), CoinbaseIntxWsError> {
316        // Update active subscriptions
317        let mut active_subs = self.subscriptions.lock().await;
318        for channel in &channels {
319            if let Some(subs) = active_subs.get_mut(channel) {
320                for product_id in &product_ids {
321                    subs.retain(|pid| pid != product_id);
322                }
323                if subs.is_empty() {
324                    active_subs.remove(channel);
325                }
326            }
327        }
328        tracing::debug!(
329            "Removed active subscription(s): channels={channels:?}, product_ids={product_ids:?}"
330        );
331
332        let time = chrono::DateTime::<Utc>::from(SystemTime::now())
333            .timestamp()
334            .to_string();
335        let signature = self.credential.sign_ws(&time);
336        let message = CoinbaseIntxSubscription {
337            op: WsOperation::Unsubscribe,
338            product_ids: Some(product_ids),
339            channels,
340            time,
341            key: self.credential.api_key,
342            passphrase: self.credential.api_passphrase,
343            signature,
344        };
345
346        let json_txt = serde_json::to_string(&message)
347            .map_err(|e| CoinbaseIntxWsError::JsonError(e.to_string()))?;
348
349        if let Some(inner) = &self.inner {
350            if let Err(err) = inner.send_text(json_txt, None).await {
351                tracing::error!("Error sending message: {err:?}");
352            }
353        } else {
354            return Err(CoinbaseIntxWsError::ClientError(
355                "Cannot send message: not connected".to_string(),
356            ));
357        }
358
359        Ok(())
360    }
361
362    /// Resubscribes for all active subscriptions.
363    async fn resubscribe_all(&self) {
364        let subs = self.subscriptions.lock().await.clone();
365
366        for (channel, product_ids) in subs {
367            if product_ids.is_empty() {
368                continue;
369            }
370
371            tracing::debug!("Resubscribing: channel={channel}, product_ids={product_ids:?}");
372
373            if let Err(e) = self.subscribe(vec![channel], product_ids).await {
374                tracing::error!("Failed to resubscribe to channel {channel}: {e}");
375            }
376        }
377    }
378
379    /// Subscribes to instrument definition updates for the given instrument IDs.
380    /// Subscribes to instrument updates for the specified instruments.
381    ///
382    /// # Errors
383    ///
384    /// Returns an error if the subscription fails.
385    pub async fn subscribe_instruments(
386        &self,
387        instrument_ids: Vec<InstrumentId>,
388    ) -> Result<(), CoinbaseIntxWsError> {
389        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
390        self.subscribe(vec![CoinbaseIntxWsChannel::Instruments], product_ids)
391            .await
392    }
393
394    /// Subscribes to funding message streams for the given instrument IDs.
395    /// Subscribes to funding rate updates for the specified instruments.
396    ///
397    /// # Errors
398    ///
399    /// Returns an error if the subscription fails.
400    pub async fn subscribe_funding(
401        &self,
402        instrument_ids: Vec<InstrumentId>,
403    ) -> Result<(), CoinbaseIntxWsError> {
404        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
405        self.subscribe(vec![CoinbaseIntxWsChannel::Funding], product_ids)
406            .await
407    }
408
409    /// Subscribes to risk message streams for the given instrument IDs.
410    /// Subscribes to risk updates for the specified instruments.
411    ///
412    /// # Errors
413    ///
414    /// Returns an error if the subscription fails.
415    pub async fn subscribe_risk(
416        &self,
417        instrument_ids: Vec<InstrumentId>,
418    ) -> Result<(), CoinbaseIntxWsError> {
419        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
420        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
421            .await
422    }
423
424    /// Subscribes to order book (level 2) streams for the given instrument IDs.
425    /// Subscribes to order book snapshots and updates for the specified instruments.
426    ///
427    /// # Errors
428    ///
429    /// Returns an error if the subscription fails.
430    pub async fn subscribe_order_book(
431        &self,
432        instrument_ids: Vec<InstrumentId>,
433    ) -> Result<(), CoinbaseIntxWsError> {
434        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
435        self.subscribe(vec![CoinbaseIntxWsChannel::Level2], product_ids)
436            .await
437    }
438
439    /// Subscribes to quote (level 1) streams for the given instrument IDs.
440    /// Subscribes to top-of-book quote updates for the specified instruments.
441    ///
442    /// # Errors
443    ///
444    /// Returns an error if the subscription fails.
445    pub async fn subscribe_quotes(
446        &self,
447        instrument_ids: Vec<InstrumentId>,
448    ) -> Result<(), CoinbaseIntxWsError> {
449        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
450        self.subscribe(vec![CoinbaseIntxWsChannel::Level1], product_ids)
451            .await
452    }
453
454    /// Subscribes to trade (match) streams for the given instrument IDs.
455    /// Subscribes to trade updates for the specified instruments.
456    ///
457    /// # Errors
458    ///
459    /// Returns an error if the subscription fails.
460    pub async fn subscribe_trades(
461        &self,
462        instrument_ids: Vec<InstrumentId>,
463    ) -> Result<(), CoinbaseIntxWsError> {
464        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
465        self.subscribe(vec![CoinbaseIntxWsChannel::Match], product_ids)
466            .await
467    }
468
469    /// Subscribes to risk streams (for mark prices) for the given instrument IDs.
470    /// Subscribes to mark price updates for the specified instruments.
471    ///
472    /// # Errors
473    ///
474    /// Returns an error if the subscription fails.
475    pub async fn subscribe_mark_prices(
476        &self,
477        instrument_ids: Vec<InstrumentId>,
478    ) -> Result<(), CoinbaseIntxWsError> {
479        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
480        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
481            .await
482    }
483
484    /// Subscribes to risk streams (for index prices) for the given instrument IDs.
485    /// Subscribes to index price updates for the specified instruments.
486    ///
487    /// # Errors
488    ///
489    /// Returns an error if the subscription fails.
490    pub async fn subscribe_index_prices(
491        &self,
492        instrument_ids: Vec<InstrumentId>,
493    ) -> Result<(), CoinbaseIntxWsError> {
494        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
495        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
496            .await
497    }
498
499    /// Subscribes to bar (candle) streams for the given instrument IDs.
500    /// Subscribes to candlestick bar updates for the specified bar type.
501    ///
502    /// # Errors
503    ///
504    /// Returns an error if the subscription fails.
505    pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), CoinbaseIntxWsError> {
506        let channel = bar_spec_as_coinbase_channel(bar_type.spec())
507            .map_err(|e| CoinbaseIntxWsError::ClientError(e.to_string()))?;
508        let product_ids = vec![bar_type.standard().instrument_id().symbol.inner()];
509        self.subscribe(vec![channel], product_ids).await
510    }
511
512    /// Unsubscribes from instrument definition streams for the given instrument IDs.
513    /// Unsubscribes from instrument updates for the specified instruments.
514    ///
515    /// # Errors
516    ///
517    /// Returns an error if the unsubscription fails.
518    pub async fn unsubscribe_instruments(
519        &self,
520        instrument_ids: Vec<InstrumentId>,
521    ) -> Result<(), CoinbaseIntxWsError> {
522        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
523        self.unsubscribe(vec![CoinbaseIntxWsChannel::Instruments], product_ids)
524            .await
525    }
526
527    /// Unsubscribes from risk message streams for the given instrument IDs.
528    /// Unsubscribes from risk updates for the specified instruments.
529    ///
530    /// # Errors
531    ///
532    /// Returns an error if the unsubscription fails.
533    pub async fn unsubscribe_risk(
534        &self,
535        instrument_ids: Vec<InstrumentId>,
536    ) -> Result<(), CoinbaseIntxWsError> {
537        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
538        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
539            .await
540    }
541
542    /// Unsubscribes from funding message streams for the given instrument IDs.
543    /// Unsubscribes from funding updates for the specified instruments.
544    ///
545    /// # Errors
546    ///
547    /// Returns an error if the unsubscription fails.
548    pub async fn unsubscribe_funding(
549        &self,
550        instrument_ids: Vec<InstrumentId>,
551    ) -> Result<(), CoinbaseIntxWsError> {
552        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
553        self.unsubscribe(vec![CoinbaseIntxWsChannel::Funding], product_ids)
554            .await
555    }
556
557    /// Unsubscribes from order book (level 2) streams for the given instrument IDs.
558    /// Unsubscribes from order book updates for the specified instruments.
559    ///
560    /// # Errors
561    ///
562    /// Returns an error if the unsubscription fails.
563    pub async fn unsubscribe_order_book(
564        &self,
565        instrument_ids: Vec<InstrumentId>,
566    ) -> Result<(), CoinbaseIntxWsError> {
567        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
568        self.unsubscribe(vec![CoinbaseIntxWsChannel::Level2], product_ids)
569            .await
570    }
571
572    /// Unsubscribes from quote (level 1) streams for the given instrument IDs.
573    /// Unsubscribes from quote updates for the specified instruments.
574    ///
575    /// # Errors
576    ///
577    /// Returns an error if the unsubscription fails.
578    pub async fn unsubscribe_quotes(
579        &self,
580        instrument_ids: Vec<InstrumentId>,
581    ) -> Result<(), CoinbaseIntxWsError> {
582        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
583        self.unsubscribe(vec![CoinbaseIntxWsChannel::Level1], product_ids)
584            .await
585    }
586
587    /// Unsubscribes from trade (match) streams for the given instrument IDs.
588    /// Unsubscribes from trade updates for the specified instruments.
589    ///
590    /// # Errors
591    ///
592    /// Returns an error if the unsubscription fails.
593    pub async fn unsubscribe_trades(
594        &self,
595        instrument_ids: Vec<InstrumentId>,
596    ) -> Result<(), CoinbaseIntxWsError> {
597        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
598        self.unsubscribe(vec![CoinbaseIntxWsChannel::Match], product_ids)
599            .await
600    }
601
602    /// Unsubscribes from risk streams (for mark prices) for the given instrument IDs.
603    /// Unsubscribes from mark price updates for the specified instruments.
604    ///
605    /// # Errors
606    ///
607    /// Returns an error if the unsubscription fails.
608    pub async fn unsubscribe_mark_prices(
609        &self,
610        instrument_ids: Vec<InstrumentId>,
611    ) -> Result<(), CoinbaseIntxWsError> {
612        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
613        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
614            .await
615    }
616
617    /// Unsubscribes from risk streams (for index prices) for the given instrument IDs.
618    /// Unsubscribes from index price updates for the specified instruments.
619    ///
620    /// # Errors
621    ///
622    /// Returns an error if the unsubscription fails.
623    pub async fn unsubscribe_index_prices(
624        &self,
625        instrument_ids: Vec<InstrumentId>,
626    ) -> Result<(), CoinbaseIntxWsError> {
627        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
628        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
629            .await
630    }
631
632    /// Unsubscribes from bar (candle) streams for the given instrument IDs.
633    /// Unsubscribes from bar updates for the specified bar type.
634    ///
635    /// # Errors
636    ///
637    /// Returns an error if the unsubscription fails.
638    pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), CoinbaseIntxWsError> {
639        let channel = bar_spec_as_coinbase_channel(bar_type.spec())
640            .map_err(|e| CoinbaseIntxWsError::ClientError(e.to_string()))?;
641        let product_id = bar_type.standard().instrument_id().symbol.inner();
642        self.unsubscribe(vec![channel], vec![product_id]).await
643    }
644}
645
646fn instrument_ids_to_product_ids(instrument_ids: &[InstrumentId]) -> Vec<Ustr> {
647    instrument_ids.iter().map(|x| x.symbol.inner()).collect()
648}
649
650/// Provides a raw message handler for Coinbase International WebSocket feed.
651struct CoinbaseIntxFeedHandler {
652    reader: MessageReader,
653    signal: Arc<AtomicBool>,
654}
655
656impl CoinbaseIntxFeedHandler {
657    /// Creates a new [`CoinbaseIntxFeedHandler`] instance.
658    pub const fn new(reader: MessageReader, signal: Arc<AtomicBool>) -> Self {
659        Self { reader, signal }
660    }
661
662    /// Gets the next message from the WebSocket message stream.
663    async fn next(&mut self) -> Option<CoinbaseIntxWsMessage> {
664        // Timeout awaiting the next message before checking signal
665        let timeout = Duration::from_millis(10);
666
667        loop {
668            if self.signal.load(Ordering::Relaxed) {
669                tracing::debug!("Stop signal received");
670                break;
671            }
672
673            match tokio::time::timeout(timeout, self.reader.next()).await {
674                Ok(Some(msg)) => match msg {
675                    Ok(Message::Pong(_)) => {
676                        tracing::trace!("Received pong");
677                    }
678                    Ok(Message::Ping(_)) => {
679                        tracing::trace!("Received pong"); // Coinbase send ping frames as pongs
680                    }
681                    Ok(Message::Text(text)) => {
682                        match serde_json::from_str(&text) {
683                            Ok(event) => match &event {
684                                CoinbaseIntxWsMessage::Reject(msg) => {
685                                    tracing::error!("{msg:?}");
686                                }
687                                CoinbaseIntxWsMessage::Confirmation(msg) => {
688                                    tracing::debug!("{msg:?}");
689                                    continue;
690                                }
691                                CoinbaseIntxWsMessage::Instrument(_) => return Some(event),
692                                CoinbaseIntxWsMessage::Funding(_) => return Some(event),
693                                CoinbaseIntxWsMessage::Risk(_) => return Some(event),
694                                CoinbaseIntxWsMessage::BookSnapshot(_) => return Some(event),
695                                CoinbaseIntxWsMessage::BookUpdate(_) => return Some(event),
696                                CoinbaseIntxWsMessage::Quote(_) => return Some(event),
697                                CoinbaseIntxWsMessage::Trade(_) => return Some(event),
698                                CoinbaseIntxWsMessage::CandleSnapshot(_) => return Some(event),
699                                CoinbaseIntxWsMessage::CandleUpdate(_) => continue, // Ignore
700                            },
701                            Err(e) => {
702                                tracing::error!("Failed to parse message: {e}: {text}");
703                                break;
704                            }
705                        }
706                    }
707                    Ok(Message::Binary(msg)) => {
708                        tracing::debug!("Raw binary: {msg:?}");
709                    }
710                    Ok(Message::Close(_)) => {
711                        tracing::debug!("Received close message");
712                        return None;
713                    }
714                    Ok(msg) => {
715                        tracing::warn!("Unexpected message: {msg:?}");
716                    }
717                    Err(e) => {
718                        tracing::error!("{e}, stopping client");
719                        break; // Break as indicates a bug in the code
720                    }
721                },
722                Ok(None) => {
723                    tracing::info!("WebSocket stream closed");
724                    break;
725                }
726                Err(_) => {} // Timeout occurred awaiting a message, continue loop to check signal
727            }
728        }
729
730        log_task_stopped("message-streaming");
731        None
732    }
733}
734
735/// Provides a Posei parser for the Coinbase International WebSocket feed.
736struct CoinbaseIntxWsMessageHandler {
737    instruments: HashMap<Ustr, InstrumentAny>,
738    handler: CoinbaseIntxFeedHandler,
739    tx: tokio::sync::mpsc::UnboundedSender<PoseiWsMessage>,
740}
741
742impl CoinbaseIntxWsMessageHandler {
743    /// Creates a new [`CoinbaseIntxWsMessageHandler`] instance.
744    pub const fn new(
745        instruments: HashMap<Ustr, InstrumentAny>,
746        reader: MessageReader,
747        signal: Arc<AtomicBool>,
748        tx: tokio::sync::mpsc::UnboundedSender<PoseiWsMessage>,
749    ) -> Self {
750        let handler = CoinbaseIntxFeedHandler::new(reader, signal);
751        Self {
752            instruments,
753            handler,
754            tx,
755        }
756    }
757
758    /// Runs the WebSocket message feed.
759    async fn run(&mut self) {
760        while let Some(data) = self.next().await {
761            if let Err(e) = self.tx.send(data) {
762                tracing::error!("Error sending data: {e}");
763                break; // Stop processing on channel error
764            }
765        }
766    }
767
768    /// Gets the next message from the WebSocket message handler.
769    async fn next(&mut self) -> Option<PoseiWsMessage> {
770        let clock = get_atomic_clock_realtime();
771
772        while let Some(event) = self.handler.next().await {
773            match event {
774                CoinbaseIntxWsMessage::Instrument(msg) => {
775                    if let Some(inst) = parse_instrument_any(&msg, clock.get_time_ns()) {
776                        // Update instruments map
777                        self.instruments
778                            .insert(inst.raw_symbol().inner(), inst.clone());
779                        return Some(PoseiWsMessage::Instrument(inst));
780                    }
781                }
782                CoinbaseIntxWsMessage::Funding(msg) => {
783                    tracing::warn!("Received {msg:?}"); // TODO: Implement
784                }
785                CoinbaseIntxWsMessage::BookSnapshot(msg) => {
786                    if let Some(inst) = self.instruments.get(&msg.product_id) {
787                        match parse_orderbook_snapshot_msg(
788                            &msg,
789                            inst.id(),
790                            inst.price_precision(),
791                            inst.size_precision(),
792                            clock.get_time_ns(),
793                        ) {
794                            Ok(deltas) => {
795                                let deltas = OrderBookDeltas_API::new(deltas);
796                                let data = Data::Deltas(deltas);
797                                return Some(PoseiWsMessage::Data(data));
798                            }
799                            Err(e) => {
800                                tracing::error!("Failed to parse orderbook snapshot: {e}");
801                                return None;
802                            }
803                        }
804                    }
805                    tracing::error!("No instrument found for {}", msg.product_id);
806                    return None;
807                }
808                CoinbaseIntxWsMessage::BookUpdate(msg) => {
809                    if let Some(inst) = self.instruments.get(&msg.product_id) {
810                        match parse_orderbook_update_msg(
811                            &msg,
812                            inst.id(),
813                            inst.price_precision(),
814                            inst.size_precision(),
815                            clock.get_time_ns(),
816                        ) {
817                            Ok(deltas) => {
818                                let deltas = OrderBookDeltas_API::new(deltas);
819                                let data = Data::Deltas(deltas);
820                                return Some(PoseiWsMessage::Data(data));
821                            }
822                            Err(e) => {
823                                tracing::error!("Failed to parse orderbook update: {e}");
824                            }
825                        }
826                    } else {
827                        tracing::error!("No instrument found for {}", msg.product_id);
828                    }
829                }
830                CoinbaseIntxWsMessage::Quote(msg) => {
831                    if let Some(inst) = self.instruments.get(&msg.product_id) {
832                        match parse_quote_msg(
833                            &msg,
834                            inst.id(),
835                            inst.price_precision(),
836                            inst.size_precision(),
837                            clock.get_time_ns(),
838                        ) {
839                            Ok(quote) => return Some(PoseiWsMessage::Data(Data::Quote(quote))),
840                            Err(e) => {
841                                tracing::error!("Failed to parse quote: {e}");
842                            }
843                        }
844                    } else {
845                        tracing::error!("No instrument found for {}", msg.product_id);
846                    }
847                }
848                CoinbaseIntxWsMessage::Trade(msg) => {
849                    if let Some(inst) = self.instruments.get(&msg.product_id) {
850                        match parse_trade_msg(
851                            &msg,
852                            inst.id(),
853                            inst.price_precision(),
854                            inst.size_precision(),
855                            clock.get_time_ns(),
856                        ) {
857                            Ok(trade) => return Some(PoseiWsMessage::Data(Data::Trade(trade))),
858                            Err(e) => {
859                                tracing::error!("Failed to parse trade: {e}");
860                            }
861                        }
862                    } else {
863                        tracing::error!("No instrument found for {}", msg.product_id);
864                    }
865                }
866                CoinbaseIntxWsMessage::Risk(msg) => {
867                    if let Some(inst) = self.instruments.get(&msg.product_id) {
868                        let mark_price = match parse_mark_price_msg(
869                            &msg,
870                            inst.id(),
871                            inst.price_precision(),
872                            clock.get_time_ns(),
873                        ) {
874                            Ok(mark_price) => Some(mark_price),
875                            Err(e) => {
876                                tracing::error!("Failed to parse mark price: {e}");
877                                None
878                            }
879                        };
880
881                        let index_price = match parse_index_price_msg(
882                            &msg,
883                            inst.id(),
884                            inst.price_precision(),
885                            clock.get_time_ns(),
886                        ) {
887                            Ok(index_price) => Some(index_price),
888                            Err(e) => {
889                                tracing::error!("Failed to parse index price: {e}");
890                                None
891                            }
892                        };
893
894                        match (mark_price, index_price) {
895                            (Some(mark), Some(index)) => {
896                                return Some(PoseiWsMessage::MarkAndIndex((mark, index)));
897                            }
898                            (Some(mark), None) => return Some(PoseiWsMessage::MarkPrice(mark)),
899                            (None, Some(index)) => {
900                                return Some(PoseiWsMessage::IndexPrice(index));
901                            }
902                            (None, None) => continue,
903                        };
904                    }
905                    tracing::error!("No instrument found for {}", msg.product_id);
906                }
907                CoinbaseIntxWsMessage::CandleSnapshot(msg) => {
908                    if let Some(inst) = self.instruments.get(&msg.product_id) {
909                        match parse_candle_msg(
910                            &msg,
911                            inst.id(),
912                            inst.price_precision(),
913                            inst.size_precision(),
914                            clock.get_time_ns(),
915                        ) {
916                            Ok(bar) => return Some(PoseiWsMessage::Data(Data::Bar(bar))),
917                            Err(e) => {
918                                tracing::error!("Failed to parse candle: {e}");
919                            }
920                        }
921                    } else {
922                        tracing::error!("No instrument found for {}", msg.product_id);
923                    }
924                }
925                _ => {
926                    tracing::warn!("Not implemented: {event:?}");
927                }
928            }
929        }
930        None // Connection closed
931    }
932}