From cc69f947fb94cb7e6cda4cbbc496d1ac60ba1561 Mon Sep 17 00:00:00 2001 From: Daniel Valentine Date: Sun, 30 Oct 2022 11:40:48 -0600 Subject: [PATCH] Redirect /:id to canonical URL for post. This implements redirection of `/:id` (a short-form URL to a post) to the post's canonical URL. Libreddit issues a `HEAD /:id` to Reddit to get the canonical URL, and on success will send an HTTP 302 to a client with the canonical URL set in as the value of the `Location:` header. --- src/client.rs | 161 +++++++++++++++++++++++++++++++++----------------- src/main.rs | 31 ++++++---- src/utils.rs | 5 +- 3 files changed, 131 insertions(+), 66 deletions(-) diff --git a/src/client.rs b/src/client.rs index 5c32335..2a30fc9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,13 +1,37 @@ use cached::proc_macro::cached; use futures_lite::{future::Boxed, FutureExt}; -use hyper::{body, body::Buf, client, header, Body, Request, Response, Uri}; +use hyper::{body, body::Buf, client, header, Body, Method, Request, Response, Uri}; use libflate::gzip; use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; use std::{io, result::Result}; +use crate::dbg_msg; use crate::server::RequestExt; +const REDDIT_URL_BASE: &str = "https://www.reddit.com"; + +/// Gets the canonical path for a resource on Reddit. On success, a +/// `Some(Option)` will be returned. If Reddit responds with +/// anything other than an HTTP 3xx, a `None` will be returned. Any +/// other error results in an `Err(String)`. +#[cached(size = 1024, time = 600, result = true)] +pub async fn canonical_path(path: String) -> Result, String> { + let res = reddit_head(path, true).await?; + + if res.status() == 429 { + return Err("Too many requests.".to_string()); + }; + + match res.headers().get(header::LOCATION) { + None => Ok(None), + Some(hdr) => match hdr.to_str() { + Ok(val) => Ok(Some(val.to_string().trim_start_matches(REDDIT_URL_BASE).to_string())), + Err(e) => Err(e.to_string()), + }, + } +} + pub async fn proxy(req: Request, format: &str) -> Result, String> { let mut url = format!("{}?{}", format, req.uri().query().unwrap_or_default()); @@ -63,21 +87,39 @@ async fn stream(url: &str, req: &Request) -> Result, String .map_err(|e| e.to_string()) } -fn request(url: String, quarantine: bool) -> Boxed, String>> { +/// Makes a GET request to Reddit at `path`. By default, this will honor HTTP +/// 3xx codes Reddit returns and will automatically redirect. +fn reddit_get(path: String, quarantine: bool) -> Boxed, String>> { + request(&Method::GET, path, true, quarantine) +} + +/// Makes a HEAD request to Reddit at `path`. This will not follow redirects. +fn reddit_head(path: String, quarantine: bool) -> Boxed, String>> { + request(&Method::HEAD, path, false, quarantine) +} + +/// Makes a request to Reddit. If `redirect` is `true`, request_with_redirect +/// will recurse on the URL that Reddit provides in the Location HTTP header +/// in its response. +fn request(method: &'static Method, path: String, redirect: bool, quarantine: bool) -> Boxed, String>> { + // Build Reddit URL from path. + let url = format!("{}{}", REDDIT_URL_BASE, path); + // Prepare the HTTPS connector. let https = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_or_http().enable_http1().build(); // Construct the hyper client from the HTTPS connector. let client: client::Client<_, hyper::Body> = client::Client::builder().build(https); - // Build request + // Build request to Reddit. When making a GET, request gzip compression + // (Reddit doesn't do brotli yet) let builder = Request::builder() - .method("GET") + .method(method) .uri(&url) .header("User-Agent", format!("web:libreddit:{}", env!("CARGO_PKG_VERSION"))) .header("Host", "www.reddit.com") .header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8") - .header("Accept-Encoding", "gzip") // Reddit doesn't do brotli yet. + .header("Accept-Encoding", if method == Method::GET { "gzip" } else { "identity" }) .header("Accept-Language", "en-US,en;q=0.5") .header("Connection", "keep-alive") .header("Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%7D" } else { "" }) @@ -87,8 +129,15 @@ fn request(url: String, quarantine: bool) -> Boxed, String match builder { Ok(req) => match client.request(req).await { Ok(mut response) => { + // Reddit may respond with a 3xx. Decide whether or not to + // redirect based on caller params. if response.status().to_string().starts_with('3') { - request( + if !redirect { + return Ok(response); + }; + + return request( + method, response .headers() .get("Location") @@ -98,56 +147,65 @@ fn request(url: String, quarantine: bool) -> Boxed, String }) .unwrap_or_default() .to_string(), + true, quarantine, ) - .await - } else { - match response.headers().get(header::CONTENT_ENCODING) { - // Content not compressed. - None => Ok(response), + .await; + }; - // Content gzipped. - Some(hdr) => { - // Since we requested gzipped content, we expect - // to get back gzipped content. If we get - // back anything else, that's a problem. - if hdr.ne("gzip") { - return Err("Reddit response was encoded with an unsupported compressor".to_string()); - } + match response.headers().get(header::CONTENT_ENCODING) { + // Content not compressed. + None => Ok(response), - // The body must be something that implements - // std::io::Read, hence the conversion to - // bytes::buf::Buf and then transformation into a - // Reader. - let mut decompressed: Vec; - { - let mut aggregated_body = match body::aggregate(response.body_mut()).await { - Ok(b) => b.reader(), - Err(e) => return Err(e.to_string()), - }; - - let mut decoder = match gzip::Decoder::new(&mut aggregated_body) { - Ok(decoder) => decoder, - Err(e) => return Err(e.to_string()), - }; - - decompressed = Vec::::new(); - match io::copy(&mut decoder, &mut decompressed) { - Ok(_) => {} - Err(e) => return Err(e.to_string()), - }; - } - - response.headers_mut().remove(header::CONTENT_ENCODING); - response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into()); - *(response.body_mut()) = Body::from(decompressed); - - Ok(response) + // Content encoded (hopefully with gzip). + Some(hdr) => { + match hdr.to_str() { + Ok(val) => match val { + "gzip" => {} + "identity" => return Ok(response), + _ => return Err("Reddit response was encoded with an unsupported compressor".to_string()), + }, + Err(_) => return Err("Reddit response was invalid".to_string()), } + + // We get here if the body is gzip-compressed. + + // The body must be something that implements + // std::io::Read, hence the conversion to + // bytes::buf::Buf and then transformation into a + // Reader. + let mut decompressed: Vec; + { + let mut aggregated_body = match body::aggregate(response.body_mut()).await { + Ok(b) => b.reader(), + Err(e) => return Err(e.to_string()), + }; + + let mut decoder = match gzip::Decoder::new(&mut aggregated_body) { + Ok(decoder) => decoder, + Err(e) => return Err(e.to_string()), + }; + + decompressed = Vec::::new(); + match io::copy(&mut decoder, &mut decompressed) { + Ok(_) => {} + Err(e) => return Err(e.to_string()), + }; + } + + response.headers_mut().remove(header::CONTENT_ENCODING); + response.headers_mut().insert(header::CONTENT_LENGTH, decompressed.len().into()); + *(response.body_mut()) = Body::from(decompressed); + + Ok(response) } } } - Err(e) => Err(e.to_string()), + Err(e) => { + dbg_msg!("{} {}: {}", method, path, e); + + Err(e.to_string()) + } }, Err(_) => Err("Post url contains non-ASCII characters".to_string()), } @@ -158,9 +216,6 @@ fn request(url: String, quarantine: bool) -> Boxed, String // Make a request to a Reddit API and parse the JSON response #[cached(size = 100, time = 30, result = true)] pub async fn json(path: String, quarantine: bool) -> Result { - // Build Reddit url from path - let url = format!("https://www.reddit.com{}", path); - // Closure to quickly build errors let err = |msg: &str, e: String| -> Result { // eprintln!("{} - {}: {}", url, msg, e); @@ -168,7 +223,7 @@ pub async fn json(path: String, quarantine: bool) -> Result { }; // Fetch the url... - match request(url.clone(), quarantine).await { + match reddit_get(path.clone(), quarantine).await { Ok(response) => { let status = response.status(); @@ -186,7 +241,7 @@ pub async fn json(path: String, quarantine: bool) -> Result { .as_str() .unwrap_or_else(|| { json["message"].as_str().unwrap_or_else(|| { - eprintln!("{} - Error parsing reddit error", url); + eprintln!("{}{} - Error parsing reddit error", REDDIT_URL_BASE, path); "Error parsing reddit error" }) }) diff --git a/src/main.rs b/src/main.rs index 4ce4a96..8592e3c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,7 @@ use futures_lite::FutureExt; use hyper::{header::HeaderValue, Body, Request, Response}; mod client; -use client::proxy; +use client::{canonical_path, proxy}; use server::RequestExt; use utils::{error, redirect, ThemeAssets}; @@ -259,9 +259,6 @@ async fn main() { app.at("/r/:sub/:sort").get(|r| subreddit::community(r).boxed()); - // Comments handler - app.at("/comments/:id").get(|r| post::item(r).boxed()); - // Front page app.at("/").get(|r| subreddit::community(r).boxed()); @@ -279,13 +276,25 @@ async fn main() { // Handle about pages app.at("/about").get(|req| error(req, "About pages aren't added yet".to_string()).boxed()); - app.at("/:id").get(|req: Request| match req.param("id").as_deref() { - // Sort front page - Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).boxed(), - // Short link for post - Some(id) if id.len() > 4 && id.len() < 7 => post::item(req).boxed(), - // Error message for unknown pages - _ => error(req, "Nothing here".to_string()).boxed(), + app.at("/:id").get(|req: Request| { + Box::pin(async move { + match req.param("id").as_deref() { + // Sort front page + Some("best" | "hot" | "new" | "top" | "rising" | "controversial") => subreddit::community(req).await, + + // Short link for post + Some(id) if (5..7).contains(&id.len()) => match canonical_path(format!("/{}", id)).await { + Ok(path_opt) => match path_opt { + Some(path) => Ok(redirect(path)), + None => error(req, "Post ID is invalid. It may point to a post on a community that has been banned.").await, + }, + Err(e) => error(req, e).await, + }, + + // Error message for unknown pages + _ => error(req, "Nothing here".to_string()).await, + } + }) }); // Default service in case no routes match diff --git a/src/utils.rs b/src/utils.rs index 42243d5..8e738ed 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -716,10 +716,11 @@ pub fn redirect(path: String) -> Response { .unwrap_or_default() } -pub async fn error(req: Request, msg: String) -> Result, String> { +/// Renders a generic error landing page. +pub async fn error(req: Request, msg: impl ToString) -> Result, String> { let url = req.uri().to_string(); let body = ErrorTemplate { - msg, + msg: msg.to_string(), prefs: Preferences::new(req), url, }