Skip to main content

hashiverse_server_lib/transport/
tcp_transport.rs

1//! # Plain-text TCP transport
2//!
3//! An unencrypted transport for local testing and private-LAN deployments where TLS
4//! is unnecessary. Frames requests and responses with `tokio-util`'s
5//! `LengthDelimitedCodec` — each message is prefixed with a u32 length, so there's
6//! no application-level ambiguity about where one message ends and the next begins.
7//!
8//! Uses the same pluggable
9//! [`hashiverse_lib::transport::ddos::ddos::DdosProtection`] trait as the HTTPS
10//! transport, so `NoopDdosProtection`, `MemDdos`, or the ipset-backed protection can
11//! all drop in unchanged. Per-request timeout is 2 seconds; anything slower is
12//! considered either a buggy client or a slow-loris probe.
13
14use crate::tools::tools::get_public_ipv4;
15use anyhow::anyhow;
16use bytes::Bytes;
17use futures::{SinkExt, StreamExt};
18use hashiverse_lib::tools::config;
19use hashiverse_lib::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
20use hashiverse_lib::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
21use hashiverse_lib::transport::transport_ownership_proof::{EmptyMarkerOwnershipProof, TransportOwnershipProof};
22use log::{info, trace, warn};
23use parking_lot::RwLock;
24use std::net::SocketAddr;
25use std::sync::Arc;
26use std::time::Duration;
27use tokio::net::{TcpListener, TcpStream};
28use tokio::sync::{Mutex, mpsc, oneshot};
29use tokio::time::sleep;
30use tokio_util::codec::{Framed, LengthDelimitedCodec};
31use tokio_util::sync::CancellationToken;
32use tokio_util::task::TaskTracker;
33use hashiverse_lib::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
34
35#[derive(Clone)]
36pub struct TcpTransportFactory {
37    ddos_protection: Arc<dyn DdosProtection>,
38    bootstrap_provider: Arc<dyn BootstrapProvider>,
39}
40
41impl TcpTransportFactory {
42    pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
43        Self { ddos_protection, bootstrap_provider }
44    }
45}
46
47pub struct TcpTransportServer {
48    address: String,
49    listener: Arc<Mutex<TcpListener>>,
50    state: Arc<RwLock<ServerState>>,
51    ddos_protection: Arc<dyn DdosProtection>,
52}
53
54impl TcpTransportServer {
55    async fn new(address: String, listener: TcpListener, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<Self> {
56        Ok(TcpTransportServer {
57            address,
58            listener: Arc::new(Mutex::new(listener)),
59            state: Arc::new(RwLock::new(ServerState::Created)),
60            ddos_protection,
61        })
62    }
63}
64
65#[async_trait::async_trait]
66impl TransportServer for TcpTransportServer {
67    fn get_address(&self) -> &String {
68        &self.address
69    }
70
71    fn get_transport_ownership_proof(&self) -> Arc<dyn TransportOwnershipProof> {
72        // Plain TCP is for trusted-network deployments only — no crypto proof exists, so
73        // the empty-marker proof matches mem-transport's behaviour and lets V2 announces
74        // flow within a TCP-only network without crossing wires with HTTPS peers.
75        Arc::new(EmptyMarkerOwnershipProof)
76    }
77
78    async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> anyhow::Result<()> {
79        // Check that we can transition to listening
80        {
81            let mut state = self.state.write();
82            match *state {
83                ServerState::Listening => {
84                    anyhow::bail!("server is already listening");
85                }
86                ServerState::Shutdown => {
87                    anyhow::bail!("server has been shut down");
88                }
89                ServerState::Created => {
90                    *state = ServerState::Listening;
91                }
92            }
93        }
94
95        async fn process_connection(cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>, socket: TcpStream, socket_addr: SocketAddr, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<()> {
96            // trace!("accepted connection on: {socket_addr}");
97            // defer! { trace!("dropped connection from: {socket_addr}"); }
98
99            let ip = socket_addr.ip().to_string();
100            let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos_protection, &ip) {
101                Some(guard) => Arc::new(guard),
102                None => {
103                    trace!("DDoS: dropping TCP connection from {}", ip);
104                    return Ok(());
105                }
106            };
107            let caller_address = ddos_connection_guard.ip().to_string();
108            let mut framed = LengthDelimitedCodec::builder().max_frame_length(config::PROTOCOL_MAX_BLOB_SIZE_REQUEST).new_framed(socket);
109
110            let result = tokio::select! {
111                _ = cancellation_token.cancelled() => { return Err(anyhow!("cancelled")) },
112
113                _ = sleep(Duration::from_secs(2)) => {
114                    Err(anyhow::anyhow!("timeout waiting for request"))
115                },
116
117                next = framed.next() => {
118                    match next {
119                        None => Ok(()),
120                        Some(Ok(bytes)) => {
121                            // trace!("received bytes={:?}", bytes);
122                            let (reply_tx, reply_rx) = oneshot::channel();
123                            handler.send(IncomingRequest::new(caller_address, bytes.into(), reply_tx, ddos_connection_guard)).await?;
124                            let response = reply_rx.await?;
125                            framed.send(response.to_bytes()).await?;
126                            Ok(())
127                        },
128                        Some(Err(e)) => Err(anyhow!("error reading string from framed stream: {}", e)),
129                    }
130                }
131            };
132
133            if let Err(e) = result {
134                warn!("error processing connection: {}", e);
135            }
136
137            Ok(())
138        }
139
140        let task_tracker = TaskTracker::new();
141
142        info!("listening on address {}", self.address);
143
144        loop {
145            let listener = self.listener.lock().await;
146
147            tokio::select! {
148                _ = cancellation_token.cancelled() => {
149                    break;
150                },
151                Ok((socket, socket_addr)) = listener.accept() => {
152                    task_tracker.spawn(
153                        process_connection(cancellation_token.clone(), handler.clone(), socket, socket_addr, self.ddos_protection.clone())
154                    );
155                },
156            }
157        }
158
159        // Stop accepting new connections
160        info!("stopped listening on address {}", self.address);
161        drop(self.listener.lock().await);
162
163        // Wait for existing connections to complete
164        info!("waiting for open connections to complete");
165        task_tracker.close();
166        task_tracker.wait().await;
167
168        // Notify the "shutdown" coroutine that we have successfully shutdown
169        info!("all open connections complete");
170        *self.state.write() = ServerState::Shutdown;
171
172        Ok(())
173    }
174}
175
176#[async_trait::async_trait]
177impl TransportFactory for TcpTransportFactory {
178    async fn get_bootstrap_addresses(&self) -> Vec<String> {
179        self.bootstrap_provider.get_bootstrap_addresses().await
180    }
181
182    async fn create_server(&self, _base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
183        // Deliberately IPv4-only.  See https_transport.rs for the reasoning.
184        let address_to_bind = format!("0.0.0.0:{}", port);
185        info!("bind on: {}", address_to_bind);
186        let listener = TcpListener::bind(address_to_bind).await?;
187
188        let address_bound_ip = get_public_ipv4(force_local_network).await?;
189        let address_bound_port = listener.local_addr()?.port();
190        let address = format!("{}:{}", address_bound_ip, address_bound_port);
191
192        let tcp_transport_server = Arc::new(TcpTransportServer::new(address, listener, self.ddos_protection.clone()).await?);
193        Ok(tcp_transport_server)
194    }
195
196    async fn rpc(&self, address: &str, bytes: Bytes) -> anyhow::Result<Bytes> {
197        let stream = TcpStream::connect(address).await?;
198        // trace!("connected to: {}", address.address);
199        // defer! { trace!("disconnected from: {}", &address.address); }
200
201        let mut framed: Framed<TcpStream, LengthDelimitedCodec> = Framed::new(stream, LengthDelimitedCodec::new());
202        framed.send(bytes).await?;
203
204        // Return the response
205        trace!("awaiting response");
206        tokio::select! {
207            _ = sleep(Duration::from_secs(2)) => {
208                trace!("timeout");
209                Err(anyhow::anyhow!("timeout waiting for response"))
210            },
211
212            next_frame = framed.next() => {
213                match next_frame {
214                    Some(Ok(bytes)) => {
215                        Ok(bytes.into())
216                    }
217                    Some(Err(e)) => {
218                        Err(anyhow::anyhow!("error reading response: {}", e)) },
219                    None => {
220                        Err(anyhow::anyhow!("no response")) },
221                }
222           }
223        }
224    }
225}
226
227
228#[cfg(test)]
229mod tests {
230    use crate::transport::tcp_transport::TcpTransportFactory;
231    use hashiverse_lib::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
232    use hashiverse_lib::transport::ddos::noop_ddos::NoopDdosProtection;
233    use hashiverse_lib::transport::transport::TransportFactory;
234    use std::sync::Arc;
235
236    #[tokio::test]
237    async fn rpc_test() -> anyhow::Result<()> {
238        let factory: Arc<dyn TransportFactory> = Arc::new(TcpTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
239        hashiverse_lib::transport::transport::tests::rpc_test(factory).await
240    }
241
242    #[tokio::test]
243    async fn bind_port_zero_test() -> anyhow::Result<()> {
244        let factory: Arc<dyn TransportFactory> = Arc::new(TcpTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
245        hashiverse_lib::transport::transport::tests::bind_port_zero_test(factory).await
246    }
247}