1use std::{
32 fmt::Debug,
33 path::Path,
34 sync::{
35 Arc,
36 atomic::{AtomicU8, Ordering},
37 },
38 time::Duration,
39};
40
41use bytes::Bytes;
42use nautilus_cryptography::providers::install_cryptographic_provider;
43#[cfg(feature = "python")]
44use pyo3::prelude::*;
45use tokio::{
46 io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
47 net::TcpStream,
48};
49use tokio_tungstenite::{
50 MaybeTlsStream,
51 tungstenite::{Error, client::IntoClientRequest, stream::Mode},
52};
53
54use crate::{
55 backoff::ExponentialBackoff,
56 error::SendError,
57 fix::process_fix_buffer,
58 logging::{log_task_aborted, log_task_started, log_task_stopped},
59 mode::ConnectionMode,
60 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
61};
62
63type TcpWriter = WriteHalf<MaybeTlsStream<TcpStream>>;
64type TcpReader = ReadHalf<MaybeTlsStream<TcpStream>>;
65pub type TcpMessageHandler = dyn Fn(&[u8]) + Send + Sync;
66
67#[derive(Debug, Clone)]
69#[cfg_attr(
70 feature = "python",
71 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
72)]
73pub struct SocketConfig {
74 pub url: String,
76 pub mode: Mode,
78 pub suffix: Vec<u8>,
80 #[cfg(feature = "python")]
81 pub py_handler: Option<Arc<PyObject>>,
83 pub heartbeat: Option<(u64, Vec<u8>)>,
85 pub reconnect_timeout_ms: Option<u64>,
87 pub reconnect_delay_initial_ms: Option<u64>,
89 pub reconnect_delay_max_ms: Option<u64>,
91 pub reconnect_backoff_factor: Option<f64>,
93 pub reconnect_jitter_ms: Option<u64>,
95 pub certs_dir: Option<String>,
97}
98
99#[derive(Debug)]
101pub enum WriterCommand {
102 Update(TcpWriter),
104 Send(Bytes),
106}
107
108#[cfg_attr(
124 feature = "python",
125 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
126)]
127struct SocketClientInner {
128 config: SocketConfig,
129 connector: Option<Connector>,
130 read_task: Arc<tokio::task::JoinHandle<()>>,
131 write_task: tokio::task::JoinHandle<()>,
132 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
133 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
134 connection_mode: Arc<AtomicU8>,
135 reconnect_timeout: Duration,
136 backoff: ExponentialBackoff,
137 handler: Option<Arc<TcpMessageHandler>>,
138}
139
140impl SocketClientInner {
141 pub async fn connect_url(
142 config: SocketConfig,
143 handler: Option<Arc<TcpMessageHandler>>,
144 ) -> anyhow::Result<Self> {
145 install_cryptographic_provider();
146
147 let SocketConfig {
148 url,
149 mode,
150 heartbeat,
151 suffix,
152 #[cfg(feature = "python")]
153 py_handler,
154 reconnect_timeout_ms,
155 reconnect_delay_initial_ms,
156 reconnect_delay_max_ms,
157 reconnect_backoff_factor,
158 reconnect_jitter_ms,
159 certs_dir,
160 } = &config;
161 let connector = if let Some(dir) = certs_dir {
162 let config = create_tls_config_from_certs_dir(Path::new(dir))?;
163 Some(Connector::Rustls(Arc::new(config)))
164 } else {
165 None
166 };
167
168 let (reader, writer) = Self::tls_connect_with_server(url, *mode, connector.clone()).await?;
169 tracing::debug!("Connected");
170
171 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
172
173 let read_task = Arc::new(Self::spawn_read_task(
174 connection_mode.clone(),
175 reader,
176 handler.clone(),
177 #[cfg(feature = "python")]
178 py_handler.clone(),
179 suffix.clone(),
180 ));
181
182 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
183
184 let write_task =
185 Self::spawn_write_task(connection_mode.clone(), writer, writer_rx, suffix.clone());
186
187 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
189 Self::spawn_heartbeat_task(
190 connection_mode.clone(),
191 heartbeat.clone(),
192 writer_tx.clone(),
193 )
194 });
195
196 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
197 let backoff = ExponentialBackoff::new(
198 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
199 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
200 reconnect_backoff_factor.unwrap_or(1.5),
201 reconnect_jitter_ms.unwrap_or(100),
202 true, )?;
204
205 Ok(Self {
206 config,
207 connector,
208 read_task,
209 write_task,
210 writer_tx,
211 heartbeat_task,
212 connection_mode,
213 reconnect_timeout,
214 backoff,
215 handler,
216 })
217 }
218
219 pub async fn tls_connect_with_server(
220 url: &str,
221 mode: Mode,
222 connector: Option<Connector>,
223 ) -> Result<(TcpReader, TcpWriter), Error> {
224 tracing::debug!("Connecting to {url}");
225 let tcp_result = TcpStream::connect(url).await;
226
227 match tcp_result {
228 Ok(stream) => {
229 tracing::debug!("TCP connection established, proceeding with TLS");
230 let request = url.into_client_request()?;
231 tcp_tls(&request, mode, stream, connector)
232 .await
233 .map(tokio::io::split)
234 }
235 Err(e) => {
236 tracing::error!("TCP connection failed: {e:?}");
237 Err(Error::Io(e))
238 }
239 }
240 }
241
242 async fn reconnect(&mut self) -> Result<(), Error> {
247 tracing::debug!("Reconnecting");
248
249 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
250 tracing::debug!("Reconnect aborted due to disconnect state");
251 return Ok(());
252 }
253
254 tokio::time::timeout(self.reconnect_timeout, async {
255 let SocketConfig {
256 url,
257 mode,
258 heartbeat: _,
259 suffix,
260 #[cfg(feature = "python")]
261 py_handler,
262 reconnect_timeout_ms: _,
263 reconnect_delay_initial_ms: _,
264 reconnect_backoff_factor: _,
265 reconnect_delay_max_ms: _,
266 reconnect_jitter_ms: _,
267 certs_dir: _,
268 } = &self.config;
269 let connector = self.connector.clone();
271 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
273
274 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
275 tracing::debug!("Reconnect aborted mid-flight (after connect)");
276 return Ok(());
277 }
278 tracing::debug!("Connected");
279
280 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer)) {
281 tracing::error!("{e}");
282 }
283
284 tokio::time::sleep(Duration::from_millis(100)).await;
286
287 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
288 tracing::debug!("Reconnect aborted mid-flight (after delay)");
289 return Ok(());
290 }
291
292 if !self.read_task.is_finished() {
293 self.read_task.abort();
294 log_task_aborted("read");
295 }
296
297 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
299 tracing::debug!("Reconnect aborted mid-flight (before spawn read)");
300 return Ok(());
301 }
302
303 self.connection_mode
305 .store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
306
307 self.read_task = Arc::new(Self::spawn_read_task(
309 self.connection_mode.clone(),
310 reader,
311 self.handler.clone(),
312 #[cfg(feature = "python")]
313 py_handler.clone(),
314 suffix.clone(),
315 ));
316
317 tracing::debug!("Reconnect succeeded");
318 Ok(())
319 })
320 .await
321 .map_err(|_| {
322 Error::Io(std::io::Error::new(
323 std::io::ErrorKind::TimedOut,
324 format!(
325 "reconnection timed out after {}s",
326 self.reconnect_timeout.as_secs_f64()
327 ),
328 ))
329 })?
330 }
331
332 #[inline]
339 #[must_use]
340 pub fn is_alive(&self) -> bool {
341 !self.read_task.is_finished()
342 }
343
344 #[must_use]
345 fn spawn_read_task(
346 connection_state: Arc<AtomicU8>,
347 mut reader: TcpReader,
348 handler: Option<Arc<TcpMessageHandler>>,
349 #[cfg(feature = "python")] py_handler: Option<Arc<PyObject>>,
350 suffix: Vec<u8>,
351 ) -> tokio::task::JoinHandle<()> {
352 log_task_started("read");
353
354 let check_interval = Duration::from_millis(10);
356
357 tokio::task::spawn(async move {
358 let mut buf = Vec::new();
359
360 loop {
361 if !ConnectionMode::from_atomic(&connection_state).is_active() {
362 break;
363 }
364
365 match tokio::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
366 Ok(Ok(0)) => {
368 tracing::debug!("Connection closed by server");
369 break;
370 }
371 Ok(Err(e)) => {
372 tracing::debug!("Connection ended: {e}");
373 break;
374 }
375 Ok(Ok(bytes)) => {
377 tracing::trace!("Received <binary> {bytes} bytes");
378
379 if let Some(handler) = &handler {
380 process_fix_buffer(&mut buf, handler);
381 } else {
382 while let Some((i, _)) = &buf
383 .windows(suffix.len())
384 .enumerate()
385 .find(|(_, pair)| pair.eq(&suffix))
386 {
387 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
388 data.truncate(data.len() - suffix.len());
389
390 if let Some(handler) = &handler {
391 handler(&data);
392 }
393
394 #[cfg(feature = "python")]
395 if let Some(py_handler) = &py_handler {
396 if let Err(e) = Python::with_gil(|py| {
397 py_handler.call1(py, (data.as_slice(),))
398 }) {
399 tracing::error!("Call to handler failed: {e}");
400 break;
401 }
402 }
403 }
404 }
405 }
406 Err(_) => {
407 continue;
409 }
410 }
411 }
412
413 log_task_stopped("read");
414 })
415 }
416
417 fn spawn_write_task(
418 connection_state: Arc<AtomicU8>,
419 writer: TcpWriter,
420 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
421 suffix: Vec<u8>,
422 ) -> tokio::task::JoinHandle<()> {
423 log_task_started("write");
424
425 let check_interval = Duration::from_millis(10);
427
428 tokio::task::spawn(async move {
429 let mut active_writer = writer;
430
431 loop {
432 if matches!(
433 ConnectionMode::from_atomic(&connection_state),
434 ConnectionMode::Disconnect | ConnectionMode::Closed
435 ) {
436 break;
437 }
438
439 match tokio::time::timeout(check_interval, writer_rx.recv()).await {
440 Ok(Some(msg)) => {
441 let mode = ConnectionMode::from_atomic(&connection_state);
443 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
444 break;
445 }
446
447 match msg {
448 WriterCommand::Update(new_writer) => {
449 tracing::debug!("Received new writer");
450
451 tokio::time::sleep(Duration::from_millis(100)).await;
453
454 _ = active_writer.shutdown().await;
457
458 active_writer = new_writer;
459 tracing::debug!("Updated writer");
460 }
461 _ if mode.is_reconnect() => {
462 tracing::warn!("Skipping message while reconnecting, {msg:?}");
463 continue;
464 }
465 WriterCommand::Send(msg) => {
466 if let Err(e) = active_writer.write_all(&msg).await {
467 tracing::error!("Failed to send message: {e}");
468 tracing::warn!("Writer triggering reconnect");
470 connection_state
471 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
472 continue;
473 }
474 if let Err(e) = active_writer.write_all(&suffix).await {
475 tracing::error!("Failed to send message: {e}");
476 }
477 }
478 }
479 }
480 Ok(None) => {
481 tracing::debug!("Writer channel closed, terminating writer task");
483 break;
484 }
485 Err(_) => {
486 continue;
488 }
489 }
490 }
491
492 _ = active_writer.shutdown().await;
495
496 log_task_stopped("write");
497 })
498 }
499
500 fn spawn_heartbeat_task(
501 connection_state: Arc<AtomicU8>,
502 heartbeat: (u64, Vec<u8>),
503 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
504 ) -> tokio::task::JoinHandle<()> {
505 log_task_started("heartbeat");
506 let (interval_secs, message) = heartbeat;
507
508 tokio::task::spawn(async move {
509 let interval = Duration::from_secs(interval_secs);
510
511 loop {
512 tokio::time::sleep(interval).await;
513
514 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
515 ConnectionMode::Active => {
516 let msg = WriterCommand::Send(message.clone().into());
517
518 match writer_tx.send(msg) {
519 Ok(()) => tracing::trace!("Sent heartbeat to writer task"),
520 Err(e) => {
521 tracing::error!("Failed to send heartbeat to writer task: {e}");
522 }
523 }
524 }
525 ConnectionMode::Reconnect => continue,
526 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
527 }
528 }
529
530 log_task_stopped("heartbeat");
531 })
532 }
533}
534
535impl Drop for SocketClientInner {
536 fn drop(&mut self) {
537 if !self.read_task.is_finished() {
538 self.read_task.abort();
539 log_task_aborted("read");
540 }
541
542 if !self.write_task.is_finished() {
543 self.write_task.abort();
544 log_task_aborted("write");
545 }
546
547 if let Some(ref handle) = self.heartbeat_task.take() {
548 if !handle.is_finished() {
549 handle.abort();
550 log_task_aborted("heartbeat");
551 }
552 }
553 }
554}
555
556#[cfg_attr(
557 feature = "python",
558 pyo3::pyclass(module = "posei_trader.core.nautilus_pyo3.network")
559)]
560pub struct SocketClient {
561 pub(crate) controller_task: tokio::task::JoinHandle<()>,
562 pub(crate) connection_mode: Arc<AtomicU8>,
563 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
564}
565
566impl Debug for SocketClient {
567 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
568 f.debug_struct(stringify!(SocketClient)).finish()
569 }
570}
571
572impl SocketClient {
573 pub async fn connect(
579 config: SocketConfig,
580 handler: Option<Arc<TcpMessageHandler>>,
581 #[cfg(feature = "python")] post_connection: Option<PyObject>,
582 #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
583 #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
584 ) -> anyhow::Result<Self> {
585 let inner = SocketClientInner::connect_url(config, handler).await?;
586 let writer_tx = inner.writer_tx.clone();
587 let connection_mode = inner.connection_mode.clone();
588
589 let controller_task = Self::spawn_controller_task(
590 inner,
591 connection_mode.clone(),
592 #[cfg(feature = "python")]
593 post_reconnection,
594 #[cfg(feature = "python")]
595 post_disconnection,
596 );
597
598 #[cfg(feature = "python")]
599 if let Some(handler) = post_connection {
600 Python::with_gil(|py| match handler.call0(py) {
601 Ok(_) => tracing::debug!("Called `post_connection` handler"),
602 Err(e) => tracing::error!("Error calling `post_connection` handler: {e}"),
603 });
604 }
605
606 Ok(Self {
607 controller_task,
608 connection_mode,
609 writer_tx,
610 })
611 }
612
613 #[must_use]
615 pub fn connection_mode(&self) -> ConnectionMode {
616 ConnectionMode::from_atomic(&self.connection_mode)
617 }
618
619 #[inline]
624 #[must_use]
625 pub fn is_active(&self) -> bool {
626 self.connection_mode().is_active()
627 }
628
629 #[inline]
634 #[must_use]
635 pub fn is_reconnecting(&self) -> bool {
636 self.connection_mode().is_reconnect()
637 }
638
639 #[inline]
643 #[must_use]
644 pub fn is_disconnecting(&self) -> bool {
645 self.connection_mode().is_disconnect()
646 }
647
648 #[inline]
654 #[must_use]
655 pub fn is_closed(&self) -> bool {
656 self.connection_mode().is_closed()
657 }
658
659 pub async fn close(&self) {
664 self.connection_mode
665 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
666
667 match tokio::time::timeout(Duration::from_secs(5), async {
668 while !self.is_closed() {
669 tokio::time::sleep(Duration::from_millis(10)).await;
670 }
671
672 if !self.controller_task.is_finished() {
673 self.controller_task.abort();
674 log_task_aborted("controller");
675 }
676 })
677 .await
678 {
679 Ok(()) => {
680 log_task_stopped("controller");
681 }
682 Err(_) => {
683 tracing::error!("Timeout waiting for controller task to finish");
684 }
685 }
686 }
687
688 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
694 if self.is_closed() {
695 return Err(SendError::Closed);
696 }
697
698 let timeout = Duration::from_secs(2);
699 let check_interval = Duration::from_millis(1);
700
701 if !self.is_active() {
702 tracing::debug!("Waiting for client to become ACTIVE before sending...");
703
704 let inner = tokio::time::timeout(timeout, async {
705 loop {
706 if self.is_active() {
707 return Ok(());
708 }
709 if matches!(
710 self.connection_mode(),
711 ConnectionMode::Disconnect | ConnectionMode::Closed
712 ) {
713 return Err(());
714 }
715 tokio::time::sleep(check_interval).await;
716 }
717 })
718 .await
719 .map_err(|_| SendError::Timeout)?;
720 inner.map_err(|()| SendError::Closed)?;
721 }
722
723 let msg = WriterCommand::Send(data.into());
724 self.writer_tx
725 .send(msg)
726 .map_err(|e| SendError::BrokenPipe(e.to_string()))
727 }
728
729 fn spawn_controller_task(
730 mut inner: SocketClientInner,
731 connection_mode: Arc<AtomicU8>,
732 #[cfg(feature = "python")] post_reconnection: Option<PyObject>,
733 #[cfg(feature = "python")] post_disconnection: Option<PyObject>,
734 ) -> tokio::task::JoinHandle<()> {
735 tokio::task::spawn(async move {
736 log_task_started("controller");
737
738 let check_interval = Duration::from_millis(10);
739
740 loop {
741 tokio::time::sleep(check_interval).await;
742 let mode = ConnectionMode::from_atomic(&connection_mode);
743
744 if mode.is_disconnect() {
745 tracing::debug!("Disconnecting");
746
747 let timeout = Duration::from_secs(5);
748 if tokio::time::timeout(timeout, async {
749 tokio::time::sleep(Duration::from_millis(100)).await;
751
752 if !inner.read_task.is_finished() {
753 inner.read_task.abort();
754 log_task_aborted("read");
755 }
756
757 if let Some(task) = &inner.heartbeat_task {
758 if !task.is_finished() {
759 task.abort();
760 log_task_aborted("heartbeat");
761 }
762 }
763 })
764 .await
765 .is_err()
766 {
767 tracing::error!("Shutdown timed out after {}s", timeout.as_secs());
768 }
769
770 tracing::debug!("Closed");
771
772 #[cfg(feature = "python")]
773 if let Some(ref handler) = post_disconnection {
774 Python::with_gil(|py| match handler.call0(py) {
775 Ok(_) => tracing::debug!("Called `post_disconnection` handler"),
776 Err(e) => {
777 tracing::error!("Error calling `post_disconnection` handler: {e}");
778 }
779 });
780 }
781 break; }
783
784 if mode.is_reconnect() || (mode.is_active() && !inner.is_alive()) {
785 match inner.reconnect().await {
786 Ok(()) => {
787 tracing::debug!("Reconnected successfully");
788 inner.backoff.reset();
789 #[cfg(feature = "python")]
791 {
792 if ConnectionMode::from_atomic(&connection_mode).is_active() {
793 if let Some(ref handler) = post_reconnection {
794 Python::with_gil(|py| match handler.call0(py) {
795 Ok(_) => tracing::debug!(
796 "Called `post_reconnection` handler"
797 ),
798 Err(e) => tracing::error!(
799 "Error calling `post_reconnection` handler: {e}"
800 ),
801 });
802 }
803 } else {
804 tracing::debug!(
805 "Skipping post_reconnection handlers due to disconnect state"
806 );
807 }
808 }
809 }
810 Err(e) => {
811 let duration = inner.backoff.next_duration();
812 tracing::warn!("Reconnect attempt failed: {e}");
813 if !duration.is_zero() {
814 tracing::warn!("Backing off for {}s...", duration.as_secs_f64());
815 }
816 tokio::time::sleep(duration).await;
817 }
818 }
819 }
820 }
821 inner
822 .connection_mode
823 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
824
825 log_task_stopped("controller");
826 })
827 }
828}
829
830impl Drop for SocketClient {
832 fn drop(&mut self) {
833 if !self.controller_task.is_finished() {
834 self.controller_task.abort();
835 log_task_aborted("controller");
836 }
837 }
838}
839
840#[cfg(test)]
844#[cfg(feature = "python")]
845#[cfg(target_os = "linux")] mod tests {
847 use std::ffi::CString;
848
849 use nautilus_common::testing::wait_until_async;
850 use nautilus_core::python::IntoPyObjectPoseiExt;
851 use pyo3::prepare_freethreaded_python;
852 use tokio::{
853 io::{AsyncReadExt, AsyncWriteExt},
854 net::{TcpListener, TcpStream},
855 sync::Mutex,
856 task,
857 time::{Duration, sleep},
858 };
859
860 use super::*;
861
862 fn create_handler() -> PyObject {
863 let code_raw = r"
864class Counter:
865 def __init__(self):
866 self.count = 0
867 self.check = False
868
869 def handler(self, bytes):
870 msg = bytes.decode()
871 if msg == 'ping':
872 self.count += 1
873 elif msg == 'heartbeat message':
874 self.check = True
875
876 def get_check(self):
877 return self.check
878
879 def get_count(self):
880 return self.count
881
882counter = Counter()
883";
884 let code = CString::new(code_raw).unwrap();
885 let filename = CString::new("test".to_string()).unwrap();
886 let module = CString::new("test".to_string()).unwrap();
887 Python::with_gil(|py| {
888 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
889 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
890
891 counter
892 .getattr(py, "handler")
893 .unwrap()
894 .into_py_any_unwrap(py)
895 })
896 }
897
898 async fn bind_test_server() -> (u16, TcpListener) {
899 let listener = TcpListener::bind("127.0.0.1:0")
900 .await
901 .expect("Failed to bind ephemeral port");
902 let port = listener.local_addr().unwrap().port();
903 (port, listener)
904 }
905
906 async fn run_echo_server(mut socket: TcpStream) {
907 let mut buf = Vec::new();
908 loop {
909 match socket.read_buf(&mut buf).await {
910 Ok(0) => {
911 break;
912 }
913 Ok(_n) => {
914 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
915 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
916 line.truncate(line.len() - 2);
918
919 if line == b"close" {
920 let _ = socket.shutdown().await;
921 return;
922 }
923
924 let mut echo_data = line;
925 echo_data.extend_from_slice(b"\r\n");
926 if socket.write_all(&echo_data).await.is_err() {
927 break;
928 }
929 }
930 }
931 Err(e) => {
932 eprintln!("Server read error: {e}");
933 break;
934 }
935 }
936 }
937 }
938
939 #[tokio::test]
940 async fn test_basic_send_receive() {
941 prepare_freethreaded_python();
942
943 let (port, listener) = bind_test_server().await;
944 let server_task = task::spawn(async move {
945 let (socket, _) = listener.accept().await.unwrap();
946 run_echo_server(socket).await;
947 });
948
949 let config = SocketConfig {
950 url: format!("127.0.0.1:{port}"),
951 mode: Mode::Plain,
952 suffix: b"\r\n".to_vec(),
953 py_handler: Some(Arc::new(create_handler())),
954 heartbeat: None,
955 reconnect_timeout_ms: None,
956 reconnect_delay_initial_ms: None,
957 reconnect_backoff_factor: None,
958 reconnect_delay_max_ms: None,
959 reconnect_jitter_ms: None,
960 certs_dir: None,
961 };
962
963 let client = SocketClient::connect(config, None, None, None, None)
964 .await
965 .expect("Client connect failed unexpectedly");
966
967 client.send_bytes(b"Hello".into()).await.unwrap();
968 client.send_bytes(b"World".into()).await.unwrap();
969
970 sleep(Duration::from_millis(100)).await;
972
973 client.send_bytes(b"close".into()).await.unwrap();
974 server_task.await.unwrap();
975 assert!(!client.is_closed());
976 }
977
978 #[tokio::test]
979 async fn test_reconnect_fail_exhausted() {
980 prepare_freethreaded_python();
981
982 let (port, listener) = bind_test_server().await;
983 drop(listener); let config = SocketConfig {
986 url: format!("127.0.0.1:{port}"),
987 mode: Mode::Plain,
988 suffix: b"\r\n".to_vec(),
989 py_handler: Some(Arc::new(create_handler())),
990 heartbeat: None,
991 reconnect_timeout_ms: None,
992 reconnect_delay_initial_ms: None,
993 reconnect_backoff_factor: None,
994 reconnect_delay_max_ms: None,
995 reconnect_jitter_ms: None,
996 certs_dir: None,
997 };
998
999 let client_res = SocketClient::connect(config, None, None, None, None).await;
1000 assert!(
1001 client_res.is_err(),
1002 "Should fail quickly with no server listening"
1003 );
1004 }
1005
1006 #[tokio::test]
1007 async fn test_user_disconnect() {
1008 prepare_freethreaded_python();
1009
1010 let (port, listener) = bind_test_server().await;
1011 let server_task = task::spawn(async move {
1012 let (socket, _) = listener.accept().await.unwrap();
1013 let mut buf = [0u8; 1024];
1014 let _ = socket.try_read(&mut buf);
1015
1016 loop {
1017 sleep(Duration::from_secs(1)).await;
1018 }
1019 });
1020
1021 let config = SocketConfig {
1022 url: format!("127.0.0.1:{port}"),
1023 mode: Mode::Plain,
1024 suffix: b"\r\n".to_vec(),
1025 py_handler: Some(Arc::new(create_handler())),
1026 heartbeat: None,
1027 reconnect_timeout_ms: None,
1028 reconnect_delay_initial_ms: None,
1029 reconnect_backoff_factor: None,
1030 reconnect_delay_max_ms: None,
1031 reconnect_jitter_ms: None,
1032 certs_dir: None,
1033 };
1034
1035 let client = SocketClient::connect(config, None, None, None, None)
1036 .await
1037 .unwrap();
1038
1039 client.close().await;
1040 assert!(client.is_closed());
1041 server_task.abort();
1042 }
1043
1044 #[tokio::test]
1045 async fn test_heartbeat() {
1046 prepare_freethreaded_python();
1047
1048 let (port, listener) = bind_test_server().await;
1049 let received = Arc::new(Mutex::new(Vec::new()));
1050 let received2 = received.clone();
1051
1052 let server_task = task::spawn(async move {
1053 let (socket, _) = listener.accept().await.unwrap();
1054
1055 let mut buf = Vec::new();
1056 loop {
1057 match socket.try_read_buf(&mut buf) {
1058 Ok(0) => break,
1059 Ok(_) => {
1060 while let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
1061 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1062 line.truncate(line.len() - 2);
1063 received2.lock().await.push(line);
1064 }
1065 }
1066 Err(_) => {
1067 tokio::time::sleep(Duration::from_millis(10)).await;
1068 }
1069 }
1070 }
1071 });
1072
1073 let heartbeat = Some((1, b"ping".to_vec()));
1075
1076 let config = SocketConfig {
1077 url: format!("127.0.0.1:{port}"),
1078 mode: Mode::Plain,
1079 suffix: b"\r\n".to_vec(),
1080 py_handler: Some(Arc::new(create_handler())),
1081 heartbeat,
1082 reconnect_timeout_ms: None,
1083 reconnect_delay_initial_ms: None,
1084 reconnect_backoff_factor: None,
1085 reconnect_delay_max_ms: None,
1086 reconnect_jitter_ms: None,
1087 certs_dir: None,
1088 };
1089
1090 let client = SocketClient::connect(config, None, None, None, None)
1091 .await
1092 .unwrap();
1093
1094 sleep(Duration::from_secs(3)).await;
1096
1097 {
1098 let lock = received.lock().await;
1099 let pings = lock
1100 .iter()
1101 .filter(|line| line == &&b"ping".to_vec())
1102 .count();
1103 assert!(
1104 pings >= 2,
1105 "Expected at least 2 heartbeat pings; got {pings}"
1106 );
1107 }
1108
1109 client.close().await;
1110 server_task.abort();
1111 }
1112
1113 #[tokio::test]
1114 async fn test_python_handler_error() {
1115 prepare_freethreaded_python();
1116
1117 let (port, listener) = bind_test_server().await;
1118 let server_task = task::spawn(async move {
1119 let (socket, _) = listener.accept().await.unwrap();
1120 run_echo_server(socket).await;
1121 });
1122
1123 let code_raw = r#"
1124def handler(bytes_data):
1125 txt = bytes_data.decode()
1126 if "ERR" in txt:
1127 raise ValueError("Simulated error in handler")
1128 return
1129"#;
1130 let code = CString::new(code_raw).unwrap();
1131 let filename = CString::new("test".to_string()).unwrap();
1132 let module = CString::new("test".to_string()).unwrap();
1133
1134 let py_handler = Some(Python::with_gil(|py| {
1135 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
1136 let func = pymod.getattr("handler").unwrap();
1137 Arc::new(func.into_py_any_unwrap(py))
1138 }));
1139
1140 let config = SocketConfig {
1141 url: format!("127.0.0.1:{port}"),
1142 mode: Mode::Plain,
1143 suffix: b"\r\n".to_vec(),
1144 py_handler,
1145 heartbeat: None,
1146 reconnect_timeout_ms: None,
1147 reconnect_delay_initial_ms: None,
1148 reconnect_backoff_factor: None,
1149 reconnect_delay_max_ms: None,
1150 reconnect_jitter_ms: None,
1151 certs_dir: None,
1152 };
1153
1154 let client = SocketClient::connect(config, None, None, None, None)
1155 .await
1156 .expect("Client connect failed unexpectedly");
1157
1158 client.send_bytes(b"hello".into()).await.unwrap();
1159 sleep(Duration::from_millis(100)).await;
1160
1161 client.send_bytes(b"ERR".into()).await.unwrap();
1162 sleep(Duration::from_secs(1)).await;
1163
1164 assert!(client.is_active());
1165
1166 client.close().await;
1167
1168 assert!(client.is_closed());
1169 server_task.abort();
1170 }
1171
1172 #[tokio::test]
1173 async fn test_reconnect_success() {
1174 prepare_freethreaded_python();
1175
1176 let (port, listener) = bind_test_server().await;
1177
1178 let server_task = task::spawn(async move {
1182 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1184
1185 sleep(Duration::from_millis(500)).await;
1187 let _ = socket.shutdown().await;
1188
1189 sleep(Duration::from_millis(500)).await;
1191
1192 let (socket, _) = listener.accept().await.expect("Second accept failed");
1194 run_echo_server(socket).await;
1195 });
1196
1197 let config = SocketConfig {
1198 url: format!("127.0.0.1:{port}"),
1199 mode: Mode::Plain,
1200 suffix: b"\r\n".to_vec(),
1201 py_handler: Some(Arc::new(create_handler())),
1202 heartbeat: None,
1203 reconnect_timeout_ms: Some(5_000),
1204 reconnect_delay_initial_ms: Some(500),
1205 reconnect_delay_max_ms: Some(5_000),
1206 reconnect_backoff_factor: Some(2.0),
1207 reconnect_jitter_ms: Some(50),
1208 certs_dir: None,
1209 };
1210
1211 let client = SocketClient::connect(config, None, None, None, None)
1212 .await
1213 .expect("Client connect failed unexpectedly");
1214
1215 assert!(client.is_active(), "Client should start as active");
1217
1218 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1221
1222 client
1223 .send_bytes(b"TestReconnect".into())
1224 .await
1225 .expect("Send failed");
1226
1227 client.close().await;
1228 server_task.abort();
1229 }
1230}
1231
1232#[cfg(test)]
1233mod rust_tests {
1234 use tokio::{
1235 net::TcpListener,
1236 task,
1237 time::{Duration, sleep},
1238 };
1239
1240 use super::*;
1241
1242 #[tokio::test]
1243 async fn test_reconnect_then_close() {
1244 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1246 let port = listener.local_addr().unwrap().port();
1247
1248 let server = task::spawn(async move {
1250 if let Ok((mut sock, _)) = listener.accept().await {
1251 let _ = sock.shutdown();
1252 }
1253 sleep(Duration::from_secs(1)).await;
1255 });
1256
1257 let config = SocketConfig {
1259 url: format!("127.0.0.1:{port}"),
1260 mode: Mode::Plain,
1261 suffix: b"\r\n".to_vec(),
1262 #[cfg(feature = "python")]
1263 py_handler: None,
1264 heartbeat: None,
1265 reconnect_timeout_ms: Some(1_000),
1266 reconnect_delay_initial_ms: Some(50),
1267 reconnect_delay_max_ms: Some(100),
1268 reconnect_backoff_factor: Some(1.0),
1269 reconnect_jitter_ms: Some(0),
1270 certs_dir: None,
1271 };
1272
1273 let client = {
1275 #[cfg(feature = "python")]
1276 {
1277 SocketClient::connect(config.clone(), None, None, None, None)
1278 .await
1279 .unwrap()
1280 }
1281 #[cfg(not(feature = "python"))]
1282 {
1283 SocketClient::connect(config.clone(), None).await.unwrap()
1284 }
1285 };
1286
1287 sleep(Duration::from_millis(100)).await;
1289
1290 client.close().await;
1292 assert!(client.is_closed());
1293 server.abort();
1294 }
1295}