1use bytes::Bytes;
19use hashiverse_lib::protocol::payload::payload::CacheRequestTokenV1;
20use hashiverse_lib::protocol::peer::Peer;
21use hashiverse_lib::tools::buckets::BucketLocation;
22use hashiverse_lib::tools::server_id::ServerId;
23use hashiverse_lib::tools::time::{TimeMillis, MILLIS_IN_MINUTE};
24use hashiverse_lib::tools::time_provider::moka_clock::TimeProviderMokaClock;
25use hashiverse_lib::tools::time_provider::time_provider::TimeProvider;
26use hashiverse_lib::tools::tools::leading_agreement_bits_xor;
27use hashiverse_lib::tools::types::Id;
28use moka::sync::Cache;
29use std::collections::HashMap;
30use std::sync::{Arc, Mutex};
31
32use crate::server::post_bundle_caching_shared::{CachedBundle, GetCacheResult, CACHE_HIT_THRESHOLD, CACHE_LOCATION_TTI, CACHE_REQUEST_TOKEN_TTL_DURATION, CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS};
33
34const POST_BUNDLE_PLACEHOLDER_WEIGHT: u32 = 4 * 1024 * 1024;
37
38struct CachedPostBundleLocationEntry {
45 bundles: HashMap<Id, CachedBundle>,
47 hit_count: u32,
48}
49
50impl CachedPostBundleLocationEntry {
51 fn placeholder() -> Self {
52 Self { bundles: HashMap::new(), hit_count: 0 }
53 }
54
55 fn weight(&self) -> u32 {
56 let total: u32 = self.bundles.values().map(|b| b.bytes.len() as u32).sum();
57 if total == 0 { POST_BUNDLE_PLACEHOLDER_WEIGHT } else { total }
58 }
59}
60
61pub struct PostBundleCache {
75 max_originators_per_location: usize,
76 bundles: Cache<Id, Arc<Mutex<CachedPostBundleLocationEntry>>>,
77 inflight: Cache<Id, ()>,
78}
79
80impl PostBundleCache {
81 pub fn new(max_originators_per_location: usize, max_bytes: u64, time_provider: Arc<dyn TimeProvider>) -> Self {
82 let clock = Arc::new(TimeProviderMokaClock::new(time_provider));
84
85 let bundles = Cache::builder()
86 .weigher(|_key: &Id, entry: &Arc<Mutex<CachedPostBundleLocationEntry>>| {
87 entry.lock().map(|e| e.weight()).unwrap_or(POST_BUNDLE_PLACEHOLDER_WEIGHT)
88 })
89 .max_capacity(max_bytes)
90 .time_to_idle(CACHE_LOCATION_TTI)
91 .external_clock(clock.clone())
92 .build();
93
94 let inflight = Cache::builder()
95 .time_to_live(CACHE_REQUEST_TOKEN_TTL_DURATION)
96 .external_clock(clock)
97 .build();
98
99 Self { max_originators_per_location, bundles, inflight }
100 }
101
102 pub fn on_get(
109 &self,
110 bucket_location: &BucketLocation,
111 already_retrieved_peer_ids: &[Id],
112 peer_self: &Peer,
113 server_id: &ServerId,
114 now: TimeMillis,
115 ) -> GetCacheResult {
116 let location_id = bucket_location.location_id;
117 let entry_arc = self.bundles.get_with(location_id, || Arc::new(Mutex::new(CachedPostBundleLocationEntry::placeholder())));
118
119 let (cached_items, already_cached_peer_ids, should_issue_token) = {
120 let mut entry = entry_arc.lock().unwrap();
121 entry.hit_count += 1;
122
123 let already_retrieved_set: std::collections::HashSet<Id> = already_retrieved_peer_ids.iter().copied().collect();
124 let cached_items: Vec<Bytes> = entry.bundles
125 .iter()
126 .filter(|(originator_id, bundle)| !already_retrieved_set.contains(originator_id) && !bundle.is_stale(now))
127 .map(|(_, bundle)| bundle.bytes.clone())
128 .collect();
129
130 let already_cached_peer_ids: Vec<Id> = entry.bundles.keys().copied().collect();
131 let should_issue_token = entry.hit_count >= CACHE_HIT_THRESHOLD && !self.inflight.contains_key(&location_id);
132 (cached_items, already_cached_peer_ids, should_issue_token)
133 };
134
135 let cache_request_token = if should_issue_token {
136 self.inflight.insert(location_id, ());
137 let expires_at = now + CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS;
138 Some(CacheRequestTokenV1::new(peer_self.clone(), bucket_location.clone(), expires_at, already_cached_peer_ids, &server_id.keys.signature_key))
139 } else {
140 None
141 };
142
143 GetCacheResult { cached_items, cache_request_token }
144 }
145
146 pub fn on_upload(
150 &self,
151 location_id: Id,
152 originator_peer_id: Id,
153 bundle_bytes: Bytes,
154 server_time: TimeMillis,
155 is_sealed: bool,
156 ) -> bool {
157 let entry_arc = match self.bundles.get(&location_id) {
158 Some(e) => e,
159 None => return false, };
161
162 let mut entry = entry_arc.lock().unwrap();
163 let expires_at = if is_sealed { None } else { Some(server_time + MILLIS_IN_MINUTE.const_mul(5)) };
164 let bundle = CachedBundle { bytes: bundle_bytes, expires_at };
165
166 entry.bundles.insert(originator_peer_id, bundle);
168
169 while entry.bundles.len() > self.max_originators_per_location {
174 let evict_key = entry.bundles
175 .iter()
176 .min_by(|(id_a, bundle_a), (id_b, bundle_b)| {
177 let distance_a = leading_agreement_bits_xor(id_a.as_ref(), location_id.as_ref());
178 let distance_b = leading_agreement_bits_xor(id_b.as_ref(), location_id.as_ref());
179 distance_a.cmp(&distance_b).then_with(|| {
180 let expires_a = bundle_a.expires_at.unwrap_or(TimeMillis(i64::MAX));
181 let expires_b = bundle_b.expires_at.unwrap_or(TimeMillis(i64::MAX));
182 expires_a.cmp(&expires_b)
183 })
184 })
185 .map(|(id, _)| *id);
186 if let Some(k) = evict_key {
187 entry.bundles.remove(&k);
188 }
189 }
190
191 entry.bundles.contains_key(&originator_peer_id)
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use bytes::Bytes;
200 use hashiverse_lib::tools::buckets::{BucketLocation, BucketType, BUCKET_DURATIONS};
201 use hashiverse_lib::tools::server_id::ServerId;
202 use hashiverse_lib::tools::time::TimeMillis;
203 use hashiverse_lib::tools::time_provider::time_provider::RealTimeProvider;
204 use hashiverse_lib::tools::pow_generator::single_threaded_pow_generator::SingleThreadedPowGenerator;
205 use hashiverse_lib::tools::types::{Id, Pow};
206
207 async fn make_test_server_and_peer() -> anyhow::Result<(ServerId, hashiverse_lib::protocol::peer::Peer)> {
208 let time_provider = RealTimeProvider;
209 let pow_generator = SingleThreadedPowGenerator::new();
210 let server_id = ServerId::new("own_pow", &time_provider, Pow(0), true, &pow_generator).await?;
211 let peer = server_id.to_peer(&time_provider)?;
212 Ok((server_id, peer))
213 }
214
215 fn make_test_bucket_location() -> BucketLocation {
216 BucketLocation::new(BucketType::User, Id::random(), BUCKET_DURATIONS[0], TimeMillis(1_000_000)).unwrap()
217 }
218
219 #[tokio::test]
220 async fn test_below_threshold_no_token() -> anyhow::Result<()> {
221 let (server_id, peer_self) = make_test_server_and_peer().await?;
222 let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
223 let bucket_location = make_test_bucket_location();
224 let now = TimeMillis(1_000_000);
225
226 for _ in 0..(CACHE_HIT_THRESHOLD - 1) {
227 let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
228 assert!(result.cache_request_token.is_none());
229 assert!(result.cached_items.is_empty());
230 }
231
232 Ok(())
233 }
234
235 #[tokio::test]
236 async fn test_at_threshold_token_issued_then_deduplicated() -> anyhow::Result<()> {
237 let (server_id, peer_self) = make_test_server_and_peer().await?;
238 let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
239 let bucket_location = make_test_bucket_location();
240 let now = TimeMillis(1_000_000);
241
242 for _ in 0..(CACHE_HIT_THRESHOLD - 1) {
243 let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
244 assert!(result.cache_request_token.is_none());
245 }
246
247 let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
249 assert!(result.cache_request_token.is_some());
250
251 let result2 = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
253 assert!(result2.cache_request_token.is_none());
254
255 Ok(())
256 }
257
258 #[tokio::test]
259 async fn test_upload_and_retrieval() -> anyhow::Result<()> {
260 let (server_id, peer_self) = make_test_server_and_peer().await?;
261 let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
262 let bucket_location = make_test_bucket_location();
263 let location_id = bucket_location.location_id;
264 let now = TimeMillis(1_000_000);
265 let originator_id = Id::random();
266 let bundle_bytes = Bytes::from_static(b"test_bundle");
267
268 cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
270
271 let accepted = cache.on_upload(location_id, originator_id, bundle_bytes.clone(), now, false);
272 assert!(accepted);
273
274 let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
275 assert_eq!(result.cached_items, vec![bundle_bytes]);
276
277 Ok(())
278 }
279
280 #[tokio::test]
281 async fn test_already_retrieved_filtered() -> anyhow::Result<()> {
282 let (server_id, peer_self) = make_test_server_and_peer().await?;
283 let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
284 let bucket_location = make_test_bucket_location();
285 let location_id = bucket_location.location_id;
286 let now = TimeMillis(1_000_000);
287 let originator_id = Id::random();
288
289 cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
290 cache.on_upload(location_id, originator_id, Bytes::from_static(b"bundle"), now, false);
291
292 let result = cache.on_get(&bucket_location, &[originator_id], &peer_self, &server_id, now);
293 assert!(result.cached_items.is_empty());
294
295 Ok(())
296 }
297
298 #[tokio::test]
299 async fn test_upload_returns_false_when_not_in_cache() -> anyhow::Result<()> {
300 let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
301 let location_id = Id::random();
302 let originator_id = Id::random();
303
304 let accepted = cache.on_upload(location_id, originator_id, Bytes::from_static(b"bundle"), TimeMillis(1_000_000), false);
306 assert!(!accepted);
307
308 Ok(())
309 }
310
311 #[tokio::test]
314 async fn test_overflow_keeps_closest_originators() -> anyhow::Result<()> {
315 let (server_id, peer_self) = make_test_server_and_peer().await?;
316 let cache = PostBundleCache::new(3, 64 * 1024 * 1024, Arc::new(RealTimeProvider)); let bucket_location = make_test_bucket_location();
318 let location_id = bucket_location.location_id;
319 let now = TimeMillis(1_000_000);
320
321 cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
323
324 let originator_at = |flip_bit: usize| -> Id {
327 let mut bytes = location_id.0;
328 bytes[flip_bit / 8] ^= 1 << (7 - (flip_bit % 8));
329 Id(bytes)
330 };
331
332 for &p in &[20usize, 40, 60, 80, 100] {
334 let bytes = Bytes::from(format!("bundle-agreement-{}", p));
335 cache.on_upload(location_id, originator_at(p), bytes, now, true);
337 }
338
339 let result = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now);
340 let cached: std::collections::HashSet<Vec<u8>> = result.cached_items.iter().map(|b| b.to_vec()).collect();
341 assert_eq!(3, cached.len(), "cache must keep exactly max_originators_per_location entries");
342 for &p in &[60usize, 80, 100] {
343 assert!(cached.contains(format!("bundle-agreement-{}", p).as_bytes()), "closest originator (agreement {}) must be kept", p);
344 }
345 for &p in &[20usize, 40] {
346 assert!(!cached.contains(format!("bundle-agreement-{}", p).as_bytes()), "furthest originator (agreement {}) must be evicted", p);
347 }
348 Ok(())
349 }
350
351 #[tokio::test]
359 async fn test_cache_request_token_expiry() -> anyhow::Result<()> {
360 let (server_id, peer_self) = make_test_server_and_peer().await?;
361 let cache = PostBundleCache::new(5, 64 * 1024 * 1024, Arc::new(RealTimeProvider));
362 let bucket_location = make_test_bucket_location();
363 let now = TimeMillis(1_000_000);
364
365 let mut token = None;
367 for _ in 0..CACHE_HIT_THRESHOLD {
368 token = cache.on_get(&bucket_location, &[], &peer_self, &server_id, now).cache_request_token.or(token);
369 }
370 let token = token.expect("server issues a token at the hit threshold");
371
372 assert!(!token.is_expired(now), "token must be valid at issue time");
373 assert!(!token.is_expired(TimeMillis(now.0 + CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS.0 - 1)), "token must be valid just before its TTL elapses");
374 assert!(token.is_expired(TimeMillis(now.0 + CACHE_REQUEST_TOKEN_TTL_DURATION_MILLIS.0 + 1)), "token must be expired once its TTL has elapsed");
375
376 Ok(())
377 }
378}