Skip to main content

hashiverse_lib/transport/
mem_transport.rs

1//! # In-memory transport for tests
2//!
3//! A fully synchronous in-process implementation of
4//! [`crate::transport::transport::TransportFactory`] and
5//! [`crate::transport::transport::TransportServer`]: every "server" registers itself in
6//! a shared registry keyed by id, and every "client" request is just a channel send into
7//! the matching server's request queue.
8//!
9//! This is what makes the integration-test harness fast and deterministic. A virtual
10//! network of dozens of servers + clients runs inside a single test binary, with no
11//! sockets, no TLS negotiation, no PoW relaxation fudge, and no flaky wall-clock
12//! ordering. Swap `MemTransportFactory` for the HTTPS factory and the same protocol code
13//! runs on the real network.
14
15use crate::tools::types::Id;
16use crate::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
17use crate::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
18use crate::transport::transport_ownership_proof::{EmptyMarkerOwnershipProof, TransportOwnershipProof};
19use anyhow::{Result, anyhow};
20use bytes::Bytes;
21use log::info;
22use parking_lot::RwLock;
23use std::collections::HashMap;
24use std::sync::Arc;
25use tokio::sync::{mpsc, oneshot};
26use tokio_util::sync::CancellationToken;
27use tokio_util::task::TaskTracker;
28use crate::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
29use crate::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
30use crate::transport::ddos::noop_ddos::NoopDdosProtection;
31
32#[derive(Debug)]
33struct RpcMessage {
34    caller_address: String,
35    bytes: Bytes,
36    response_tx: oneshot::Sender<Result<Bytes>>,
37}
38
39struct ServerEntry {
40    command_tx: mpsc::Sender<RpcMessage>,
41}
42
43struct ServerManager {
44    servers: Arc<RwLock<HashMap<u16, Arc<ServerEntry>>>>,
45}
46
47impl ServerManager {
48    pub fn new() -> Self {
49        ServerManager {
50            servers: Arc::new(RwLock::new(HashMap::new())),
51        }
52    }
53    pub async fn remove_server(&self, port: u16) {
54        let mut servers_locked = self.servers.write();
55        servers_locked.remove(&port);
56    }
57}
58
59/// An entirely in-process [`TransportServer`] used by the integration test harness.
60///
61/// Servers created by `MemTransportFactory` share a process-wide registry keyed by port;
62/// "sending a request" from one client to one server becomes a channel send on the registry.
63/// There is no serialization to sockets, no DNS, no kernel — which makes this both
64/// dramatically faster than a real network and fully deterministic when paired with a virtual
65/// [`crate::tools::time_provider::time_provider::TimeProvider`]. Port `0` is translated to a
66/// freshly-allocated port number, mirroring the semantics of a real OS bind.
67///
68/// Not for production use: there is nothing here that crosses a process or host boundary.
69pub struct MemTransportServer {
70    port: u16,
71    address: String,
72    server_manager: Arc<ServerManager>,
73    command_rx: Arc<RwLock<Option<mpsc::Receiver<RpcMessage>>>>,
74    state: Arc<RwLock<ServerState>>,
75    ddos_protection: Arc<dyn DdosProtection>,
76}
77
78#[async_trait::async_trait]
79impl TransportServer for MemTransportServer {
80    fn get_address(&self) -> &String {
81        &self.address
82    }
83
84    fn get_transport_ownership_proof(&self) -> Arc<dyn TransportOwnershipProof> {
85        Arc::new(EmptyMarkerOwnershipProof)
86    }
87
88    async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> Result<()> {
89        async fn process_connection(_cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>, message: RpcMessage, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<()> {
90            // trace!("accepted connection");
91            // scopeguard::defer! { trace!("dropped connection"); }
92            // trace!("received packet={:?}", message.bytes);
93            let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos_protection, message.caller_address.as_str()) {
94                Some(guard) => Arc::new(guard),
95                None => return Ok(()),
96            };
97            let caller_address = ddos_connection_guard.ip().to_string();
98            let (reply_tx, reply_rx) = oneshot::channel();
99            handler.send(IncomingRequest::new(caller_address, message.bytes, reply_tx, ddos_connection_guard)).await?;
100            let response = reply_rx.await?;
101            let _ = message.response_tx.send(Ok(response.to_bytes()));
102
103            Ok(())
104        }
105
106        // Check that we can transition to listening
107        {
108            let mut state = self.state.write();
109            match *state {
110                ServerState::Listening => {
111                    anyhow::bail!("server is already listening");
112                }
113                ServerState::Shutdown => {
114                    anyhow::bail!("server has been shut down");
115                }
116                ServerState::Created => {
117                    *state = ServerState::Listening;
118                }
119            }
120        }
121
122        let task_tracker = TaskTracker::new();
123
124        info!("listening on address {}", self.address);
125
126        // Take ownership of the receiver.  If there's no receiver, we can't listen.  Should never happen!
127        let mut receiver = match self.command_rx.write().take() {
128            Some(r) => r,
129            None => {
130                return Err(anyhow!("no receiver available on address {}", self.address));
131            }
132        };
133
134        loop {
135            tokio::select! {
136                _ = cancellation_token.cancelled() => {
137                    break;
138                }
139
140                Some(msg) = receiver.recv() => {
141                    task_tracker.spawn(
142                        process_connection(cancellation_token.clone(), handler.clone(), msg, self.ddos_protection.clone())
143                    );
144                }
145            }
146        }
147
148        info!("stopped listening on port {}", self.address);
149        self.server_manager.remove_server(self.port).await;
150
151        // Wait for existing connections to complete
152        info!("waiting for open connections to complete");
153        task_tracker.close();
154        task_tracker.wait().await;
155
156        // Notify the "shutdown" coroutine that we have successfully shutdown
157        info!("all open connections complete");
158        *self.state.write() = ServerState::Shutdown;
159
160        Ok(())
161    }
162}
163
164#[derive(Clone)]
165pub struct MemTransportFactory {
166    server_manager: Arc<ServerManager>,
167    ddos_protection: Arc<dyn DdosProtection>,
168    bootstrap_provider: Arc<dyn BootstrapProvider>,
169}
170
171impl MemTransportFactory {
172    pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
173        Self {
174            server_manager: Arc::new(ServerManager::new()),
175            ddos_protection,
176            bootstrap_provider,
177        }
178    }
179
180    #[allow(clippy::should_implement_trait)] // wraps Arc<Self>, can't satisfy the Default trait
181    pub fn default() -> Arc<Self> {
182        Arc::new(Self::new(NoopDdosProtection::default(), ManualBootstrapProvider::new_mem_multiple()))
183    }
184}
185
186#[async_trait::async_trait]
187impl TransportFactory for MemTransportFactory {
188    async fn get_bootstrap_addresses(&self) -> Vec<String> {
189        self.bootstrap_provider.get_bootstrap_addresses().await
190    }
191
192    async fn create_server(&self, _base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
193        if !force_local_network {
194            return Err(anyhow!("only local network is supported"));
195        }
196
197        let mut servers_locked = self.server_manager.servers.write();
198
199        if servers_locked.contains_key(&port) {
200            return Err(anyhow!("server already exists on port {}", port));
201        }
202
203        // If they have requested port 0, pick the first available empty slot
204        let bound_port = match port {
205            0 => {
206                servers_locked.keys().max().unwrap_or(&0u16) + 1
207            }
208            _ => port
209        };
210
211        let address = format!("{}", bound_port);
212
213        // Create channels for communication.  Buffer sized generously so bursts of
214        // concurrent in-memory RPCs don't trip capacity limits; backpressure is still
215        // applied via awaited `send` below, which is closer to the behaviour of a real
216        // TCP socket than `try_send`'s fail-fast.
217        let (tx, rx) = mpsc::channel::<RpcMessage>(256);
218
219        // Create the server
220        let mem_transport_server = Arc::new(MemTransportServer {
221            port: bound_port,
222            address,
223            server_manager: self.server_manager.clone(),
224            command_rx: Arc::new(RwLock::new(Some(rx))),
225            state: Arc::new(RwLock::new(ServerState::Created)),
226            ddos_protection: self.ddos_protection.clone(),
227        });
228
229        // Store the server and its sender in the map
230        servers_locked.insert(bound_port, Arc::new(ServerEntry { command_tx: tx }));
231
232        Ok(mem_transport_server)
233    }
234
235    async fn rpc(&self, address: &str, bytes: Bytes) -> Result<Bytes> {
236        let port: u16 = address.parse()?;
237
238        let server_entry = {
239            let servers = self.server_manager.servers.read();
240            let server_entry = servers.get(&port).ok_or_else(|| anyhow::anyhow!("no server found with port {}", port))?;
241            server_entry.clone()
242        };
243
244        // trace!("connected to: {:?}", address);
245        // defer! { trace!("disconnected from: {:?}", &address); }
246
247        // Create a oneshot channel for the response
248        let (response_tx, response_rx) = oneshot::channel();
249
250        // Create the message
251        let message = RpcMessage { caller_address: format!("mem:{}", Id::random()), bytes, response_tx };
252
253        // Send the message to the server using the sender from the server entry.
254        // Awaited `send` applies backpressure if the receiver is saturated, rather than
255        // dropping the request — mirrors how a real TCP transport would behave.
256        server_entry.command_tx.send(message).await.map_err(|e| anyhow::anyhow!("failed to send request: {}", e))?;
257
258        // Wait for the response
259        response_rx.await.map_err(|_| anyhow::anyhow!("server disconnected before responding"))?
260    }
261}
262
263
264#[cfg(test)]
265mod tests {
266    use crate::transport::mem_transport::MemTransportFactory;
267    use crate::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
268    use crate::transport::ddos::noop_ddos::NoopDdosProtection;
269    use std::sync::Arc;
270
271    #[tokio::test]
272    async fn rpc_test() -> anyhow::Result<()> {
273        let factory: Arc<dyn crate::transport::transport::TransportFactory> = Arc::new(MemTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
274        crate::transport::transport::tests::rpc_test(factory).await
275    }
276
277    #[tokio::test]
278    async fn bind_port_zero_test() -> anyhow::Result<()> {
279        let factory: Arc<dyn crate::transport::transport::TransportFactory> = Arc::new(MemTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
280        crate::transport::transport::tests::bind_port_zero_test(factory).await
281    }
282}