Skip to main content

hashiverse_lib/transport/ddos/
mem_ddos.rs

1//! # In-memory DDoS accounting
2//!
3//! Implements [`crate::transport::ddos::ddos::DdosProtection`] purely in RAM: per-IP
4//! `DdosScore`s live in a `moka` cache with time-based eviction so idle IPs get
5//! collected automatically, and per-IP connection counts live in a `HashMap` guarded
6//! by a `parking_lot::Mutex`.
7//!
8//! "Ban" here is just a flag in the cache — no kernel-level dropping. That makes this
9//! implementation suitable for tests (the integration harness stresses the scoring
10//! logic without wanting to touch host firewall state) and for platforms where
11//! `ipset`/`iptables` aren't available. The production path in
12//! `hashiverse-server-lib` wraps this with a real firewall-level ban via
13//! [`crate::tools::config::SERVER_DDOS_IPSET_SET_NAME`].
14
15use crate::tools::time_provider::moka_clock::TimeProviderMokaClock;
16use crate::tools::time_provider::time_provider::TimeProvider;
17use crate::transport::ddos::ddos::{DdosProtection, DdosScore};
18use log::warn;
19use moka::sync::Cache;
20use parking_lot::Mutex;
21use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24
25/// In-memory DDoS protection with linearly decaying per-IP scores.
26///
27/// Each `allow_request` adds 1.0 point, each `report_bad_request` adds
28/// `bad_request_penalty` points.  Between calls the score drains at
29/// `decay_per_second` points/second, so sustained low-rate traffic stabilises
30/// well below the threshold while bursts trigger quickly.
31///
32/// Scores are stored in a moka cache whose idle expiry is long enough for any
33/// maxed-out score to fully decay, keeping memory bounded.
34pub struct MemDdosProtection {
35    score_threshold: f64,
36    decay_per_second: f64,
37    bad_request_penalty: f64,
38    max_connections_per_ip: usize,
39    scores: Cache<String, Arc<Mutex<DdosScore>>>,
40    connections: Mutex<HashMap<String, usize>>,
41    time_provider: Arc<dyn TimeProvider>,
42}
43
44impl MemDdosProtection {
45    pub fn new(score_threshold: f64, decay_per_second: f64, bad_request_penalty: f64, max_connections_per_ip: usize, time_provider: Arc<dyn TimeProvider>) -> Self {
46        // Idle expiry: time for a maxed-out score to fully decay, with 2x margin
47        let idle_secs = if decay_per_second > 0.0 {
48            (score_threshold / decay_per_second * 2.0).ceil() as u64
49        } else {
50            3600 // fallback: 1 hour if no decay
51        };
52        // Score idle-eviction runs on our TimeProvider (scaled in tests), not wall time.
53        let scores = Cache::builder()
54            .time_to_idle(Duration::from_secs(idle_secs))
55            .external_clock(Arc::new(TimeProviderMokaClock::new(time_provider.clone())))
56            .build();
57        Self {
58            score_threshold,
59            decay_per_second,
60            bad_request_penalty,
61            max_connections_per_ip,
62            scores,
63            connections: Mutex::new(HashMap::new()),
64            time_provider,
65        }
66    }
67
68    fn increment_score(&self, ip: &str, points: f64) -> f64 {
69        let now = self.time_provider.current_time_millis();
70        let entry = self.scores.get_with(ip.to_string(), || Arc::new(Mutex::new(DdosScore::new())));
71        entry.lock().increment(points, self.decay_per_second, now)
72    }
73
74    fn is_score_banned(&self, ip: &str) -> bool {
75        let now = self.time_provider.current_time_millis();
76        self.scores
77            .get(ip)
78            .map(|entry| entry.lock().current(self.decay_per_second, now) >= self.score_threshold)
79            .unwrap_or(false)
80    }
81}
82
83impl DdosProtection for MemDdosProtection {
84    fn allow_request(&self, ip: &str) -> bool {
85        self.increment_score(ip, 1.0) < self.score_threshold
86    }
87
88    fn report_bad_request(&self, ip: &str) {
89        let score = self.increment_score(ip, self.bad_request_penalty);
90        if score >= self.score_threshold {
91            warn!("DDoS: {} blocked (score={:.1})", ip, score);
92        }
93    }
94
95    fn try_acquire_connection(&self, ip: &str) -> bool {
96        if self.is_score_banned(ip) {
97            return false;
98        }
99        let mut connections = self.connections.lock();
100        let count = connections.entry(ip.to_string()).or_insert(0);
101        if *count >= self.max_connections_per_ip {
102            return false;
103        }
104        *count += 1;
105        true
106    }
107
108    fn release_connection(&self, ip: &str) {
109        let mut connections = self.connections.lock();
110        if let Some(count) = connections.get_mut(ip) {
111            *count = count.saturating_sub(1);
112            if *count == 0 {
113                connections.remove(ip);
114            }
115        }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::tools::config;
123    use crate::tools::time::TimeMillis;
124    use crate::tools::time_provider::manual_time_provider::ManualTimeProvider;
125    use crate::transport::ddos::ddos::DdosConnectionGuard;
126
127    fn make_ddos() -> Arc<MemDdosProtection> {
128        Arc::new(MemDdosProtection::new(
129            config::SERVER_DDOS_SCORE_THRESHOLD,
130            config::SERVER_DDOS_DECAY_PER_SECOND,
131            config::SERVER_DDOS_BAD_REQUEST_PENALTY,
132            config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP,
133            Arc::new(ManualTimeProvider::default()),
134        ))
135    }
136
137    #[test]
138    fn connection_guard_limits_per_ip() {
139        let ddos = make_ddos();
140        let ip = "1.2.3.4";
141
142        let mut guards = vec![];
143        for _ in 0..config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP {
144            let guard = DdosConnectionGuard::try_new(ddos.clone(), ip);
145            assert!(guard.is_some(), "should acquire slot within limit");
146            guards.push(guard.unwrap());
147        }
148
149        let over_limit = DdosConnectionGuard::try_new(ddos.clone(), ip);
150        assert!(over_limit.is_none(), "should be blocked at per-IP cap");
151
152        // Release one slot — should unblock
153        drop(guards.pop().unwrap());
154        let recovered = DdosConnectionGuard::try_new(ddos.clone(), ip);
155        assert!(recovered.is_some(), "should acquire after release");
156    }
157
158    #[test]
159    fn connection_guard_independent_ips() {
160        let ddos = make_ddos();
161
162        let guard_a = DdosConnectionGuard::try_new(ddos.clone(), "1.1.1.1");
163        let guard_b = DdosConnectionGuard::try_new(ddos.clone(), "2.2.2.2");
164
165        assert!(guard_a.is_some());
166        assert!(guard_b.is_some());
167    }
168
169    #[test]
170    fn banned_ip_cannot_acquire_connection() {
171        let ddos = Arc::new(MemDdosProtection::new(3.0, 0.0, 3.0, 8, Arc::new(ManualTimeProvider::default())));
172        let ip = "1.2.3.4";
173
174        // Exhaust the score to trigger a ban
175        while ddos.allow_request(ip) {}
176
177        let guard = DdosConnectionGuard::try_new(ddos.clone(), ip);
178        assert!(guard.is_none(), "banned IP should not acquire a connection slot");
179    }
180
181    #[test]
182    fn guard_report_bad_request_delegates() {
183        let ddos = Arc::new(MemDdosProtection::new(100.0, 0.0, 5.0, 8, Arc::new(ManualTimeProvider::default())));
184        let ip = "5.6.7.8";
185        let guard = DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap();
186
187        guard.report_bad_request();
188        // After one bad-request penalty the score is ~5 — still under 100, so allow_request works
189        assert!(guard.allow_request());
190    }
191
192    #[test]
193    fn connection_count_drops_to_zero_after_all_guards_released() {
194        let ddos = make_ddos();
195        let ip = "9.9.9.9";
196
197        let guard = DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap();
198        drop(guard);
199
200        let mut guards = vec![];
201        for _ in 0..config::SERVER_DDOS_MAX_CONNECTIONS_PER_IP {
202            guards.push(DdosConnectionGuard::try_new(ddos.clone(), ip).unwrap());
203        }
204        assert!(DdosConnectionGuard::try_new(ddos.clone(), ip).is_none());
205    }
206
207    #[test]
208    fn allow_request_returns_false_at_threshold() {
209        // Use zero decay so timing doesn't affect the test
210        let threshold = 5.0;
211        let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, 1.0, 8, Arc::new(ManualTimeProvider::default())));
212        let ip = "3.3.3.3";
213
214        for i in 0..4 {
215            assert!(ddos.allow_request(ip), "request {} of 5 should be allowed", i + 1);
216        }
217        // 5th call reaches the limit
218        assert!(!ddos.allow_request(ip), "request at threshold should be blocked");
219        assert!(!ddos.allow_request(ip), "subsequent requests must also be blocked");
220    }
221
222    #[test]
223    fn bad_request_penalty_causes_ban_faster_than_normal_requests() {
224        let threshold = 20.0;
225        let penalty = 10.0;
226        let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, penalty, 8, Arc::new(ManualTimeProvider::default())));
227        let ip = "4.4.4.4";
228
229        ddos.report_bad_request(ip); // score = 10
230        ddos.report_bad_request(ip); // score = 20 — at threshold
231
232        assert!(!ddos.allow_request(ip), "IP should be banned after two penalty-weight bad requests");
233        assert!(DdosConnectionGuard::try_new(ddos.clone(), ip).is_none(), "banned IP must not acquire a connection");
234    }
235
236    #[test]
237    fn score_is_independent_per_ip() {
238        let threshold = 3.0;
239        let ddos = Arc::new(MemDdosProtection::new(threshold, 0.0, 1.0, 8, Arc::new(ManualTimeProvider::default())));
240        let ip_a = "10.0.0.1";
241        let ip_b = "10.0.0.2";
242
243        while ddos.allow_request(ip_a) {}
244        assert!(!ddos.allow_request(ip_a), "ip_a should be blocked");
245
246        assert!(ddos.allow_request(ip_b), "ip_b should be unaffected by ip_a's exhaustion");
247        let guard_b = DdosConnectionGuard::try_new(ddos.clone(), ip_b);
248        assert!(guard_b.is_some(), "ip_b should still acquire a connection after ip_a is banned");
249    }
250
251    #[test]
252    fn score_decays_over_time() {
253        // High threshold, fast decay; drive the clock by hand so the test is deterministic.
254        let time_provider = Arc::new(ManualTimeProvider::default());
255        let ddos = Arc::new(MemDdosProtection::new(5.0, 1000.0, 1.0, 8, time_provider.clone()));
256        let ip = "7.7.7.7";
257
258        // At t=0, add 4 points (just under the threshold of 5).
259        for _ in 0..4 {
260            assert!(ddos.allow_request(ip));
261        }
262
263        // Advance 10ms: with decay_per_second=1000 that's 1000 * 0.01 = 10 points of
264        // decay, draining the score to 0. The next request should therefore be allowed.
265        time_provider.set_time(TimeMillis(10));
266        assert!(ddos.allow_request(ip), "score should have decayed, allowing the request");
267    }
268}