diff --git a/.gitignore b/.gitignore index 96ef6c0..b66d4be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -/target +/**/target Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index 1504b14..1ca085c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,6 @@ tracing = "0.1.26" x25519-dalek = "1.1.1" xoodoo = "0.1.0" zeroize = "1.3.0" + +[workspace] +members = ["examples/bench"] diff --git a/examples/bench/Cargo.toml b/examples/bench/Cargo.toml new file mode 100644 index 0000000..dad788e --- /dev/null +++ b/examples/bench/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "bench" +version = "0.1.0" +edition = "2018" +license = "MIT OR Apache-2.0" +publish = false + +[dependencies] +anyhow = "1.0.22" +bytes = "1" +hdrhistogram = { version = "7.2", default-features = false } +quinn = "0.9.1" +clap = { version = "3.2", features = ["derive"] } +tokio = { version = "1.0.1", features = ["rt", "sync"] } +tracing = "0.1.10" +tracing-subscriber = { version = "0.3.0", default-features = false, features = ["env-filter", "fmt", "ansi", "time", "local-time"] } +quinn-noise = { path = "../.." } +rand = "0.7.3" +ed25519-dalek = "1.0.1" diff --git a/examples/bench/src/bin/bulk.rs b/examples/bench/src/bin/bulk.rs new file mode 100644 index 0000000..2cb5217 --- /dev/null +++ b/examples/bench/src/bin/bulk.rs @@ -0,0 +1,212 @@ +use std::{ + net::SocketAddr, + sync::{Arc, Mutex}, + time::Instant, +}; + +use anyhow::{Context, Result}; +use clap::Parser; +use tokio::sync::Semaphore; +use tracing::{info, trace}; + +use bench::{ + configure_tracing_subscriber, connect_client, drain_stream, rt, send_data_on_stream, + server_endpoint, + stats::{Stats, TransferResult}, + Opt, +}; + +fn main() { + let opt = Opt::parse(); + configure_tracing_subscriber(); + + let mut csprng = rand::rngs::OsRng {}; + let keypair: ed25519_dalek::Keypair = ed25519_dalek::Keypair::generate(&mut csprng); + let public_key: ed25519_dalek::PublicKey = keypair.public; + + let runtime = rt(); + let (server_addr, endpoint) = server_endpoint(&runtime, keypair, &opt); + + let server_thread = std::thread::spawn(move || { + if let Err(e) = runtime.block_on(server(endpoint, opt)) { + eprintln!("server failed: {:#}", e); + } + }); + + let mut handles = Vec::new(); + for _ in 0..opt.clients { + handles.push(std::thread::spawn(move || { + let runtime = rt(); + match runtime.block_on(client(server_addr, public_key, opt)) { + Ok(stats) => Ok(stats), + Err(e) => { + eprintln!("client failed: {:#}", e); + Err(e) + } + } + })); + } + + for (id, handle) in handles.into_iter().enumerate() { + // We print all stats at the end of the test sequentially to avoid + // them being garbled due to being printed concurrently + if let Ok(stats) = handle.join().expect("client thread") { + stats.print(id); + } + } + + server_thread.join().expect("server thread"); +} + +async fn server(endpoint: quinn::Endpoint, opt: Opt) -> Result<()> { + let mut server_tasks = Vec::new(); + + // Handle only the expected amount of clients + for _ in 0..opt.clients { + let handshake = endpoint.accept().await.unwrap(); + let connection = handshake.await.context("handshake failed")?; + + server_tasks.push(tokio::spawn(async move { + loop { + let (mut send_stream, mut recv_stream) = match connection.accept_bi().await { + Err(quinn::ConnectionError::ApplicationClosed(_)) => break, + Err(e) => { + eprintln!("accepting stream failed: {:?}", e); + break; + } + Ok(stream) => stream, + }; + trace!("stream established"); + + let _: tokio::task::JoinHandle> = tokio::spawn(async move { + drain_stream(&mut recv_stream, opt.read_unordered).await?; + send_data_on_stream(&mut send_stream, opt.download_size).await?; + Ok(()) + }); + } + + if opt.stats { + println!("\nServer connection stats:\n{:#?}", connection.stats()); + } + })); + } + + // Await all the tasks. We have to do this to prevent the runtime getting dropped + // and all server tasks to be cancelled + for handle in server_tasks { + if let Err(e) = handle.await { + eprintln!("Server task error: {:?}", e); + }; + } + + Ok(()) +} + +async fn client( + server_addr: SocketAddr, + remote_public_key: ed25519_dalek::PublicKey, + opt: Opt, +) -> Result { + let (endpoint, connection) = connect_client(server_addr, remote_public_key, opt).await?; + + let start = Instant::now(); + + let connection = Arc::new(connection); + + let mut stats = ClientStats::default(); + let mut first_error = None; + + let sem = Arc::new(Semaphore::new(opt.max_streams)); + let results = Arc::new(Mutex::new(Vec::new())); + for _ in 0..opt.streams { + let permit = sem.clone().acquire_owned().await.unwrap(); + let results = results.clone(); + let connection = connection.clone(); + tokio::spawn(async move { + let result = + handle_client_stream(connection, opt.upload_size, opt.read_unordered).await; + info!("stream finished: {:?}", result); + results.lock().unwrap().push(result); + drop(permit); + }); + } + + // Wait for remaining streams to finish + let _ = sem.acquire_many(opt.max_streams as u32).await.unwrap(); + + for result in results.lock().unwrap().drain(..) { + match result { + Ok((upload_result, download_result)) => { + stats.upload_stats.stream_finished(upload_result); + stats.download_stats.stream_finished(download_result); + } + Err(e) => { + if first_error.is_none() { + first_error = Some(e); + } + } + } + } + + stats.upload_stats.total_duration = start.elapsed(); + stats.download_stats.total_duration = start.elapsed(); + + // Explicit close of the connection, since handles can still be around due + // to `Arc`ing them + connection.close(0u32.into(), b"Benchmark done"); + + endpoint.wait_idle().await; + + if opt.stats { + println!("\nClient connection stats:\n{:#?}", connection.stats()); + } + + match first_error { + None => Ok(stats), + Some(e) => Err(e), + } +} + +async fn handle_client_stream( + connection: Arc, + upload_size: u64, + read_unordered: bool, +) -> Result<(TransferResult, TransferResult)> { + let start = Instant::now(); + + let (mut send_stream, mut recv_stream) = connection + .open_bi() + .await + .context("failed to open stream")?; + + send_data_on_stream(&mut send_stream, upload_size).await?; + + let upload_result = TransferResult::new(start.elapsed(), upload_size); + + let start = Instant::now(); + let size = drain_stream(&mut recv_stream, read_unordered).await?; + let download_result = TransferResult::new(start.elapsed(), size as u64); + + Ok((upload_result, download_result)) +} + +#[derive(Default)] +struct ClientStats { + upload_stats: Stats, + download_stats: Stats, +} + +impl ClientStats { + pub fn print(&self, client_id: usize) { + println!(); + println!("Client {} stats:", client_id); + + if self.upload_stats.total_size != 0 { + self.upload_stats.print("upload"); + } + + if self.download_stats.total_size != 0 { + self.download_stats.print("download"); + } + } +} diff --git a/examples/bench/src/lib.rs b/examples/bench/src/lib.rs new file mode 100644 index 0000000..4fdd46e --- /dev/null +++ b/examples/bench/src/lib.rs @@ -0,0 +1,217 @@ +use std::{ + convert::TryInto, + net::{IpAddr, Ipv6Addr, SocketAddr}, + num::ParseIntError, + str::FromStr, + sync::Arc, +}; + +use anyhow::{Context, Result}; +use bytes::Bytes; +use clap::Parser; +use ed25519_dalek::Keypair; +use quinn::crypto::HandshakeTokenKey; +use rand::rngs::OsRng; +use tokio::runtime::{Builder, Runtime}; +use tracing::trace; + +pub mod stats; + +pub fn configure_tracing_subscriber() { + let filter = tracing_subscriber::EnvFilter::from_default_env(); + tracing::subscriber::set_global_default( + tracing_subscriber::FmtSubscriber::builder() + .with_env_filter(filter) + .finish(), + ) + .unwrap(); +} + +struct DummyHandshakeTokenYey; + +impl HandshakeTokenKey for DummyHandshakeTokenYey { + fn aead_from_hkdf(&self, _random_bytes: &[u8]) -> Box { + todo!() + } +} + +/// Creates a server endpoint which runs on the given runtime +pub fn server_endpoint( + rt: &tokio::runtime::Runtime, + keypair: ed25519_dalek::Keypair, + opt: &Opt, +) -> (SocketAddr, quinn::Endpoint) { + let crypto = Arc::new(quinn_noise::NoiseConfig::from( + quinn_noise::NoiseServerConfig { + keypair, + keylogger: None, + psk: None, + supported_protocols: vec![b"bench".to_vec()], + }, + )); + let mut server_config = quinn::ServerConfig::new(crypto, Arc::new(DummyHandshakeTokenYey)); + server_config.transport = Arc::new(transport_config(opt)); + + let endpoint = { + let _guard = rt.enter(); + quinn::Endpoint::server( + server_config, + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), + ) + .unwrap() + }; + let server_addr = endpoint.local_addr().unwrap(); + (server_addr, endpoint) +} + +/// Create a client endpoint and client connection +pub async fn connect_client( + server_addr: SocketAddr, + remote_public_key: ed25519_dalek::PublicKey, + opt: Opt, +) -> Result<(quinn::Endpoint, quinn::Connection)> { + let endpoint = + quinn::Endpoint::client(SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0)).unwrap(); + let mut csprng = OsRng {}; + let keypair: Keypair = Keypair::generate(&mut csprng); + let crypto = quinn_noise::NoiseConfig::from(quinn_noise::NoiseClientConfig { + remote_public_key, + alpn: b"bench".to_vec(), + keypair, + psk: None, + keylogger: None, + }); + + let mut client_config = quinn::ClientConfig::new(Arc::new(crypto)); + client_config.transport_config(Arc::new(transport_config(&opt))); + + let connection = endpoint + .connect_with(client_config, server_addr, "localhost") + .unwrap() + .await + .context("unable to connect")?; + trace!("connected"); + + Ok((endpoint, connection)) +} + +pub async fn drain_stream(stream: &mut quinn::RecvStream, read_unordered: bool) -> Result { + let mut read = 0; + + if read_unordered { + while let Some(chunk) = stream.read_chunk(usize::MAX, false).await? { + read += chunk.bytes.len(); + } + } else { + // These are 32 buffers, for reading approximately 32kB at once + #[rustfmt::skip] + let mut bufs = [ + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + Bytes::new(), Bytes::new(), Bytes::new(), Bytes::new(), + ]; + + while let Some(n) = stream.read_chunks(&mut bufs[..]).await? { + read += bufs.iter().take(n).map(|buf| buf.len()).sum::(); + } + } + + Ok(read) +} + +pub async fn send_data_on_stream(stream: &mut quinn::SendStream, stream_size: u64) -> Result<()> { + const DATA: &[u8] = &[0xAB; 1024 * 1024]; + let bytes_data = Bytes::from_static(DATA); + + let full_chunks = stream_size / (DATA.len() as u64); + let remaining = (stream_size % (DATA.len() as u64)) as usize; + + for _ in 0..full_chunks { + stream + .write_chunk(bytes_data.clone()) + .await + .context("failed sending data")?; + } + + if remaining != 0 { + stream + .write_chunk(bytes_data.slice(0..remaining)) + .await + .context("failed sending data")?; + } + + stream.finish().await.context("failed finishing stream")?; + + Ok(()) +} + +pub fn rt() -> Runtime { + Builder::new_current_thread().enable_all().build().unwrap() +} + +pub fn transport_config(opt: &Opt) -> quinn::TransportConfig { + // High stream windows are chosen because the amount of concurrent streams + // is configurable as a parameter. + let mut config = quinn::TransportConfig::default(); + config.max_concurrent_uni_streams(opt.max_streams.try_into().unwrap()); + config +} + +#[derive(Parser, Debug, Clone, Copy)] +#[clap(name = "bulk")] +pub struct Opt { + /// The total number of clients which should be created + #[clap(long = "clients", short = 'c', default_value = "1")] + pub clients: usize, + /// The total number of streams which should be created + #[clap(long = "streams", short = 'n', default_value = "1")] + pub streams: usize, + /// The amount of concurrent streams which should be used + #[clap(long = "max_streams", short = 'm', default_value = "1")] + pub max_streams: usize, + /// Number of bytes to transmit from server to client + /// + /// This can use SI prefixes for sizes. E.g. 1M will transfer 1MiB, 10GiB + /// will transfer 10GiB. + #[clap(long, default_value = "1G", parse(try_from_str = parse_byte_size))] + pub download_size: u64, + /// Number of bytes to transmit from client to server + /// + /// This can use SI prefixes for sizes. E.g. 1M will transfer 1MiB, 10GiB + /// will transfer 10GiB. + #[clap(long, default_value = "0", parse(try_from_str = parse_byte_size))] + pub upload_size: u64, + /// Show connection stats the at the end of the benchmark + #[clap(long = "stats")] + pub stats: bool, + /// Whether to use the unordered read API + #[clap(long = "unordered")] + pub read_unordered: bool, +} + +fn parse_byte_size(s: &str) -> Result { + let s = s.trim(); + + let multiplier = match s.chars().last() { + Some('T') => 1024 * 1024 * 1024 * 1024, + Some('G') => 1024 * 1024 * 1024, + Some('M') => 1024 * 1024, + Some('k') => 1024, + _ => 1, + }; + + let s = if multiplier != 1 { + &s[..s.len() - 1] + } else { + s + }; + + let base: u64 = u64::from_str(s)?; + + Ok(base * multiplier) +} diff --git a/examples/bench/src/stats.rs b/examples/bench/src/stats.rs new file mode 100644 index 0000000..4450724 --- /dev/null +++ b/examples/bench/src/stats.rs @@ -0,0 +1,95 @@ +use std::time::Duration; + +use hdrhistogram::Histogram; + +#[derive(Default)] +pub struct Stats { + pub total_size: u64, + pub total_duration: Duration, + pub streams: usize, + pub stream_stats: StreamStats, +} + +impl Stats { + pub fn stream_finished(&mut self, stream_result: TransferResult) { + self.total_size += stream_result.size; + self.streams += 1; + + self.stream_stats + .duration_hist + .record(stream_result.duration.as_millis() as u64) + .unwrap(); + self.stream_stats + .throughput_hist + .record(stream_result.throughput as u64) + .unwrap(); + } + + pub fn print(&self, stat_name: &str) { + println!("Overall {} stats:\n", stat_name); + println!( + "Transferred {} bytes on {} streams in {:4.2?} ({:.2} MiB/s)\n", + self.total_size, + self.streams, + self.total_duration, + throughput_bps(self.total_duration, self.total_size as u64) / 1024.0 / 1024.0 + ); + + println!("Stream {} metrics:\n", stat_name); + + println!(" │ Throughput │ Duration "); + println!("──────┼───────────────┼──────────"); + + let print_metric = |label: &'static str, get_metric: fn(&Histogram) -> u64| { + println!( + " {} │ {:7.2} MiB/s │ {:>9.2?}", + label, + get_metric(&self.stream_stats.throughput_hist) as f64 / 1024.0 / 1024.0, + Duration::from_millis(get_metric(&self.stream_stats.duration_hist)) + ); + }; + + print_metric("AVG ", |hist| hist.mean() as u64); + print_metric("P0 ", |hist| hist.value_at_quantile(0.00)); + print_metric("P10 ", |hist| hist.value_at_quantile(0.10)); + print_metric("P50 ", |hist| hist.value_at_quantile(0.50)); + print_metric("P90 ", |hist| hist.value_at_quantile(0.90)); + print_metric("P100", |hist| hist.value_at_quantile(1.00)); + } +} + +pub struct StreamStats { + pub duration_hist: Histogram, + pub throughput_hist: Histogram, +} + +impl Default for StreamStats { + fn default() -> Self { + Self { + duration_hist: Histogram::::new(3).unwrap(), + throughput_hist: Histogram::::new(3).unwrap(), + } + } +} + +#[derive(Debug)] +pub struct TransferResult { + pub duration: Duration, + pub size: u64, + pub throughput: f64, +} + +impl TransferResult { + pub fn new(duration: Duration, size: u64) -> Self { + let throughput = throughput_bps(duration, size); + TransferResult { + duration, + size, + throughput, + } + } +} + +pub fn throughput_bps(duration: Duration, size: u64) -> f64 { + (size as f64) / (duration.as_secs_f64()) +} diff --git a/src/keylog.rs b/src/keylog.rs index 95bdfaf..44fcbf8 100644 --- a/src/keylog.rs +++ b/src/keylog.rs @@ -60,6 +60,12 @@ impl KeyLogFile { } } +impl Default for KeyLogFile { + fn default() -> Self { + Self::new() + } +} + impl KeyLog for KeyLogFile { fn log(&self, label: &str, client_random: &[u8], secret: &[u8]) { match self diff --git a/src/session.rs b/src/session.rs index 7a413b0..1fde892 100644 --- a/src/session.rs +++ b/src/session.rs @@ -318,7 +318,7 @@ impl Session for NoiseSession { } let (remote_s, rest) = rest.split_at(32); let mut s = [0; 32]; - self.xoodyak.decrypt(&remote_s, &mut s); + self.xoodyak.decrypt(remote_s, &mut s); let s = PublicKey::from_bytes(&s) .map_err(|_| connection_refused("invalid static public key"))?; self.remote_s = Some(s); @@ -343,9 +343,8 @@ impl Session for NoiseSession { .supported_protocols .as_ref() .expect("invalid config") - .into_iter() - .find(|proto| proto.as_slice() == alpn) - .is_some(); + .iter() + .any(|proto| proto.as_slice() == alpn); if !is_supported { return Err(connection_refused("unsupported alpn")); } @@ -356,11 +355,11 @@ impl Session for NoiseSession { } let (params, auth) = rest.split_at(rest.len() - 16); let mut transport_parameters = vec![0; params.len()]; - self.xoodyak.decrypt(¶ms, &mut transport_parameters); + self.xoodyak.decrypt(params, &mut transport_parameters); // check tag let mut tag = [0; 16]; self.xoodyak.squeeze(&mut tag); - if !bool::from(tag.ct_eq(&auth)) { + if !bool::from(tag.ct_eq(auth)) { return Err(connection_refused("invalid authentication tag")); } self.remote_transport_parameters = Some(TransportParameters::read( @@ -377,7 +376,7 @@ impl Session for NoiseSession { } let (remote_e, rest) = handshake.split_at(32); let mut e = [0; 32]; - self.xoodyak.decrypt(&remote_e, &mut e); + self.xoodyak.decrypt(remote_e, &mut e); let e = PublicKey::from_bytes(&e) .map_err(|_| connection_refused("invalid ephemeral public key"))?; self.remote_e = Some(e); @@ -393,11 +392,11 @@ impl Session for NoiseSession { } let (params, auth) = rest.split_at(rest.len() - 16); let mut transport_parameters = vec![0; params.len()]; - self.xoodyak.decrypt(¶ms, &mut transport_parameters); + self.xoodyak.decrypt(params, &mut transport_parameters); // check tag let mut tag = [0; 16]; self.xoodyak.squeeze(&mut tag); - if !bool::from(tag.ct_eq(&auth)) { + if !bool::from(tag.ct_eq(auth)) { return Err(connection_refused("invalid authentication tag")); } self.remote_transport_parameters = Some(TransportParameters::read(