mod generator; use axum::{ body::{Body, Bytes}, extract::{ConnectInfo, Path}, response::{Html, IntoResponse, Response}, routing::get, Router, }; use axum_server::{bind_rustls, tls_rustls::RustlsConfig}; use clap::Parser; use hashbrown::HashMap; use http::HeaderMap; use http_body::Frame; use itertools::Itertools; use rand::{prelude::*, Rng as _}; use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng}; use std::{ borrow::Cow, convert::Infallible, net::{IpAddr, SocketAddr}, path::PathBuf, pin::Pin, sync::{ atomic::{AtomicU64, Ordering}, Arc, }, task::{Context, Poll}, time::Duration, }; use tokio::time::{interval, Interval}; pub type Rng = ChaCha8Rng; #[derive(Parser)] pub struct Args { /// Socket to bind to, defaults to 0.0.0.0:3000 #[arg(long)] sock: Option, #[arg(long)] /// Path of the certificate .pem cert: Option, /// Path of the private key .pem #[arg(long)] key: Option, } #[inline] fn create_rng(seed_bytes: impl IntoIterator) -> Rng { let mut seed = [0; 32]; for (i, b) in seed_bytes.into_iter().take(seed.len()).enumerate() { seed[i] = b; } Rng::from_seed(seed) } const COUNT_FILE: &str = "count.txt"; const STATS_FILE: &str = "stats.txt"; const SLOW_CHUNK_SIZE: usize = 100; const SLOW_DURATION: Duration = Duration::from_millis(100); struct SlowBody { bytes: Bytes, interval: Interval, } impl http_body::Body for SlowBody { type Data = Bytes; type Error = Infallible; fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll, Self::Error>>> { match self.interval.poll_tick(cx) { Poll::Ready(_) => { if self.bytes.is_empty() { Poll::Ready(None) } else { let split_len = self.bytes.len().min(SLOW_CHUNK_SIZE); Poll::Ready(Some(Ok(Frame::data(self.bytes.split_to(split_len))))) } } Poll::Pending => Poll::Pending, } } } impl IntoResponse for SlowBody { fn into_response(self) -> Response { Html(Body::new(self)).into_response() } } struct RequestStats { ip: IpAddr, } #[tokio::main] async fn main() { let args = Args::parse(); let generators: Vec> = vec![ Arc::new(generator::Markov::new(include_str!("../wap.txt"))), Arc::new(generator::Ast::new()), ]; let counter = Arc::new(AtomicU64::new( if let Some(prev_count) = std::fs::read_to_string(COUNT_FILE) .ok() .and_then(|s| s.trim().parse().ok()) { prev_count } else { 0 }, )); let (stats_tx, stats_rx) = flume::unbounded(); let app = { let counter = counter.clone(); let stats_tx = stats_tx.clone(); Router::new().route( "/{id}", get( |Path(id): Path, ConnectInfo(sock): ConnectInfo, headers: HeaderMap| async move { // Create a RNG for this path (deterministic, to simulate static pages) let mut rng = create_rng(id.bytes()); let ip = headers .get("X-Forwarded-For") .and_then(|h| h.to_str().ok()) .and_then(|h| h.split(',').next()) .and_then(|s| s.trim().parse().ok()) .unwrap_or_else(|| sock.ip()); stats_tx.send(RequestStats { ip }).unwrap(); // Count the request. Also doubles as the non-deterministic seed let count = counter.fetch_add(1, Ordering::Relaxed); // Create a RNG for this session (non-deterministic) let mut session_rng = create_rng(count.to_le_bytes()); // Artificially slow down connections as rudimentary DDoS protection, and to use up client resources tokio::time::sleep(Duration::from_millis(session_rng.random_range(200..1000))) .await; // Choose a bullshit generator from our collection for this page let generator = generators.choose(&mut rng).unwrap(); let title = generator .word_stream(rng.random_range(2..10), &mut rng.clone()) .join(" "); let stats = format!("Served rubbish to {count} clients so far"); let content = generator .word_stream(rng.random_range(50..5_000), &mut rng.clone()) .fold(String::new(), |mut content, word| { // Small chance of every word becoming a link back into the void if rng.random_bool(0.05) { let url = generator.word_stream(3, &mut rng.clone()).join("-"); content += &format!(" {}", url, word); } else { // Also, a chance for every word to end with a newline. This should probably be controlled by the generator. content += if rng.random_bool(0.01) { ".
" } else { " " }; content += &word } content }); let html = format!( " {title}

{title}

{stats}

{content}

" ); SlowBody { bytes: html.into(), interval: interval(SLOW_DURATION), } }, ), ) }; let mut interval = tokio::time::interval(Duration::from_secs(20)); let mut worst_offenders = HashMap::<_, u64>::default(); tokio::spawn(async move { let mut last = 0; loop { interval.tick().await; while let Ok(stats) = stats_rx.try_recv() { *worst_offenders.entry(stats.ip).or_default() += 1; } let count = counter.load(Ordering::Relaxed); if count != last { last = count; let _ = std::fs::write(COUNT_FILE, &format!("{count}")); let mut worst_offenders = worst_offenders.iter().collect::>(); worst_offenders.sort_by_key(|(_, n)| *n); let stats = worst_offenders .iter() .enumerate() .map(|(i, (ip, n))| format!("#{:>4} | {:>4} | {}\n", i + 1, n, ip)) .collect::(); let _ = std::fs::write(STATS_FILE, &stats); } } }); println!("Starting..."); let sock = args .sock .as_deref() .unwrap_or("0.0.0.0:3000") .parse() .unwrap(); if let (Some(cert), Some(key)) = (args.cert, args.key) { println!("Enabling TLS..."); let config = RustlsConfig::from_pem_file(cert, key).await.unwrap(); bind_rustls(sock, config) .serve(app.into_make_service_with_connect_info::()) .await .unwrap(); } else { println!("WARNING: TLS disabled."); axum_server::bind(sock) .serve(app.into_make_service_with_connect_info::()) .await .unwrap() } }