hashiverse_server_lib/transport/
tcp_transport.rs1use crate::tools::tools::get_public_ipv4;
15use anyhow::anyhow;
16use bytes::Bytes;
17use futures::{SinkExt, StreamExt};
18use hashiverse_lib::tools::config;
19use hashiverse_lib::transport::ddos::ddos::{DdosConnectionGuard, DdosProtection};
20use hashiverse_lib::transport::transport::{IncomingRequest, ServerState, TransportFactory, TransportServer};
21use hashiverse_lib::transport::transport_ownership_proof::{EmptyMarkerOwnershipProof, TransportOwnershipProof};
22use log::{info, trace, warn};
23use parking_lot::RwLock;
24use std::net::SocketAddr;
25use std::sync::Arc;
26use std::time::Duration;
27use tokio::net::{TcpListener, TcpStream};
28use tokio::sync::{Mutex, mpsc, oneshot};
29use tokio::time::sleep;
30use tokio_util::codec::{Framed, LengthDelimitedCodec};
31use tokio_util::sync::CancellationToken;
32use tokio_util::task::TaskTracker;
33use hashiverse_lib::transport::bootstrap_provider::bootstrap_provider::BootstrapProvider;
34
35#[derive(Clone)]
36pub struct TcpTransportFactory {
37 ddos_protection: Arc<dyn DdosProtection>,
38 bootstrap_provider: Arc<dyn BootstrapProvider>,
39}
40
41impl TcpTransportFactory {
42 pub fn new(ddos_protection: Arc<dyn DdosProtection>, bootstrap_provider: Arc<dyn BootstrapProvider>) -> Self {
43 Self { ddos_protection, bootstrap_provider }
44 }
45}
46
47pub struct TcpTransportServer {
48 address: String,
49 listener: Arc<Mutex<TcpListener>>,
50 state: Arc<RwLock<ServerState>>,
51 ddos_protection: Arc<dyn DdosProtection>,
52}
53
54impl TcpTransportServer {
55 async fn new(address: String, listener: TcpListener, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<Self> {
56 Ok(TcpTransportServer {
57 address,
58 listener: Arc::new(Mutex::new(listener)),
59 state: Arc::new(RwLock::new(ServerState::Created)),
60 ddos_protection,
61 })
62 }
63}
64
65#[async_trait::async_trait]
66impl TransportServer for TcpTransportServer {
67 fn get_address(&self) -> &String {
68 &self.address
69 }
70
71 fn get_transport_ownership_proof(&self) -> Arc<dyn TransportOwnershipProof> {
72 Arc::new(EmptyMarkerOwnershipProof)
76 }
77
78 async fn listen(&self, cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>) -> anyhow::Result<()> {
79 {
81 let mut state = self.state.write();
82 match *state {
83 ServerState::Listening => {
84 anyhow::bail!("server is already listening");
85 }
86 ServerState::Shutdown => {
87 anyhow::bail!("server has been shut down");
88 }
89 ServerState::Created => {
90 *state = ServerState::Listening;
91 }
92 }
93 }
94
95 async fn process_connection(cancellation_token: CancellationToken, handler: mpsc::Sender<IncomingRequest>, socket: TcpStream, socket_addr: SocketAddr, ddos_protection: Arc<dyn DdosProtection>) -> anyhow::Result<()> {
96 let ip = socket_addr.ip().to_string();
100 let ddos_connection_guard = match DdosConnectionGuard::try_new(ddos_protection, &ip) {
101 Some(guard) => Arc::new(guard),
102 None => {
103 trace!("DDoS: dropping TCP connection from {}", ip);
104 return Ok(());
105 }
106 };
107 let caller_address = ddos_connection_guard.ip().to_string();
108 let mut framed = LengthDelimitedCodec::builder().max_frame_length(config::PROTOCOL_MAX_BLOB_SIZE_REQUEST).new_framed(socket);
109
110 let result = tokio::select! {
111 _ = cancellation_token.cancelled() => { return Err(anyhow!("cancelled")) },
112
113 _ = sleep(Duration::from_secs(2)) => {
114 Err(anyhow::anyhow!("timeout waiting for request"))
115 },
116
117 next = framed.next() => {
118 match next {
119 None => Ok(()),
120 Some(Ok(bytes)) => {
121 let (reply_tx, reply_rx) = oneshot::channel();
123 handler.send(IncomingRequest::new(caller_address, bytes.into(), reply_tx, ddos_connection_guard)).await?;
124 let response = reply_rx.await?;
125 framed.send(response.to_bytes()).await?;
126 Ok(())
127 },
128 Some(Err(e)) => Err(anyhow!("error reading string from framed stream: {}", e)),
129 }
130 }
131 };
132
133 if let Err(e) = result {
134 warn!("error processing connection: {}", e);
135 }
136
137 Ok(())
138 }
139
140 let task_tracker = TaskTracker::new();
141
142 info!("listening on address {}", self.address);
143
144 loop {
145 let listener = self.listener.lock().await;
146
147 tokio::select! {
148 _ = cancellation_token.cancelled() => {
149 break;
150 },
151 Ok((socket, socket_addr)) = listener.accept() => {
152 task_tracker.spawn(
153 process_connection(cancellation_token.clone(), handler.clone(), socket, socket_addr, self.ddos_protection.clone())
154 );
155 },
156 }
157 }
158
159 info!("stopped listening on address {}", self.address);
161 drop(self.listener.lock().await);
162
163 info!("waiting for open connections to complete");
165 task_tracker.close();
166 task_tracker.wait().await;
167
168 info!("all open connections complete");
170 *self.state.write() = ServerState::Shutdown;
171
172 Ok(())
173 }
174}
175
176#[async_trait::async_trait]
177impl TransportFactory for TcpTransportFactory {
178 async fn get_bootstrap_addresses(&self) -> Vec<String> {
179 self.bootstrap_provider.get_bootstrap_addresses().await
180 }
181
182 async fn create_server(&self, _base_path: &str, port: u16, force_local_network: bool) -> anyhow::Result<Arc<dyn TransportServer>> {
183 let address_to_bind = format!("0.0.0.0:{}", port);
185 info!("bind on: {}", address_to_bind);
186 let listener = TcpListener::bind(address_to_bind).await?;
187
188 let address_bound_ip = get_public_ipv4(force_local_network).await?;
189 let address_bound_port = listener.local_addr()?.port();
190 let address = format!("{}:{}", address_bound_ip, address_bound_port);
191
192 let tcp_transport_server = Arc::new(TcpTransportServer::new(address, listener, self.ddos_protection.clone()).await?);
193 Ok(tcp_transport_server)
194 }
195
196 async fn rpc(&self, address: &str, bytes: Bytes) -> anyhow::Result<Bytes> {
197 let stream = TcpStream::connect(address).await?;
198 let mut framed: Framed<TcpStream, LengthDelimitedCodec> = Framed::new(stream, LengthDelimitedCodec::new());
202 framed.send(bytes).await?;
203
204 trace!("awaiting response");
206 tokio::select! {
207 _ = sleep(Duration::from_secs(2)) => {
208 trace!("timeout");
209 Err(anyhow::anyhow!("timeout waiting for response"))
210 },
211
212 next_frame = framed.next() => {
213 match next_frame {
214 Some(Ok(bytes)) => {
215 Ok(bytes.into())
216 }
217 Some(Err(e)) => {
218 Err(anyhow::anyhow!("error reading response: {}", e)) },
219 None => {
220 Err(anyhow::anyhow!("no response")) },
221 }
222 }
223 }
224 }
225}
226
227
228#[cfg(test)]
229mod tests {
230 use crate::transport::tcp_transport::TcpTransportFactory;
231 use hashiverse_lib::transport::bootstrap_provider::manual_bootstrap_provider::ManualBootstrapProvider;
232 use hashiverse_lib::transport::ddos::noop_ddos::NoopDdosProtection;
233 use hashiverse_lib::transport::transport::TransportFactory;
234 use std::sync::Arc;
235
236 #[tokio::test]
237 async fn rpc_test() -> anyhow::Result<()> {
238 let factory: Arc<dyn TransportFactory> = Arc::new(TcpTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
239 hashiverse_lib::transport::transport::tests::rpc_test(factory).await
240 }
241
242 #[tokio::test]
243 async fn bind_port_zero_test() -> anyhow::Result<()> {
244 let factory: Arc<dyn TransportFactory> = Arc::new(TcpTransportFactory::new(NoopDdosProtection::default(), ManualBootstrapProvider::default()));
245 hashiverse_lib::transport::transport::tests::bind_port_zero_test(factory).await
246 }
247}