hashiverse_lib/transport/ddos/
mem_ddos.rs1use crate::tools::time_provider::moka_clock::TimeProviderMokaClock;
16use crate::tools::time_provider::time_provider::TimeProvider;
17use crate::transport::ddos::ddos::{DdosProtection, DdosScore};
18use log::warn;
19use moka::sync::Cache;
20use parking_lot::Mutex;
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24
25pub struct MemDdosProtection {
35 score_threshold: f64,
36 decay_per_second: f64,
37 bad_request_penalty: f64,
38 max_connections_per_ip: usize,
39 scores: Cache<String, Arc<Mutex<DdosScore>>>,
40 connections: Mutex<HashMap<String, usize>>,
41 time_provider: Arc<dyn TimeProvider>,
42}
43
44impl MemDdosProtection {
45 pub fn new(score_threshold: f64, decay_per_second: f64, bad_request_penalty: f64, max_connections_per_ip: usize, time_provider: Arc<dyn TimeProvider>) -> Self {
46 let idle_secs = if decay_per_second > 0.0 {
48 (score_threshold / decay_per_second * 2.0).ceil() as u64
49 } else {
50 3600 };
52 let scores = Cache::builder()
54 .time_to_idle(Duration::from_secs(idle_secs))
55 .external_clock(Arc::new(TimeProviderMokaClock::new(time_provider.clone())))
56 .build();
57 Self {
58 score_threshold,
59 decay_per_second,
60 bad_request_penalty,
61 max_connections_per_ip,
62 scores,
63 connections: Mutex::new(HashMap::new()),
64 time_provider,
65 }
66 }
67
68 fn increment_score(&self, ip: &str, points: f64) -> f64 {
69 let now = self.time_provider.current_time_millis();
70 let entry = self.scores.get_with(ip.to_string(), || Arc::new(Mutex::new(DdosScore::new())));
71 entry.lock().increment(points, self.decay_per_second, now)
72 }
73
74 fn is_score_banned(&self, ip: &str) -> bool {
75 let now = self.time_provider.current_time_millis();
76 self.scores
77 .get(ip)
78 .map(|entry| entry.lock().current(self.decay_per_second, now) >= self.score_threshold)
79 .unwrap_or(false)
80 }
81}
82
83impl DdosProtection for MemDdosProtection {
84 fn allow_request(&self, ip: &str) -> bool {
85 self.increment_score(ip, 1.0) < self.score_threshold
86 }
87
88 fn report_bad_request(&self, ip: &str) {
89 let score = self.increment_score(ip, self.bad_request_penalty);
90 if score >= self.score_threshold {
91 warn!("DDoS: {} blocked (score={:.1})", ip, score);
92 }
93 }
94
95 fn try_acquire_connection(&self, ip: &str) -> bool {
96 if self.is_score_banned(ip) {
97 return false;
98 }
99 let mut connections = self.connections.lock();
100 let count = connections.entry(ip.to_string()).or_insert(0);
101 if *count >= self.max_connections_per_ip {
102 return false;
103 }
104 *count += 1;
105 true
106 }
107
108 fn release_connection(&self, ip: &str) {
109 let mut connections = self.connections.lock();
110 if let Some(count) = connections.get_mut(ip) {
111 *count = count.saturating_sub(1);
112 if *count == 0 {
113 connections.remove(ip);
114 }
115 }
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::tools::config;
123 use crate::tools::time::TimeMillis;
124 use crate::tools::time_provider::manual_time_provider::ManualTimeProvider;
125 use crate::transport::ddos::ddos::DdosConnectionGuard;
126
127 fn make_ddos() -> Arc<MemDdosProtection> {
128 Arc::new(MemDdosProtection::new(
129 config::SERVER_DDOS_SCORE_THRESHOLD,
130 config::SERVER_DDOS_DECAY_PER_SECOND,
131 config::SERVER_DDOS_BAD_REQUEST_PENALTY,
132 config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP,
133 Arc::new(ManualTimeProvider::default()),
134 ))
135 }
136
137 #[test]
138 fn connection_guard_limits_per_ip() {
139 let ddos = make_ddos();
140 let ip = "1.2.3.4";
141
142 let mut guards = vec![];
143 for _ in 0..config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP {
144 let guard = DdosConnectionGuard::try_new(ddos.clone(), ip);
145 assert!(guard.is_some(), "should acquire slot within limit");
146 guards.push(guard.unwrap());
147 }
148
149 let over_limit = DdosConnectionGuard::try_new(ddos.clone(), ip);
150 assert!(over_limit.is_none(), "should be blocked at per-IP cap");
151
152 drop(guards.pop().unwrap());
154 let recovered = DdosConnectionGuard::try_new(ddos.clone(), ip);
155 assert!(recovered.is_some(), "should acquire after release");
156 }
157
158 #[test]
159 fn connection_guard_independent_ips() {
160 let ddos = make_ddos();
161
162 let guard_a = DdosConnectionGuard::try_new(ddos.clone(), "1.1.1.1");
163 let guard_b = DdosConnectionGuard::try_new(ddos.clone(), "2.2.2.2");
164
165 assert!(guard_a.is_some());
166 assert!(guard_b.is_some());
167 }
168
169 #[test]
170 fn banned_ip_cannot_acquire_connection() {
171 let ddos = Arc::new(MemDdosProtection::new(3.0, 0.0, 3.0, 8, Arc::new(ManualTimeProvider::default())));
172 let ip = "1.2.3.4";
173
174 while ddos.allow_request(ip) {}
176
177 let guard = DdosConnectionGuard::try_new(ddos.clone(), ip);
178 assert!(guard.is_none(), "banned IP should not acquire a connection slot");
179 }
180
181 #[test]
182 fn guard_report_bad_request_delegates() {
183 let ddos = Arc::new(MemDdosProtection::new(100.0, 0.0, 5.0, 8, Arc::new(ManualTimeProvider::default())));
184 let ip = "5.6.7.8";
185 let guard = DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap();
186
187 guard.report_bad_request();
188 assert!(guard.allow_request());
190 }
191
192 #[test]
193 fn connection_count_drops_to_zero_after_all_guards_released() {
194 let ddos = make_ddos();
195 let ip = "9.9.9.9";
196
197 let guard = DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap();
198 drop(guard);
199
200 let mut guards = vec![];
201 for _ in 0..config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP {
202 guards.push(DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap());
203 }
204 assert!(DdosConnectionGuard::try_new(ddos.clone(), ip).is_none());
205 }
206
207 #[test]
208 fn allow_request_returns_false_at_threshold() {
209 let threshold = 5.0;
211 let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, 1.0, 8, Arc::new(ManualTimeProvider::default())));
212 let ip = "3.3.3.3";
213
214 for i in 0..4 {
215 assert!(ddos.allow_request(ip), "request {} of 5 should be allowed", i + 1);
216 }
217 assert!(!ddos.allow_request(ip), "request at threshold should be blocked");
219 assert!(!ddos.allow_request(ip), "subsequent requests must also be blocked");
220 }
221
222 #[test]
223 fn bad_request_penalty_causes_ban_faster_than_normal_requests() {
224 let threshold = 20.0;
225 let penalty = 10.0;
226 let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, penalty, 8, Arc::new(ManualTimeProvider::default())));
227 let ip = "4.4.4.4";
228
229 ddos.report_bad_request(ip); ddos.report_bad_request(ip); assert!(!ddos.allow_request(ip), "IP should be banned after two penalty-weight bad requests");
233 assert!(DdosConnectionGuard::try_new(ddos.clone(), ip).is_none(), "banned IP must not acquire a connection");
234 }
235
236 #[test]
237 fn score_is_independent_per_ip() {
238 let threshold = 3.0;
239 let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, 1.0, 8, Arc::new(ManualTimeProvider::default())));
240 let ip_a = "10.0.0.1";
241 let ip_b = "10.0.0.2";
242
243 while ddos.allow_request(ip_a) {}
244 assert!(!ddos.allow_request(ip_a), "ip_a should be blocked");
245
246 assert!(ddos.allow_request(ip_b), "ip_b should be unaffected by ip_a's exhaustion");
247 let guard_b = DdosConnectionGuard::try_new(ddos.clone(), ip_b);
248 assert!(guard_b.is_some(), "ip_b should still acquire a connection after ip_a is banned");
249 }
250
251 #[test]
252 fn score_decays_over_time() {
253 let time_provider = Arc::new(ManualTimeProvider::default());
255 let ddos = Arc::new(MemDdosProtection::new(5.0, 1000.0, 1.0, 8, time_provider.clone()));
256 let ip = "7.7.7.7";
257
258 for _ in 0..4 {
260 assert!(ddos.allow_request(ip));
261 }
262
263 time_provider.set_time(TimeMillis(10));
266 assert!(ddos.allow_request(ip), "score should have decayed, allowing the request");
267 }
268}