diff --git a/Cargo.lock b/Cargo.lock index 9937801..97739d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler32" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" + [[package]] name = "aho-corasick" version = "0.7.19" @@ -11,6 +17,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "askama" version = "0.11.1" @@ -109,6 +130,27 @@ dependencies = [ "generic-array", ] +[[package]] +name = "brotli" +version = "3.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1a0b1dbcc8ae29329621f8d4f0d835787c1c38bb1401979b49d13b0b305ff68" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ad2d4653bf5ca36ae797b1f4bb4dbddb60ce49ca4aed8a2ce4829f60425b80" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bstr" version = "0.2.17" @@ -233,6 +275,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -398,6 +449,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "globset" version = "0.4.9" @@ -580,18 +642,41 @@ version = "0.2.137" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" +[[package]] +name = "libflate" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05605ab2bce11bcfc0e9c635ff29ef8b2ea83f29be257ee7d730cac3ee373093" +dependencies = [ + "adler32", + "crc32fast", + "libflate_lz77", +] + +[[package]] +name = "libflate_lz77" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39a734c0493409afcd49deee13c006a04e3586b9761a03543c6272c9c51f2f5a" +dependencies = [ + "rle-decode-fast", +] + [[package]] name = "libreddit" version = "0.23.1" dependencies = [ "askama", "async-recursion", + "brotli", "cached", "clap", "cookie", "futures-lite", "hyper", "hyper-rustls", + "libflate", + "lipsum", "percent-encoding", "regex", "route-recognizer", @@ -603,6 +688,16 @@ dependencies = [ "url", ] +[[package]] +name = "lipsum" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8451846f1f337e44486666989fbce40be804da139d5a4477d6b88ece5dc69f4" +dependencies = [ + "rand", + "rand_chacha", +] + [[package]] name = "lock_api" version = "0.4.9" @@ -756,6 +851,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "ppv-lite86" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" + [[package]] name = "proc-macro2" version = "1.0.47" @@ -774,6 +875,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -815,6 +946,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "rle-decode-fast" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" + [[package]] name = "route-recognizer" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index adf135a..42b12db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,8 @@ tokio = { version = "1.21.2", features = ["full"] } time = "0.3.16" url = "2.3.1" rust-embed = { version = "6.4.2", features = ["include-exclude"] } +libflate = "1.2.0" +brotli = { version = "3.3.4", features = ["std"] } + +[dev-dependencies] +lipsum = "0.8.2" diff --git a/src/client.rs b/src/client.rs index da271dd..5c32335 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,9 +1,10 @@ use cached::proc_macro::cached; use futures_lite::{future::Boxed, FutureExt}; -use hyper::{body::Buf, client, Body, Request, Response, Uri}; +use hyper::{body, body::Buf, client, header, Body, Request, Response, Uri}; +use libflate::gzip; use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; -use std::result::Result; +use std::{io, result::Result}; use crate::server::RequestExt; @@ -76,6 +77,7 @@ fn request(url: String, quarantine: bool) -> Boxed, String .header("User-Agent", format!("web:libreddit:{}", env!("CARGO_PKG_VERSION"))) .header("Host", "www.reddit.com") .header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8") + .header("Accept-Encoding", "gzip") // Reddit doesn't do brotli yet. .header("Accept-Language", "en-US,en;q=0.5") .header("Connection", "keep-alive") .header("Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%7D" } else { "" }) @@ -84,7 +86,7 @@ fn request(url: String, quarantine: bool) -> Boxed, String async move { match builder { Ok(req) => match client.request(req).await { - Ok(response) => { + Ok(mut response) => { if response.status().to_string().starts_with('3') { request( response @@ -100,7 +102,49 @@ fn request(url: String, quarantine: bool) -> Boxed, String ) .await } else { - Ok(response) + match response.headers().get(header::CONTENT_ENCODING) { + // Content not compressed. + None => Ok(response), + + // Content gzipped. + Some(hdr) => { + // Since we requested gzipped content, we expect + // to get back gzipped content. If we get + // back anything else, that's a problem. + if hdr.ne("gzip") { + return Err("Reddit response was encoded with an unsupported compressor".to_string()); + } + + // The body must be something that implements + // std::io::Read, hence the conversion to + // bytes::buf::Buf and then transformation into a + // Reader. + let mut decompressed: Vec; + { + let mut aggregated_body = match body::aggregate(response.body_mut()).await { + Ok(b) => b.reader(), + Err(e) => return Err(e.to_string()), + }; + + let mut decoder = match gzip::Decoder::new(&mut aggregated_body) { + Ok(decoder) => decoder, + Err(e) => return Err(e.to_string()), + }; + + decompressed = Vec::::new(); + match io::copy(&mut decoder, &mut decompressed) { + Ok(_) => {} + Err(e) => return Err(e.to_string()), + }; + } + + response.headers_mut().remove(header::CONTENT_ENCODING); + response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into()); + *(response.body_mut()) = Body::from(decompressed); + + Ok(response) + } + } } } Err(e) => Err(e.to_string()), diff --git a/src/server.rs b/src/server.rs index 979dbd7..c277b6b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,17 +1,80 @@ +use brotli::enc::{BrotliCompress, BrotliEncoderParams}; +use cached::proc_macro::cached; use cookie::Cookie; +use core::f64; use futures_lite::{future::Boxed, Future, FutureExt}; use hyper::{ - header::HeaderValue, + body, + body::HttpBody, + header, service::{make_service_fn, service_fn}, HeaderMap, }; use hyper::{Body, Method, Request, Response, Server as HyperServer}; +use libflate::gzip; use route_recognizer::{Params, Router}; -use std::{pin::Pin, result::Result}; +use std::{ + cmp::Ordering, + io, + pin::Pin, + result::Result, + str::{from_utf8, Split}, + string::ToString, +}; use time::Duration; +use crate::dbg_msg; + type BoxResponse = Pin, String>> + Send>>; +/// Compressors for the response Body, in ascending order of preference. +#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +enum CompressionType { + Passthrough, + Gzip, + Brotli, +} + +/// All browsers support gzip, so if we are given `Accept-Encoding: *`, deliver +/// gzipped-content. +/// +/// Brotli would be nice universally, but Safari (iOS, iPhone, macOS) reportedly +/// doesn't support it yet. +const DEFAULT_COMPRESSOR: CompressionType = CompressionType::Gzip; + +impl CompressionType { + /// Returns a `CompressionType` given a content coding + /// in [RFC 7231](https://datatracker.ietf.org/doc/html/rfc7231#section-5.3.4) + /// format. + fn parse(s: &str) -> Option { + let c = match s { + // Compressors we support. + "gzip" => CompressionType::Gzip, + "br" => CompressionType::Brotli, + + // The wildcard means that we can choose whatever + // compression we prefer. In this case, use the + // default. + "*" => DEFAULT_COMPRESSOR, + + // Compressor not supported. + _ => return None, + }; + + Some(c) + } +} + +impl ToString for CompressionType { + fn to_string(&self) -> String { + match self { + CompressionType::Gzip => "gzip".to_string(), + CompressionType::Brotli => "br".to_string(), + _ => String::new(), + } + } +} + pub struct Route<'a> { router: &'a mut Router) -> BoxResponse>, path: String, @@ -97,7 +160,7 @@ impl ResponseExt for Response { } fn insert_cookie(&mut self, cookie: Cookie) { - if let Ok(val) = HeaderValue::from_str(&cookie.to_string()) { + if let Ok(val) = header::HeaderValue::from_str(&cookie.to_string()) { self.headers_mut().append("Set-Cookie", val); } } @@ -106,7 +169,7 @@ impl ResponseExt for Response { let mut cookie = Cookie::named(name); cookie.set_path("/"); cookie.set_max_age(Duration::seconds(1)); - if let Ok(val) = HeaderValue::from_str(&cookie.to_string()) { + if let Ok(val) = header::HeaderValue::from_str(&cookie.to_string()) { self.headers_mut().append("Set-Cookie", val); } } @@ -156,10 +219,11 @@ impl Server { // let shared_router = router.clone(); async move { Ok::<_, String>(service_fn(move |req: Request| { - let headers = default_headers.clone(); + let req_headers = req.headers().clone(); + let def_headers = default_headers.clone(); // Remove double slashes and decode encoded slashes - let mut path = req.uri().path().replace("//", "/").replace("%2F","/"); + let mut path = req.uri().path().replace("//", "/").replace("%2F", "/"); // Remove trailing slashes if path != "/" && path.ends_with('/') { @@ -176,26 +240,20 @@ impl Server { // Run the route's function let func = (found.handler().to_owned().to_owned())(parammed); async move { - let res: Result, String> = func.await; - // Add default headers to response - res.map(|mut response| { - response.headers_mut().extend(headers); - response - }) + match func.await { + Ok(mut res) => { + res.headers_mut().extend(def_headers); + let _ = compress_response(req_headers, &mut res).await; + + Ok(res) + } + Err(msg) => new_boilerplate(def_headers, req_headers, 500, Body::from(msg)).await, + } } .boxed() } // If there was a routing error - Err(e) => async move { - // Return a 404 error - let res: Result, String> = Ok(Response::builder().status(404).body(e.into()).unwrap_or_default()); - // Add default headers to response - res.map(|mut response| { - response.headers_mut().extend(headers); - response - }) - } - .boxed(), + Err(e) => async move { new_boilerplate(def_headers, req_headers, 404, e.into()).await }.boxed(), } })) } @@ -213,3 +271,480 @@ impl Server { server.boxed() } } + +/// Create a boilerplate Response for error conditions. This response will be +/// compressed if requested by client. +async fn new_boilerplate( + default_headers: HeaderMap, + req_headers: HeaderMap, + status: u16, + body: Body, +) -> Result, String> { + match Response::builder().status(status).body(body) { + Ok(mut res) => { + let _ = compress_response(req_headers, &mut res).await; + + res.headers_mut().extend(default_headers.clone()); + Ok(res) + } + Err(msg) => Err(msg.to_string()), + } +} + +/// Determines the desired compressor based on the Accept-Encoding header. +/// +/// This function will honor the [q-value](https://developer.mozilla.org/en-US/docs/Glossary/Quality_values) +/// for each compressor. The q-value is an optional parameter, a decimal value +/// on \[0..1\], to order the compressors by preference. An Accept-Encoding value +/// with no q-values is also accepted. +/// +/// Here are [examples](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding#examples) +/// of valid Accept-Encoding headers. +/// +/// ```http +/// Accept-Encoding: gzip +/// Accept-Encoding: gzip, compress, br +/// Accept-Encoding: br;q=1.0, gzip;q=0.8, *;q=0.1 +/// ``` +fn determine_compressor(accept_encoding: &str) -> Option { + if accept_encoding.is_empty() { + return None; + }; + + // Keep track of the compressor candidate based on both the client's + // preference and our own. Concrete examples: + // + // 1. "Accept-Encoding: gzip, br" => assuming we like brotli more than + // gzip, and the browser supports brotli, we choose brotli + // + // 2. "Accept-Encoding: gzip;q=0.8, br;q=0.3" => the client has stated a + // preference for gzip over brotli, so we choose gzip + // + // To do this, we need to define a struct which contains the requested + // requested compressor (abstracted as a CompressionType enum) and the + // q-value. If no q-value is defined for the compressor, we assume one of + // 1.0. We first compare compressor candidates by comparing q-values, and + // then CompressionTypes. We keep track of whatever is the greatest per our + // ordering. + + struct CompressorCandidate { + alg: CompressionType, + q: f64, + } + + impl Ord for CompressorCandidate { + fn cmp(&self, other: &Self) -> Ordering { + // Compare q-values. Break ties with the + // CompressionType values. + + match self.q.total_cmp(&other.q) { + Ordering::Equal => self.alg.cmp(&other.alg), + ord => ord, + } + } + } + + impl PartialOrd for CompressorCandidate { + fn partial_cmp(&self, other: &Self) -> Option { + // Guard against NAN, both on our end and on the other. + if self.q.is_nan() || other.q.is_nan() { + return None; + }; + + // f64 and CompressionType are ordered, except in the case + // where the f64 is NAN (which we checked against), so we + // can safely return a Some here. + Some(self.cmp(other)) + } + } + + impl PartialEq for CompressorCandidate { + fn eq(&self, other: &Self) -> bool { + (self.q == other.q) && (self.alg == other.alg) + } + } + + impl Eq for CompressorCandidate {} + + // This is the current candidate. + // + // Assmume no candidate so far. We do this by assigning the sentinel value + // of negative infinity to the q-value. If this value is negative infinity, + // that means there was no viable compressor candidate. + let mut cur_candidate = CompressorCandidate { + alg: CompressionType::Passthrough, + q: f64::NEG_INFINITY, + }; + + // This loop reads the requested compressors and keeps track of whichever + // one has the highest priority per our heuristic. + for val in accept_encoding.to_string().split(',') { + let mut q: f64 = 1.0; + + // The compressor and q-value (if the latter is defined) + // will be delimited by semicolons. + let mut spl: Split = val.split(';'); + + // Get the compressor. For example, in + // gzip;q=0.8 + // this grabs "gzip" in the string. It + // will further validate the compressor against the + // list of those we support. If it is not supported, + // we move onto the next one. + let compressor: CompressionType = match spl.next() { + // CompressionType::parse will return the appropriate enum given + // a string. For example, it will return CompressionType::Gzip + // when given "gzip". + Some(s) => match CompressionType::parse(s.trim()) { + Some(candidate) => candidate, + + // We don't support the requested compression algorithm. + None => continue, + }, + + // We should never get here, but I'm paranoid. + None => continue, + }; + + // Get the q-value. This might not be defined, in which case assume + // 1.0. + if let Some(s) = spl.next() { + if !(s.len() > 2 && s.starts_with("q=")) { + // If the q-value is malformed, the header is malformed, so + // abort. + return None; + } + + match s[2..].parse::() { + Ok(val) => { + if (0.0..=1.0).contains(&val) { + q = val; + } else { + // If the value is outside [0..1], header is malformed. + // Abort. + return None; + }; + } + Err(_) => { + // If this isn't a f64, then assume a malformed header + // value and abort. + return None; + } + } + }; + + // If new_candidate > cur_candidate, make new_candidate the new + // cur_candidate. But do this safely! It is very possible that + // someone gave us the string "NAN", which (&str).parse:: + // will happily translate to f64::NAN. + let new_candidate = CompressorCandidate { alg: compressor, q }; + if let Some(ord) = new_candidate.partial_cmp(&cur_candidate) { + if ord == Ordering::Greater { + cur_candidate = new_candidate; + } + }; + } + + if cur_candidate.q != f64::NEG_INFINITY { + Some(cur_candidate.alg) + } else { + None + } +} + +/// Compress the response body, if possible or desirable. The Body will be +/// compressed in place, and a new header Content-Encoding will be set +/// indicating the compression algorithm. +/// +/// This function deems Body eligible compression if and only if the following +/// conditions are met: +/// +/// 1. the HTTP client requests a compression encoding in the Content-Encoding +/// header (hence the need for the req_headers); +/// +/// 2. the content encoding corresponds to a compression algorithm we support; +/// +/// 3. the Media type in the Content-Type response header is text with any +/// subtype (e.g. text/plain) or application/json. +/// +/// compress_response returns Ok on successful compression, or if not all three +/// conditions above are met. It returns Err if there was a problem decoding +/// any header in either req_headers or res, but res will remain intact. +/// +/// This function logs errors to stderr, but only in debug mode. No information +/// is logged in release builds. +async fn compress_response(req_headers: HeaderMap, res: &mut Response) -> Result<(), String> { + // Check if the data is eligible for compression. + if let Some(hdr) = res.headers().get(header::CONTENT_TYPE) { + match from_utf8(hdr.as_bytes()) { + Ok(val) => { + let s = val.to_string(); + + // TODO: better determination of what is eligible for compression + if !(s.starts_with("text/") || s.starts_with("application/json")) { + return Ok(()); + }; + } + Err(e) => { + dbg_msg!(e); + return Err(e.to_string()); + } + }; + } else { + // Response declares no Content-Type. Assume for simplicity that it + // cannot be compressed. + return Ok(()); + }; + + // Don't bother if the size of the size of the response body will fit + // within an IP frame (less the bytes that make up the TCP/IP and HTTP + // headers). + if res.body().size_hint().lower() < 1452 { + return Ok(()); + }; + + // Quick and dirty closure for extracting a header from the request and + // returning it as a &str. + let get_req_header = |k: header::HeaderName| -> Option<&str> { + match req_headers.get(k) { + Some(hdr) => match from_utf8(hdr.as_bytes()) { + Ok(val) => Some(val), + + #[cfg(debug_assertions)] + Err(e) => { + dbg_msg!(e); + None + } + + #[cfg(not(debug_assertions))] + Err(_) => None, + }, + None => None, + } + }; + + // Check to see which compressor is requested, and if we can use it. + let accept_encoding: &str = match get_req_header(header::ACCEPT_ENCODING) { + Some(val) => val, + None => return Ok(()), // Client requested no compression. + }; + + let compressor: CompressionType = match determine_compressor(accept_encoding) { + Some(c) => c, + None => return Ok(()), + }; + + // Get the body from the response. + let body_bytes: Vec = match body::to_bytes(res.body_mut()).await { + Ok(b) => b.to_vec(), + Err(e) => { + dbg_msg!(e); + return Err(e.to_string()); + } + }; + + // Compress! + match compress_body(compressor, body_bytes) { + Ok(compressed) => { + // We get here iff the compression was successful. Replace the body + // with the compressed payload, and add the appropriate + // Content-Encoding header in the response. + res.headers_mut().insert(header::CONTENT_ENCODING, compressor.to_string().parse().unwrap()); + *(res.body_mut()) = Body::from(compressed); + } + + Err(e) => return Err(e), + } + + Ok(()) +} + +/// Compresses a `Vec` given a [`CompressionType`]. +/// +/// This is a helper function for [`compress_response`] and should not be +/// called directly. + +// I've chosen a TTL of 600 (== 10 minutes) since compression is +// computationally expensive and we don't want to be doing it often. This is +// larger than client::json's TTL, but that's okay, because if client::json +// returns a new serde_json::Value, body_bytes changes, so this function will +// execute again. +#[cached(size = 100, time = 600, result = true)] +fn compress_body(compressor: CompressionType, body_bytes: Vec) -> Result, String> { + // io::Cursor implements io::Read, required for our encoders. + let mut reader = io::Cursor::new(body_bytes); + + let compressed: Vec = match compressor { + CompressionType::Gzip => { + let mut gz: gzip::Encoder> = match gzip::Encoder::new(Vec::new()) { + Ok(gz) => gz, + Err(e) => { + dbg_msg!(e); + return Err(e.to_string()); + } + }; + + match io::copy(&mut reader, &mut gz) { + Ok(_) => match gz.finish().into_result() { + Ok(compressed) => compressed, + Err(e) => { + dbg_msg!(e); + return Err(e.to_string()); + } + }, + Err(e) => { + dbg_msg!(e); + return Err(e.to_string()); + } + } + } + + CompressionType::Brotli => { + // We may want to make the compression parameters configurable + // in the future. For now, the defaults are sufficient. + let brotli_params = BrotliEncoderParams::default(); + + let mut compressed = Vec::::new(); + match BrotliCompress(&mut reader, &mut compressed, &brotli_params) { + Ok(_) => compressed, + Err(e) => { + dbg_msg!(e); + return Err(e.to_string()); + } + } + } + + // This arm is for any requested compressor for which we don't yet + // have an implementation. + _ => { + let msg = "unsupported compressor".to_string(); + return Err(msg); + } + }; + + Ok(compressed) +} + +#[cfg(test)] +mod tests { + use super::*; + use brotli::Decompressor as BrotliDecompressor; + use futures_lite::future::block_on; + use lipsum::lipsum; + use std::{boxed::Box, io}; + + #[test] + fn test_determine_compressor() { + // Single compressor given. + assert_eq!(determine_compressor("unsupported"), None); + assert_eq!(determine_compressor("gzip"), Some(CompressionType::Gzip)); + assert_eq!(determine_compressor("*"), Some(DEFAULT_COMPRESSOR)); + + // Multiple compressors. + assert_eq!(determine_compressor("gzip, br"), Some(CompressionType::Brotli)); + assert_eq!(determine_compressor("gzip;q=0.8, br;q=0.3"), Some(CompressionType::Gzip)); + assert_eq!(determine_compressor("br, gzip"), Some(CompressionType::Brotli)); + assert_eq!(determine_compressor("br;q=0.3, gzip;q=0.4"), Some(CompressionType::Gzip)); + + // Invalid q-values. + assert_eq!(determine_compressor("gzip;q=NAN"), None); + } + + #[test] + fn test_compress_response() { + // This macro generates an Accept-Encoding header value given any number of + // compressors. + macro_rules! ae_gen { + ($x:expr) => { + $x.to_string().as_str() + }; + + ($x:expr, $($y:expr),+) => { + format!("{}, {}", $x.to_string(), ae_gen!($($y),+)).as_str() + }; + } + + for accept_encoding in [ + "*", + ae_gen!(CompressionType::Gzip), + ae_gen!(CompressionType::Brotli, CompressionType::Gzip), + ae_gen!(CompressionType::Brotli), + ] { + // Determine what the expected encoding should be based on both the + // specific encodings we accept. + let expected_encoding: CompressionType = match determine_compressor(accept_encoding) { + Some(s) => s, + None => panic!("determine_compressor(accept_encoding) => None"), + }; + + // Build headers with our Accept-Encoding. + let mut req_headers = HeaderMap::new(); + req_headers.insert(header::ACCEPT_ENCODING, header::HeaderValue::from_str(accept_encoding).unwrap()); + + // Build test response. + let lorem_ipsum: String = lipsum(10000); + let expected_lorem_ipsum = Vec::::from(lorem_ipsum.as_str()); + let mut res = Response::builder() + .status(200) + .header(header::CONTENT_TYPE, "text/plain") + .body(Body::from(lorem_ipsum)) + .unwrap(); + + // Perform the compression. + if let Err(e) = block_on(compress_response(req_headers, &mut res)) { + panic!("compress_response(req_headers, &mut res) => Err(\"{}\")", e); + }; + + // If the content was compressed, we expect the Content-Encoding + // header to be modified. + assert_eq!( + res + .headers() + .get(header::CONTENT_ENCODING) + .unwrap_or_else(|| panic!("missing content-encoding header")) + .to_str() + .unwrap_or_else(|_| panic!("failed to convert Content-Encoding header::HeaderValue to String")), + expected_encoding.to_string() + ); + + // Decompress body and make sure it's equal to what we started + // with. + // + // In the case of no compression, just make sure the "new" body in + // the Response is the same as what with which we start. + let body_vec = match block_on(body::to_bytes(res.body_mut())) { + Ok(b) => b.to_vec(), + Err(e) => panic!("{}", e), + }; + + if expected_encoding == CompressionType::Passthrough { + assert!(body_vec.eq(&expected_lorem_ipsum)); + continue; + } + + // This provides an io::Read for the underlying body. + let mut body_cursor: io::Cursor> = io::Cursor::new(body_vec); + + // Match the appropriate decompresor for the given + // expected_encoding. + let mut decoder: Box = match expected_encoding { + CompressionType::Gzip => match gzip::Decoder::new(&mut body_cursor) { + Ok(dgz) => Box::new(dgz), + Err(e) => panic!("{}", e), + }, + + CompressionType::Brotli => Box::new(BrotliDecompressor::new(body_cursor, expected_lorem_ipsum.len())), + + _ => panic!("no decompressor for {}", expected_encoding.to_string()), + }; + + let mut decompressed = Vec::::new(); + match io::copy(&mut decoder, &mut decompressed) { + Ok(_) => {} + Err(e) => panic!("{}", e), + }; + + assert!(decompressed.eq(&expected_lorem_ipsum)); + } + } +} diff --git a/src/utils.rs b/src/utils.rs index 2691d16..42243d5 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -13,6 +13,21 @@ use std::str::FromStr; use time::{macros::format_description, Duration, OffsetDateTime}; use url::Url; +/// Write a message to stderr on debug mode. This function is a no-op on +/// release code. +#[macro_export] +macro_rules! dbg_msg { + ($x:expr) => { + #[cfg(debug_assertions)] + eprintln!("{}:{}: {}", file!(), line!(), $x.to_string()) + }; + + ($($x:expr),+) => { + #[cfg(debug_assertions)] + dbg_msg!(format!($($x),+)) + }; +} + // Post flair with content, background color and foreground color pub struct Flair { pub flair_parts: Vec,