hashiverse_lib/transport/
mem_transport.rs1use crate::tools::types::Id;
16use crate::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
17use crate::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
18use crate::transport::transport_ownership_proof::{EmptyMarkerOwnershipProof, TransportOwnershipProof};
19use anyhow::{Result, anyhow};
20use bytes::Bytes;
21use log::info;
22use parking_lot::RwLock;
23use std::collections::HashMap;
24use std::sync::Arc;
25use tokio::sync::{mpsc, oneshot};
26use tokio_util::sync::CancellationToken;
27use tokio_util::task::TaskTracker;
28use crate::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
29use crate::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
30use crate::transport::ddos::noop_ddos::NoopDdosProtection;
31
32#[derive(Debug)]
33struct RpcMessage {
34 caller_address: String,
35 bytes: Bytes,
36 response_tx: oneshot::Sender<Result<Bytes>>,
37}
38
39struct ServerEntry {
40 command_tx: mpsc::Sender<RpcMessage>,
41}
42
43struct ServerManager {
44 servers: Arc<RwLock<HashMap<u16, Arc<ServerEntry>>>>,
45}
46
47impl ServerManager {
48 pub fn new() -> Self {
49 ServerManager {
50 servers: Arc::new(RwLock::new(HashMap::new())),
51 }
52 }
53 pub async fn remove_server(&self, port: u16) {
54 let mut servers_locked = self.servers.write();
55 servers_locked.remove(&port);
56 }
57}
58
59pub struct MemTransportServer {
70 port: u16,
71 address: String,
72 server_manager: Arc<ServerManager>,
73 command_rx: Arc<RwLock<Option<mpsc::Receiver<RpcMessage>>>>,
74 state: Arc<RwLock<ServerState>>,
75 ddos_protection: Arc<dyn DdosProtection>,
76}
77
78#[async_trait::async_trait]
79impl TransportServer for MemTransportServer {
80 fn get_address(&self) -> &String {
81 &self.address
82 }
83
84 fn get_transport_ownership_proof(&self) -> Arc<dyn TransportOwnershipProof> {
85 Arc::new(EmptyMarkerOwnershipProof)
86 }
87
88 async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> Result<()> {
89 async fn process_connection(_cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>, message: RpcMessage, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<()> {
90 let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos_protection, message.caller_address.as_str()) {
94 Some(guard) => Arc::new(guard),
95 None => return Ok(()),
96 };
97 let caller_address = ddos_connection_guard.ip().to_string();
98 let (reply_tx, reply_rx) = oneshot::channel();
99 handler.send(IncomingRequest::new(caller_address, message.bytes, reply_tx, ddos_connection_guard)).await?;
100 let response = reply_rx.await?;
101 let _ = message.response_tx.send(Ok(response.to_bytes()));
102
103 Ok(())
104 }
105
106 {
108 let mut state = self.state.write();
109 match *state {
110 ServerState::Listening => {
111 anyhow::bail!("server is already listening");
112 }
113 ServerState::Shutdown => {
114 anyhow::bail!("server has been shut down");
115 }
116 ServerState::Created => {
117 *state = ServerState::Listening;
118 }
119 }
120 }
121
122 let task_tracker = TaskTracker::new();
123
124 info!("listening on address {}", self.address);
125
126 let mut receiver = match self.command_rx.write().take() {
128 Some(r) => r,
129 None => {
130 return Err(anyhow!("no receiver available on address {}", self.address));
131 }
132 };
133
134 loop {
135 tokio::select! {
136 _ = cancellation_token.cancelled() => {
137 break;
138 }
139
140 Some(msg) = receiver.recv() => {
141 task_tracker.spawn(
142 process_connection(cancellation_token.clone(), handler.clone(), msg, self.ddos_protection.clone())
143 );
144 }
145 }
146 }
147
148 info!("stopped listening on port {}", self.address);
149 self.server_manager.remove_server(self.port).await;
150
151 info!("waiting for open connections to complete");
153 task_tracker.close();
154 task_tracker.wait().await;
155
156 info!("all open connections complete");
158 *self.state.write() = ServerState::Shutdown;
159
160 Ok(())
161 }
162}
163
164#[derive(Clone)]
165pub struct MemTransportFactory {
166 server_manager: Arc<ServerManager>,
167 ddos_protection: Arc<dyn DdosProtection>,
168 bootstrap_provider: Arc<dyn BootstrapProvider>,
169}
170
171impl MemTransportFactory {
172 pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
173 Self {
174 server_manager: Arc::new(ServerManager::new()),
175 ddos_protection,
176 bootstrap_provider,
177 }
178 }
179
180 #[allow(clippy::should_implement_trait)] pub fn default() -> Arc<Self> {
182 Arc::new(Self::new(NoopDdosProtection::default(), ManualBootstrapProvider::new_mem_multiple()))
183 }
184}
185
186#[async_trait::async_trait]
187impl TransportFactory for MemTransportFactory {
188 async fn get_bootstrap_addresses(&self) -> Vec<String> {
189 self.bootstrap_provider.get_bootstrap_addresses().await
190 }
191
192 async fn create_server(&self, _base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
193 if !force_local_network {
194 return Err(anyhow!("only local network is supported"));
195 }
196
197 let mut servers_locked = self.server_manager.servers.write();
198
199 if servers_locked.contains_key(&port) {
200 return Err(anyhow!("server already exists on port {}", port));
201 }
202
203 let bound_port = match port {
205 0 => {
206 servers_locked.keys().max().unwrap_or(&0u16) + 1
207 }
208 _ => port
209 };
210
211 let address = format!("{}", bound_port);
212
213 let (tx, rx) = mpsc::channel::<RpcMessage>(256);
218
219 let mem_transport_server = Arc::new(MemTransportServer {
221 port: bound_port,
222 address,
223 server_manager: self.server_manager.clone(),
224 command_rx: Arc::new(RwLock::new(Some(rx))),
225 state: Arc::new(RwLock::new(ServerState::Created)),
226 ddos_protection: self.ddos_protection.clone(),
227 });
228
229 servers_locked.insert(bound_port, Arc::new(ServerEntry { command_tx: tx }));
231
232 Ok(mem_transport_server)
233 }
234
235 async fn rpc(&self, address: &str, bytes: Bytes) -> Result<Bytes> {
236 let port: u16 = address.parse()?;
237
238 let server_entry = {
239 let servers = self.server_manager.servers.read();
240 let server_entry = servers.get(&port).ok_or_else(|| anyhow::anyhow!("no server found with port {}", port))?;
241 server_entry.clone()
242 };
243
244 let (response_tx, response_rx) = oneshot::channel();
249
250 let message = RpcMessage { caller_address: format!("mem:{}", Id::random()), bytes, response_tx };
252
253 server_entry.command_tx.send(message).await.map_err(|e| anyhow::anyhow!("failed to send request: {}", e))?;
257
258 response_rx.await.map_err(|_| anyhow::anyhow!("server disconnected before responding"))?
260 }
261}
262
263
264#[cfg(test)]
265mod tests {
266 use crate::transport::mem_transport::MemTransportFactory;
267 use crate::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
268 use crate::transport::ddos::noop_ddos::NoopDdosProtection;
269 use std::sync::Arc;
270
271 #[tokio::test]
272 async fn rpc_test() -> anyhow::Result<()> {
273 let factory: Arc<dyn crate::transport::transport::TransportFactory> = Arc::new(MemTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
274 crate::transport::transport::tests::rpc_test(factory).await
275 }
276
277 #[tokio::test]
278 async fn bind_port_zero_test() -> anyhow::Result<()> {
279 let factory: Arc<dyn crate::transport::transport::TransportFactory> = Arc::new(MemTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
280 crate::transport::transport::tests::bind_port_zero_test(factory).await
281 }
282}