From 8bb247af3bc3cc9c5a3de9aeb3f9f179a5229cf7 Mon Sep 17 00:00:00 2001 From: curlpipe <11898833+curlpipe@users.noreply.github.com> Date: Sun, 16 May 2021 16:53:39 +0100 Subject: [PATCH] Added support for quarantined subreddits (#219) * Added support for quarantined subreddits * Added confirmation wall for quarantined subreddits * Added quarantine walls to other routes and fixed case issue * Correct obsolete use of cookie() * Refactor param() and quarantine() Co-authored-by: Spike <19519553+spikecodes@users.noreply.github.com> --- src/client.rs | 8 ++-- src/main.rs | 5 +- src/post.rs | 35 +++++++++----- src/search.rs | 41 +++++++++------- src/subreddit.rs | 111 ++++++++++++++++++++++++++++++++++---------- src/user.rs | 10 ++-- src/utils.rs | 21 +++++---- templates/wall.html | 13 ++++++ 8 files changed, 174 insertions(+), 70 deletions(-) create mode 100644 templates/wall.html diff --git a/src/client.rs b/src/client.rs index 9ec5215..76bae36 100644 --- a/src/client.rs +++ b/src/client.rs @@ -60,7 +60,7 @@ async fn stream(url: &str, req: &Request) -> Result, String .map_err(|e| e.to_string()) } -fn request(url: String) -> Boxed, String>> { +fn request(url: String, quarantine: bool) -> Boxed, String>> { // Prepare the HTTPS connector. let https = hyper_rustls::HttpsConnector::with_native_roots(); @@ -75,6 +75,7 @@ fn request(url: String) -> Boxed, String>> { .header("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8") .header("Accept-Language", "en-US,en;q=0.5") .header("Connection", "keep-alive") + .header("Cookie", if quarantine { "_options=%7B%22pref_quarantine_optin%22%3A%20true%7D" } else { "" }) .body(Body::empty()); async move { @@ -89,6 +90,7 @@ fn request(url: String) -> Boxed, String>> { .map(|val| val.to_str().unwrap_or_default()) .unwrap_or_default() .to_string(), + quarantine, ) .await } else { @@ -105,7 +107,7 @@ fn request(url: String) -> Boxed, String>> { // Make a request to a Reddit API and parse the JSON response #[cached(size = 100, time = 30, result = true)] -pub async fn json(path: String) -> Result { +pub async fn json(path: String, quarantine: bool) -> Result { // Build Reddit url from path let url = format!("https://www.reddit.com{}", path); @@ -116,7 +118,7 @@ pub async fn json(path: String) -> Result { }; // Fetch the url... - match request(url.clone()).await { + match request(url.clone(), quarantine).await { Ok(response) => { // asynchronously aggregate the chunks of the body match hyper::body::aggregate(response).await { diff --git a/src/main.rs b/src/main.rs index c26d093..5034b83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -194,7 +194,10 @@ async fn main() { app.at("/settings/update").get(|r| settings::update(r).boxed()); // Subreddit services - app.at("/r/:sub").get(|r| subreddit::community(r).boxed()); + app + .at("/r/:sub") + .get(|r| subreddit::community(r).boxed()) + .post(|r| subreddit::add_quarantine_exception(r).boxed()); app .at("/r/u_:name") diff --git a/src/post.rs b/src/post.rs index 94d87e9..f6f3226 100644 --- a/src/post.rs +++ b/src/post.rs @@ -2,7 +2,9 @@ use crate::client::json; use crate::esc; use crate::server::RequestExt; +use crate::subreddit::{can_access_quarantine, quarantine}; use crate::utils::{error, format_num, format_url, param, rewrite_urls, setting, template, time, val, Author, Comment, Flags, Flair, FlairPart, Media, Post, Preferences}; + use hyper::{Body, Request, Response}; use async_recursion::async_recursion; @@ -23,18 +25,22 @@ struct PostTemplate { pub async fn item(req: Request) -> Result, String> { // Build Reddit API path let mut path: String = format!("{}.json?{}&raw_json=1", req.uri().path(), req.uri().query().unwrap_or_default()); + let sub = req.param("sub").unwrap_or_default(); + let quarantined = can_access_quarantine(&req, &sub); // Set sort to sort query parameter - let mut sort: String = param(&path, "sort"); + let sort = param(&path, "sort").unwrap_or_else(|| { + // Grab default comment sort method from Cookies + let default_sort = setting(&req, "comment_sort"); - // Grab default comment sort method from Cookies - let default_sort = setting(&req, "comment_sort"); - - // If there's no sort query but there's a default sort, set sort to default_sort - if sort.is_empty() && !default_sort.is_empty() { - sort = default_sort; - path = format!("{}.json?{}&sort={}&raw_json=1", req.uri().path(), req.uri().query().unwrap_or_default(), sort); - } + // If there's no sort query but there's a default sort, set sort to default_sort + if !default_sort.is_empty() { + path = format!("{}.json?{}&sort={}&raw_json=1", req.uri().path(), req.uri().query().unwrap_or_default(), default_sort); + default_sort + } else { + String::new() + } + }); // Log the post ID being fetched in debug mode #[cfg(debug_assertions)] @@ -44,7 +50,7 @@ pub async fn item(req: Request) -> Result, String> { let highlighted_comment = &req.param("comment_id").unwrap_or_default(); // Send a request to the url, receive JSON in response - match json(path).await { + match json(path, quarantined).await { // Otherwise, grab the JSON output from the request Ok(res) => { // Parse the JSON into Post and Comment structs @@ -61,7 +67,14 @@ pub async fn item(req: Request) -> Result, String> { }) } // If the Reddit API returns an error, exit and send error page to user - Err(msg) => error(req, msg).await, + Err(msg) => { + if msg == "quarantined" { + let sub = req.param("sub").unwrap_or_default(); + quarantine(req, sub) + } else { + error(req, msg).await + } + } } } diff --git a/src/search.rs b/src/search.rs index de75f0b..c0a7e99 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,6 +1,10 @@ // CRATES use crate::utils::{catch_random, error, format_num, format_url, param, setting, template, val, Post, Preferences}; -use crate::{client::json, RequestExt}; +use crate::{ + client::json, + subreddit::{can_access_quarantine, quarantine}, + RequestExt, +}; use askama::Template; use hyper::{Body, Request, Response}; @@ -39,27 +43,23 @@ pub async fn find(req: Request) -> Result, String> { let nsfw_results = if setting(&req, "show_nsfw") == "on" { "&include_over_18=on" } else { "" }; let path = format!("{}.json?{}{}", req.uri().path(), req.uri().query().unwrap_or_default(), nsfw_results); let sub = req.param("sub").unwrap_or_default(); + let quarantined = can_access_quarantine(&req, &sub); // Handle random subreddits if let Ok(random) = catch_random(&sub, "/find").await { return Ok(random); } - let query = param(&path, "q"); + let query = param(&path, "q").unwrap_or_default(); - let sort = if param(&path, "sort").is_empty() { - "relevance".to_string() - } else { - param(&path, "sort") - }; + let sort = param(&path, "sort").unwrap_or("relevance".to_string()); - let subreddits = if param(&path, "restrict_sr").is_empty() { - search_subreddits(&query).await - } else { - Vec::new() + let subreddits = match param(&path, "restrict_sr") { + None => search_subreddits(&query).await, + Some(_) => Vec::new() }; let url = String::from(req.uri().path_and_query().map_or("", |val| val.as_str())); - match Post::fetch(&path, String::new()).await { + match Post::fetch(&path, String::new(), quarantined).await { Ok((posts, after)) => template(SearchTemplate { posts, subreddits, @@ -67,15 +67,22 @@ pub async fn find(req: Request) -> Result, String> { params: SearchParams { q: query.replace('"', """), sort, - t: param(&path, "t"), - before: param(&path, "after"), + t: param(&path, "t").unwrap_or_default(), + before: param(&path, "after").unwrap_or_default(), after, - restrict_sr: param(&path, "restrict_sr"), + restrict_sr: param(&path, "restrict_sr").unwrap_or_default(), }, prefs: Preferences::new(req), url, }), - Err(msg) => error(req, msg).await, + Err(msg) => { + if msg == "quarantined" { + let sub = req.param("sub").unwrap_or_default(); + quarantine(req, sub) + } else { + error(req, msg).await + } + } } } @@ -83,7 +90,7 @@ async fn search_subreddits(q: &str) -> Vec { let subreddit_search_path = format!("/subreddits/search.json?q={}&limit=3", q.replace(' ', "+")); // Send a request to the url - match json(subreddit_search_path).await { + match json(subreddit_search_path, false).await { // If success, receive JSON in response Ok(response) => { match response["data"]["children"].as_array() { diff --git a/src/subreddit.rs b/src/subreddit.rs index 8c4a228..0477eb2 100644 --- a/src/subreddit.rs +++ b/src/subreddit.rs @@ -28,9 +28,20 @@ struct WikiTemplate { prefs: Preferences, } +#[derive(Template)] +#[template(path = "wall.html", escape = "none")] +struct WallTemplate { + title: String, + sub: String, + msg: String, + prefs: Preferences, + url: String, +} + // SERVICES pub async fn community(req: Request) -> Result, String> { // Build Reddit API path + let root = req.uri().path() == "/"; let subscribed = setting(&req, "subscriptions"); let front_page = setting(&req, "front_page"); let post_sort = req.cookie("post_sort").map_or_else(|| "hot".to_string(), |c| c.value().to_string()); @@ -45,6 +56,7 @@ pub async fn community(req: Request) -> Result, String> { } else { front_page.to_owned() }); + let quarantined = can_access_quarantine(&req, &sub) || root; // Handle random subreddits if let Ok(random) = catch_random(&sub, "").await { @@ -57,16 +69,16 @@ pub async fn community(req: Request) -> Result, String> { let path = format!("/r/{}/{}.json?{}&raw_json=1", sub, sort, req.uri().query().unwrap_or_default()); - match Post::fetch(&path, String::new()).await { + match Post::fetch(&path, String::new(), quarantined).await { Ok((posts, after)) => { // If you can get subreddit posts, also request subreddit metadata let sub = if !sub.contains('+') && sub != subscribed && sub != "popular" && sub != "all" { // Regular subreddit - subreddit(&sub).await.unwrap_or_default() + subreddit(&sub, quarantined).await.unwrap_or_default() } else if sub == subscribed { // Subscription feed if req.uri().path().starts_with("/r/") { - subreddit(&sub).await.unwrap_or_default() + subreddit(&sub, quarantined).await.unwrap_or_default() } else { Subreddit::default() } @@ -85,14 +97,14 @@ pub async fn community(req: Request) -> Result, String> { template(SubredditTemplate { sub, posts, - sort: (sort, param(&path, "t")), - ends: (param(&path, "after"), after), + sort: (sort, param(&path, "t").unwrap_or_default()), + ends: (param(&path, "after").unwrap_or_default(), after), prefs: Preferences::new(req), url, }) } Err(msg) => match msg.as_str() { - "quarantined" => error(req, format!("r/{} has been quarantined by Reddit", sub)).await, + "quarantined" => quarantine(req, sub), "private" => error(req, format!("r/{} is a private community", sub)).await, "banned" => error(req, format!("r/{} has been banned from Reddit", sub)).await, _ => error(req, msg).await, @@ -100,6 +112,43 @@ pub async fn community(req: Request) -> Result, String> { } } +pub fn quarantine(req: Request, sub: String) -> Result, String> { + let wall = WallTemplate { + title: format!("r/{} is quarantined", sub), + msg: "Please click the button below to continue to this subreddit.".to_string(), + url: req.uri().to_string(), + sub, + prefs: Preferences::new(req), + }; + + Ok( + Response::builder() + .status(403) + .header("content-type", "text/html") + .body(wall.render().unwrap_or_default().into()) + .unwrap_or_default(), + ) +} + +pub async fn add_quarantine_exception(req: Request) -> Result, String> { + let subreddit = req.param("sub").ok_or("Invalid URL")?; + let redir = param(&format!("?{}", req.uri().query().unwrap_or_default()), "redir").ok_or("Invalid URL")?; + let mut res = redirect(redir.to_owned()); + res.insert_cookie( + Cookie::build(&format!("allow_quaran_{}", subreddit.to_lowercase()), "true") + .path("/") + .http_only(true) + .expires(cookie::Expiration::Session) + .finish(), + ); + Ok(res) +} + +pub fn can_access_quarantine(req: &Request, sub: &str) -> bool { + // Determine if the subreddit can be accessed + setting(&req, &format!("allow_quaran_{}", sub.to_lowercase())).parse().unwrap_or_default() +} + // Sub or unsub by setting subscription cookie using response "Set-Cookie" header pub async fn subscriptions(req: Request) -> Result, String> { let sub = req.param("sub").unwrap_or_default(); @@ -114,7 +163,7 @@ pub async fn subscriptions(req: Request) -> Result, String> let mut sub_list = Preferences::new(req).subscriptions; // Retrieve list of posts for these subreddits to extract display names - let posts = json(format!("/r/{}/hot.json?raw_json=1", sub)).await?; + let posts = json(format!("/r/{}/hot.json?raw_json=1", sub), true).await?; let display_lookup: Vec<(String, &str)> = posts["data"]["children"] .as_array() .map(|list| { @@ -138,7 +187,7 @@ pub async fn subscriptions(req: Request) -> Result, String> } else { // This subreddit display name isn't known, retrieve it let path: String = format!("/r/{}/about.json?raw_json=1", part); - display = json(path).await?; + display = json(path, true).await?; display["data"]["display_name"].as_str().ok_or_else(|| "Failed to query subreddit name".to_string())? }; @@ -156,11 +205,9 @@ pub async fn subscriptions(req: Request) -> Result, String> // Redirect back to subreddit // check for redirect parameter if unsubscribing from outside sidebar - let redirect_path = param(&format!("/?{}", query), "redirect"); - let path = if redirect_path.is_empty() { - format!("/r/{}", sub) - } else { - format!("/{}/", redirect_path) + let path = match param(&format!("?{}", query), "redirect") { + Some(redirect_path) => format!("/{}/", redirect_path), + None => format!("/r/{}", sub) }; let mut res = redirect(path); @@ -183,6 +230,7 @@ pub async fn subscriptions(req: Request) -> Result, String> pub async fn wiki(req: Request) -> Result, String> { let sub = req.param("sub").unwrap_or_else(|| "reddit.com".to_string()); + let quarantined = can_access_quarantine(&req, &sub); // Handle random subreddits if let Ok(random) = catch_random(&sub, "/wiki").await { return Ok(random); @@ -191,19 +239,26 @@ pub async fn wiki(req: Request) -> Result, String> { let page = req.param("page").unwrap_or_else(|| "index".to_string()); let path: String = format!("/r/{}/wiki/{}.json?raw_json=1", sub, page); - match json(path).await { + match json(path, quarantined).await { Ok(response) => template(WikiTemplate { sub, wiki: rewrite_urls(response["data"]["content_html"].as_str().unwrap_or("

Wiki not found

")), page, prefs: Preferences::new(req), }), - Err(msg) => error(req, msg).await, + Err(msg) => { + if msg == "quarantined" { + quarantine(req, sub) + } else { + error(req, msg).await + } + } } } pub async fn sidebar(req: Request) -> Result, String> { let sub = req.param("sub").unwrap_or_else(|| "reddit.com".to_string()); + let quarantined = can_access_quarantine(&req, &sub); // Handle random subreddits if let Ok(random) = catch_random(&sub, "/about/sidebar").await { return Ok(random); @@ -213,26 +268,32 @@ pub async fn sidebar(req: Request) -> Result, String> { let path: String = format!("/r/{}/about.json?raw_json=1", sub); // Send a request to the url - match json(path).await { + match json(path, quarantined).await { // If success, receive JSON in response Ok(response) => template(WikiTemplate { wiki: format!( "{}

Moderators


    {}
", rewrite_urls(&val(&response, "description_html").replace("\\", "")), - moderators(&sub).await?.join(""), + moderators(&sub, quarantined).await?.join(""), ), sub, page: "Sidebar".to_string(), prefs: Preferences::new(req), }), - Err(msg) => error(req, msg).await, + Err(msg) => { + if msg == "quarantined" { + quarantine(req, sub) + } else { + error(req, msg).await + } + } } } -pub async fn moderators(sub: &str) -> Result, String> { +pub async fn moderators(sub: &str, quarantined: bool) -> Result, String> { // Retrieve and format the html for the moderators list Ok( - moderators_list(sub) + moderators_list(sub, quarantined) .await? .iter() .map(|m| format!("
  • {name}
  • ", name = m)) @@ -240,12 +301,12 @@ pub async fn moderators(sub: &str) -> Result, String> { ) } -async fn moderators_list(sub: &str) -> Result, String> { +async fn moderators_list(sub: &str, quarantined: bool) -> Result, String> { // Build the moderator list URL let path: String = format!("/r/{}/about/moderators.json?raw_json=1", sub); // Retrieve response - let response = json(path).await?["data"]["children"].clone(); + let response = json(path, quarantined).await?["data"]["children"].clone(); Ok( // Traverse json tree and format into list of strings response @@ -265,12 +326,12 @@ async fn moderators_list(sub: &str) -> Result, String> { } // SUBREDDIT -async fn subreddit(sub: &str) -> Result { +async fn subreddit(sub: &str, quarantined: bool) -> Result { // Build the Reddit JSON API url let path: String = format!("/r/{}/about.json?raw_json=1", sub); // Send a request to the url - match json(path).await { + match json(path, quarantined).await { // If success, receive JSON in response Ok(res) => { // Metadata regarding the subreddit @@ -286,7 +347,7 @@ async fn subreddit(sub: &str) -> Result { title: esc!(&res, "title"), description: esc!(&res, "public_description"), info: rewrite_urls(&val(&res, "description_html").replace("\\", "")), - moderators: moderators_list(sub).await?, + moderators: moderators_list(sub, quarantined).await?, icon: format_url(&icon), members: format_num(members), active: format_num(active), diff --git a/src/user.rs b/src/user.rs index 2ba024f..199653a 100644 --- a/src/user.rs +++ b/src/user.rs @@ -29,11 +29,11 @@ pub async fn profile(req: Request) -> Result, String> { ); // Retrieve other variables from Libreddit request - let sort = param(&path, "sort"); + let sort = param(&path, "sort").unwrap_or_default(); let username = req.param("name").unwrap_or_default(); // Request user posts/comments from Reddit - let posts = Post::fetch(&path, "Comment".to_string()).await; + let posts = Post::fetch(&path, "Comment".to_string(), false).await; let url = String::from(req.uri().path_and_query().map_or("", |val| val.as_str())); match posts { @@ -44,8 +44,8 @@ pub async fn profile(req: Request) -> Result, String> { template(UserTemplate { user, posts, - sort: (sort, param(&path, "t")), - ends: (param(&path, "after"), after), + sort: (sort, param(&path, "t").unwrap_or_default()), + ends: (param(&path, "after").unwrap_or_default(), after), prefs: Preferences::new(req), url, }) @@ -61,7 +61,7 @@ async fn user(name: &str) -> Result { let path: String = format!("/user/{}/about.json?raw_json=1", name); // Send a request to the url - match json(path).await { + match json(path, false).await { // If success, receive JSON in response Ok(res) => { // Grab creation date as unix timestamp diff --git a/src/utils.rs b/src/utils.rs index e65e2b3..8147815 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -217,12 +217,12 @@ pub struct Post { impl Post { // Fetch posts of a user or subreddit and return a vector of posts and the "after" value - pub async fn fetch(path: &str, fallback_title: String) -> Result<(Vec, String), String> { + pub async fn fetch(path: &str, fallback_title: String, quarantine: bool) -> Result<(Vec, String), String> { let res; let post_list; // Send a request to the url - match json(path.to_string()).await { + match json(path.to_string(), quarantine).await { // If success, receive JSON in response Ok(response) => { res = response; @@ -416,11 +416,16 @@ impl Preferences { // // Grab a query parameter from a url -pub fn param(path: &str, value: &str) -> String { - match Url::parse(format!("https://libredd.it/{}", path).as_str()) { - Ok(url) => url.query_pairs().into_owned().collect::>().get(value).unwrap_or(&String::new()).to_owned(), - _ => String::new(), - } +pub fn param(path: &str, value: &str) -> Option { + Some( + Url::parse(format!("https://libredd.it/{}", path).as_str()) + .ok()? + .query_pairs() + .into_owned() + .collect::>() + .get(value)? + .to_owned(), + ) } // Retrieve the value of a setting by name @@ -443,7 +448,7 @@ pub fn setting(req: &Request, name: &str) -> String { // Detect and redirect in the event of a random subreddit pub async fn catch_random(sub: &str, additional: &str) -> Result, String> { if (sub == "random" || sub == "randnsfw") && !sub.contains('+') { - let new_sub = json(format!("/r/{}/about.json?raw_json=1", sub)).await?["data"]["display_name"] + let new_sub = json(format!("/r/{}/about.json?raw_json=1", sub), false).await?["data"]["display_name"] .as_str() .unwrap_or_default() .to_string(); diff --git a/templates/wall.html b/templates/wall.html new file mode 100644 index 0000000..309dac2 --- /dev/null +++ b/templates/wall.html @@ -0,0 +1,13 @@ +{% extends "base.html" %} +{% block title %}{{ msg }}{% endblock %} +{% block sortstyle %}{% endblock %} +{% block content %} +
    +

    {{ title }}

    +
    +

    {{ msg }}

    +
    + +
    +
    +{% endblock %}