Skip to main content

hashiverse_lib/tools/
tools.rs

1//! # Grab-bag of tiny cross-platform helpers
2//!
3//! Utility functions that don't fit in any of the more focused modules:
4//!
5//! - **Async yielding** ([`yield_now`]) — maps to `tokio::task::yield_now` on native,
6//!   `gloo_timers` on wasm32-unknown, and `tokio::time::sleep(0)` on wasi.
7//! - **Randomness** ([`random_fill_bytes`], [`random_bytes`], [`random_u32`]) — OS RNG
8//!   helpers used by key generation and PoW salt selection.
9//! - **Base64 and hex parsing** — consistent helpers used wherever we need to emit or
10//!   accept textual byte blobs (key persistence, URLs, HTML attributes).
11//! - **Byte reversal** used by server-id PoW hash-to-id mapping.
12//! - **`LeadingAgreementBits`** typedef for the XOR-distance metric used by the DHT.
13//! - **Logging bootstrap** (`tracing_subscriber` initialisation with consistent
14//!   formatting across native and wasm).
15//! - **`Cancellable`-style async helpers** that plug into `CancellationToken`.
16//!
17//! Anything here that grows a meaningful amount of functionality should graduate to
18//! its own module.
19
20use crate::tools::json;
21use crate::tools::time::DurationMillis;
22use crate::tools::time_provider::time_provider::{RealTimeProvider, TimeProvider};
23use crate::tools::BytesGatherer;
24use argon2::password_hash::rand_core::{OsRng, RngCore};
25use base64::Engine;
26use bytes::Bytes;
27use log::info;
28use std::fmt;
29use std::future::Future;
30use std::sync::Arc;
31use tokio_util::sync::CancellationToken;
32use tracing_subscriber::fmt::time::FormatTime;
33use tracing_subscriber::layer::SubscriberExt;
34use tracing_subscriber::util::SubscriberInitExt;
35
36pub type LeadingAgreementBits = i32;
37
38pub async fn yield_now() {
39    #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
40    {
41        send_wrapper::SendWrapper::new(gloo_timers::future::TimeoutFuture::new(0)).await;
42    }
43    #[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
44    {
45        tokio::time::sleep(std::time::Duration::from_millis(0u64)).await;
46    }
47    #[cfg(not(target_arch = "wasm32"))]
48    {
49        // On native platforms, use Tokio's optimized yield.
50        tokio::task::yield_now().await;
51    }
52}
53
54pub fn random_fill_bytes(dest: &mut [u8]) {
55    OsRng.fill_bytes(dest);
56}
57
58pub fn random_bytes(n: usize) -> Vec<u8> {
59    let mut dest = vec![0u8; n];
60    random_fill_bytes(&mut dest);
61    dest
62}
63
64pub fn reverse_bytes<const N: usize>(bytes: &[u8; N]) -> [u8; N] {
65    let mut result = [0u8; N];
66    for (i, &byte) in bytes.iter().rev().enumerate() {
67        result[i] = byte;
68    }
69    result
70}
71
72pub fn random_u32() -> u32 {
73    OsRng.next_u32()
74}
75
76#[cfg(target_pointer_width = "64")]
77pub fn random_usize() -> usize {
78    OsRng.next_u64() as usize
79}
80
81#[cfg(target_pointer_width = "32")]
82pub fn random_usize() -> usize {
83    OsRng.next_u32() as usize
84}
85
86pub fn random_usize_bounded(upper: usize) -> usize {
87    // Rejection sampling to avoid modulo bias.
88    // We accept values in [0, zone) where zone is the largest multiple of `upper`.
89    let zone = usize::MAX - (usize::MAX % upper);
90
91    loop {
92        let r = random_usize();
93        if r < zone {
94            return r % upper;
95        }
96    }
97}
98
99pub fn random_u8() -> u8 {
100    OsRng.next_u32() as u8
101}
102
103pub fn random_base64(length: usize) -> String {
104    let mut bytes = vec![0u8; length];
105    random_fill_bytes(&mut bytes);
106    encode_base64(bytes)
107}
108
109pub fn are_all_zeros<T: PartialEq + num_traits::Zero>(src: &[T]) -> bool {
110    src.iter().all(|b| *b == T::zero())
111}
112
113pub fn are_all_equal<T: PartialEq>(src1: &[T], src2: &[T]) -> bool {
114    if src1.len() != src2.len() {
115        return false;
116    }
117    src1.iter().zip(src2).all(|(a, b)| a == b)
118}
119
120pub fn count_leading_zero_bits(bytes: &[u8]) -> u8 {
121    let mut count = 0u64;
122
123    for &byte in bytes {
124        if byte == 0 {
125            count += 8;
126            continue;
127        }
128
129        // Count leading zeros in the non-zero byte
130        let mut mask = 0x80; // 10000000 in binary
131        while byte & mask == 0 {
132            count += 1;
133            mask >>= 1;
134        }
135
136        break; // Exit after processing the first non-zero byte
137    }
138
139    if count < 256 { count as u8 } else { 255 }
140}
141
142pub async fn cancellable_sleep_millis(time_provider: &dyn TimeProvider, millis: DurationMillis, cancellation_token: &CancellationToken) {
143    tokio::select! {
144        _ = time_provider.sleep_millis(millis) => {},
145        _ = cancellation_token.cancelled() => {},
146    }
147}
148
149pub fn format_vec<T: std::fmt::Display>(items: &[T]) -> String {
150    format!("[ {} ]", items.iter().map(|item| format!("{}", item)).collect::<Vec<_>>().join(", "))
151}
152
153pub fn leading_agreement_bits_xor(key1: &[u8], key2: &[u8]) -> LeadingAgreementBits {
154    let mut leading_bits_in_agreement: i32 = 0;
155
156    let min_len = std::cmp::min(key1.len(), key2.len());
157    for byte_idx in 0..min_len {
158        let xor = key1[byte_idx] ^ key2[byte_idx];
159
160        // Do we have differing bytes?
161        if xor != 0 {
162            leading_bits_in_agreement += xor.leading_zeros() as LeadingAgreementBits;
163            return leading_bits_in_agreement;
164        }
165        else {
166            leading_bits_in_agreement += 8;
167        }
168    }
169
170    leading_bits_in_agreement
171}
172
173pub fn encode_base64<T: AsRef<[u8]>>(input: T) -> String {
174    base64::engine::general_purpose::STANDARD.encode(&input)
175}
176
177pub fn decode_base64<T: AsRef<[u8]>>(input: T) -> anyhow::Result<Vec<u8>> {
178    Ok(base64::engine::general_purpose::STANDARD.decode(input)?)
179}
180
181pub fn usize_encode_le64(v: usize) -> [u8; 8] {
182    u64::to_le_bytes(v as u64)
183}
184
185pub fn usize_decode_le64(v_bytes: &[u8]) -> anyhow::Result<usize> {
186    let v = u64::from_le_bytes(v_bytes.try_into()?);
187    Ok(v as usize)
188}
189
190pub fn write_length_prefixed_json<T: serde::Serialize>(bytes_gatherer: &mut BytesGatherer, value: &T) -> anyhow::Result<()> {
191    let json_bytes = json::struct_to_bytes(value)?;
192    bytes_gatherer.put_u64(json_bytes.len() as u64);
193    bytes_gatherer.put_bytes(json_bytes);
194    Ok(())
195}
196pub fn read_length_prefixed_json<T: serde::de::DeserializeOwned>(bytes: &mut Bytes) -> anyhow::Result<T> {
197    use bytes::Buf;
198
199    if bytes.remaining() < 8 {
200        anyhow::bail!("Invalid buffer: missing json length");
201    }
202
203    let len = bytes.get_u64() as usize;
204
205    if bytes.remaining() < len {
206        anyhow::bail!("Invalid buffer: json data truncated");
207    }
208
209    let json_bytes = bytes.copy_to_bytes(len);
210    json::bytes_to_struct::<T>(&json_bytes)
211}
212
213pub fn random_element<T>(range: &[T]) -> &T {
214    let index = random_usize_bounded(range.len());
215    &range[index]
216}
217
218pub fn shuffle<T>(source: &mut [T]) {
219    // Fisher–Yates / Knuth shuffle (uniform)
220    for i in 1..source.len() {
221        let j = random_usize_bounded(i + 1);
222        source.swap(i, j);
223    }
224}
225
226pub struct CustomTimeFormatter {
227    time_provider: Arc<dyn TimeProvider>,
228}
229
230impl CustomTimeFormatter {
231    pub fn new(time_provider: Arc<dyn TimeProvider>) -> Self {
232        Self { time_provider }
233    }
234}
235
236impl FormatTime for CustomTimeFormatter {
237    fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer<'_>) -> fmt::Result {
238        write!(w, "{}", self.time_provider.current_time_str())
239    }
240}
241
242pub fn configure_logging() {
243    configure_logging_with_time_provider("trace", Arc::new(RealTimeProvider))
244}
245
246pub fn configure_logging_with_time_provider(level: &str, time_provider: Arc<dyn TimeProvider>) {
247    // The filter
248    let filter = format!("{},hyper=off,warp=off,reqwest=off,rustls=off,h2=off,h2=off,html5ever=off,selectors=off,fjall=off,lsm_tree=off,sfa=off,hickory_resolver=off,hickory_proto=off", level);
249    let env_filter = tracing_subscriber::EnvFilter::new(&filter);
250
251    // Prepare the Standard Logging Layer
252    let fmt_layer = tracing_subscriber::fmt::layer().with_timer(CustomTimeFormatter::new(time_provider));
253
254    let registry = tracing_subscriber::registry();
255
256    // Prepare the Console Layer (Conditional) - we only enable this if the 'tokio_unstable' cfg is present and we are not WASM
257    #[cfg(all(tokio_unstable, not(target_arch = "wasm32")))]
258    registry.with(console_subscriber::spawn());
259
260    // Register everything
261    registry.with(fmt_layer).with(env_filter).init();
262
263    info!("Logging initialized");
264}
265
266#[cfg(not(target_arch = "wasm32"))]
267pub type TempDirHandle = tempfile::TempDir;
268
269#[cfg(not(target_arch = "wasm32"))]
270pub fn get_temp_dir() -> anyhow::Result<(TempDirHandle, String)> {
271    let mut base = std::env::temp_dir();
272    base.push("hashiverse-temp");
273
274    // Ensure the base directory exists
275    std::fs::create_dir_all(&base)?;
276
277    let temp_dir = tempfile::Builder::new().prefix("hashiverse-").tempdir_in(&base)?;
278    let temp_dir_path = temp_dir.path().to_str().unwrap().to_string();
279    Ok((temp_dir, temp_dir_path))
280}
281
282#[cfg(target_arch = "wasm32")]
283pub type TempDirHandle = ();
284
285#[cfg(target_arch = "wasm32")]
286pub fn get_temp_dir() -> anyhow::Result<(TempDirHandle, String)> {
287    Ok(((), "".to_string()))
288}
289
290pub fn from_hex_str<T, const T_BYTES: usize>(str: &str, ctor: impl FnOnce([u8; T_BYTES]) -> T) -> anyhow::Result<T> {
291    if str.len() != 2 * T_BYTES {
292        anyhow::bail!("Invalid hex string length: expected {} hex characters ({} bytes), got {} characters.", 2 * T_BYTES, T_BYTES, str.len(),);
293    }
294
295    // Try to decode the hex string
296    let decoded = hex::decode(str)?;
297
298    // Check if the decoded bytes are exactly xxx bytes
299    if decoded.len() != T_BYTES {
300        anyhow::bail!("Invalid hex string length: expected {} bytes, got {} bytes", T_BYTES, decoded.len());
301    }
302
303    // Convert Vec<u8> to [u8; xxx]
304    let mut decoded_bytes = [0u8; T_BYTES];
305    decoded_bytes.copy_from_slice(&decoded);
306
307    Ok(ctor(decoded_bytes))
308}
309
310/// Spawn a background async task, using the appropriate runtime for the current target.
311///
312/// Only the native path requires `Send`: `tokio::spawn` runs the task on the multi-threaded
313/// runtime. Both wasm targets are single-threaded and impose no `Send` bound — important because
314/// browser-backed futures (IndexedDB storage, `gloo-net` requests) and the `async_trait(?Send)`
315/// client traits are not `Send`. The browser (`wasm32-unknown-unknown`) drives the task on the JS
316/// event loop via `wasm_bindgen_futures`; the wasmtime test target (`wasm32-wasip1`) uses tokio's
317/// thread-local spawner.
318#[cfg(not(target_arch = "wasm32"))]
319pub fn spawn_background_task<F>(task: F)
320where
321    F: Future<Output = ()> + Send + 'static,
322{
323    tokio::spawn(task);
324}
325
326#[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
327pub fn spawn_background_task<F>(task: F)
328where
329    F: Future<Output = ()> + 'static,
330{
331    wasm_bindgen_futures::spawn_local(task);
332}
333
334#[cfg(all(target_arch = "wasm32", not(target_os = "unknown")))]
335pub fn spawn_background_task<F>(task: F)
336where
337    F: Future<Output = ()> + 'static,
338{
339    tokio::task::spawn_local(task);
340}
341
342#[cfg(test)]
343mod tests {
344    #[tokio::test]
345    async fn xor_distance_bits_test() -> anyhow::Result<()> {
346        use crate::tools::tools::leading_agreement_bits_xor;
347
348        let tests = [
349            // Identical
350            ("0000", "0000", 16),
351            ("ffff", "ffff", 16),
352            ("1234", "1234", 16),
353            ("abcd", "abcd", 16),
354            // MSB
355            ("0000", "ffff", 0),
356            ("0000", "0fff", 4),
357            ("0000", "00ff", 8),
358            ("0000", "000f", 12),
359            // Units
360            ("0000", "efff", 0),
361            ("0000", "7fff", 1),
362            ("0000", "3fff", 2),
363            ("0000", "1fff", 3),
364            ("0000", "0fff", 4),
365            ("0000", "07ff", 5),
366            ("0000", "03ff", 6),
367            ("0000", "01ff", 7),
368            ("0000", "00ff", 8),
369            ("0000", "007f", 9),
370            ("0000", "003f", 10),
371            ("0000", "001f", 11),
372            ("0000", "000f", 12),
373            ("0000", "0007", 13),
374            ("0000", "0003", 14),
375            ("0000", "0001", 15),
376            // MSB + random
377            ("0000", "fff9", 0),
378            ("0000", "0ff9", 4),
379            ("0000", "00f9", 8),
380            // Different lengths
381            ("", "0000", 0),
382            ("00", "0000", 8),
383            ("0000", "000000", 16),
384        ];
385
386        for (a, b, expected) in tests {
387            let a_binary = hex::decode(a)?;
388            let b_binary = hex::decode(b)?;
389            {
390                let distance = leading_agreement_bits_xor(&a_binary, &b_binary);
391                assert_eq!(distance, expected, "Failed for {} and {}.  Got {} expected {}.", a, b, distance, expected);
392            }
393            {
394                let distance = leading_agreement_bits_xor(&b_binary, &a_binary);
395                assert_eq!(distance, expected, "Failed for {} and {}.  Got {} expected {}.", a, b, distance, expected);
396            }
397        }
398        Ok(())
399    }
400}
401