diff --git a/src/client.rs b/src/client.rs index 2a30fc9..0d075e1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -11,18 +11,37 @@ use crate::server::RequestExt; const REDDIT_URL_BASE: &str = "https://www.reddit.com"; -/// Gets the canonical path for a resource on Reddit. On success, a -/// `Some(Option)` will be returned. If Reddit responds with -/// anything other than an HTTP 3xx, a `None` will be returned. Any -/// other error results in an `Err(String)`. +/// Gets the canonical path for a resource on Reddit. This is accomplished by +/// making a `HEAD` request to Reddit at the path given in `path`. +/// +/// This function returns `Ok(Some(path))`, where `path`'s value is identical +/// to that of the value of the argument `path`, if Reddit responds to our +/// `HEAD` request with a 2xx-family HTTP code. It will also return an +/// `Ok(Some(String))` if Reddit responds to our `HEAD` request with a +/// `Location` header in the response, and the HTTP code is in the 3xx-family; +/// the `String` will contain the path as reported in `Location`. The return +/// value is `Ok(None)` if Reddit responded with a 3xx, but did not provide a +/// `Location` header. An `Err(String)` is returned if Reddit responds with a +/// 429, or if we were unable to decode the value in the `Location` header. #[cached(size = 1024, time = 600, result = true)] pub async fn canonical_path(path: String) -> Result, String> { - let res = reddit_head(path, true).await?; + let res = reddit_head(path.clone(), true).await?; if res.status() == 429 { return Err("Too many requests.".to_string()); }; + // If Reddit responds with a 2xx, then the path is already canonical. + if res.status().to_string().starts_with('2') { + return Ok(Some(path)); + } + + // If Reddit responds with anything other than 3xx (except for the 2xx as + // above), return a None. + if !res.status().to_string().starts_with('3') { + return Ok(None); + } + match res.headers().get(header::LOCATION) { None => Ok(None), Some(hdr) => match hdr.to_str() { @@ -111,8 +130,8 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo // Construct the hyper client from the HTTPS connector. let client: client::Client<_, hyper::Body> = client::Client::builder().build(https); - // Build request to Reddit. When making a GET, request gzip compression - // (Reddit doesn't do brotli yet) + // Build request to Reddit. When making a GET, request gzip compression. + // (Reddit doesn't do brotli yet.) let builder = Request::builder() .method(method) .uri(&url) @@ -187,9 +206,8 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo }; decompressed = Vec::::new(); - match io::copy(&mut decoder, &mut decompressed) { - Ok(_) => {} - Err(e) => return Err(e.to_string()), + if let Err(e) = io::copy(&mut decoder, &mut decompressed) { + return Err(e.to_string()); }; }