From e59b2b13464717885029a7b5f893590cb88a8349 Mon Sep 17 00:00:00 2001 From: spikecodes <19519553+spikecodes@users.noreply.github.com> Date: Tue, 9 Mar 2021 22:13:46 -0800 Subject: [PATCH] Custom HTTP client with Rustls --- Cargo.toml | 2 +- src/main.rs | 5 ++- src/proxy.rs | 68 ++++++++++++++++++++++++++----- src/utils.rs | 111 ++++++++++++++++++++++++++++++++++++--------------- 4 files changed, 141 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6739f69..e72deac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ edition = "2018" [dependencies] tide = { version = "0.16.0", default-features = false, features = ["h1-server", "cookies"] } async-std = { version = "1.9.0", features = ["attributes"] } -surf = { version = "2.2.0", default-features = false, features = ["curl-client", "encoding"] } +async-tls = { version = "0.11.0", default-features = false, features = ["client"] } cached = "0.23.0" askama = { version = "0.10.5", default-features = false } serde = { version = "1.0.124", features = ["derive"] } diff --git a/src/main.rs b/src/main.rs index c312064..488d366 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,10 +2,11 @@ #![forbid(unsafe_code)] #![warn(clippy::pedantic, clippy::all)] #![allow( - clippy::clippy::needless_pass_by_value, + clippy::needless_pass_by_value, clippy::match_wildcard_for_single_variants, clippy::cast_possible_truncation, - clippy::similar_names + clippy::similar_names, + clippy::cast_possible_wrap )] // Reference local files diff --git a/src/proxy.rs b/src/proxy.rs index de53fb5..0a62a0f 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,5 +1,6 @@ -use surf::Body; -use tide::{Request, Response}; +use async_std::{io, net::TcpStream, prelude::*}; +use async_tls::TlsConnector; +use tide::{http::url::Url, Request, Response}; pub async fn handler(req: Request<()>, format: &str, params: Vec<&str>) -> tide::Result { let mut url = format.to_string(); @@ -13,20 +14,69 @@ pub async fn handler(req: Request<()>, format: &str, params: Vec<&str>) -> tide: } async fn request(url: String) -> tide::Result { - match surf::get(url).await { - Ok(res) => { - let content_length = res.header("Content-Length").map(std::string::ToString::to_string).unwrap_or_default(); - let content_type = res.content_type().map(|m| m.to_string()).unwrap_or_default(); + // Parse url into parts + let parts = Url::parse(&url).unwrap(); + let host = parts.host().unwrap().to_string(); + let domain = parts.domain().unwrap_or_default(); + let path = format!("{}?{}", parts.path(), parts.query().unwrap_or_default()); + // Build reddit-compliant user agent for Libreddit + let user_agent = format!("web:libreddit:{}", env!("CARGO_PKG_VERSION")); + + // Construct a request body + let req = format!( + "GET {} HTTP/1.1\r\nHost: {}\r\nAccept: */*\r\nConnection: close\r\nUser-Agent: {}\r\n\r\n", + path, host, user_agent + ); + + // Initialize TLS connector for requests + let connector = TlsConnector::default(); + + // Open a TCP connection + let tcp_stream = TcpStream::connect(format!("{}:443", domain)).await.unwrap(); + + // Use the connector to start the handshake process + let mut tls_stream = connector.connect(domain, tcp_stream).await.unwrap(); + + // Write the aforementioned HTTP request to the stream + tls_stream.write_all(req.as_bytes()).await.unwrap(); + + // And read the response + let mut writer = Vec::new(); + io::copy(&mut tls_stream, &mut writer).await.unwrap(); + + // Find the delimiter which separates the body and headers + match (0..writer.len()).find(|i| writer[i.to_owned()] == 10_u8 && writer[i - 2] == 10_u8) { + Some(delim) => { + // Split the response into the body and headers + let split = writer.split_at(delim); + let headers_str = String::from_utf8_lossy(split.0); + let headers = headers_str.split("\r\n").collect::>(); + let body = split.1[1..split.1.len()].to_vec(); + + // Parse the status code from the first header line + let status: u16 = headers[0].split(' ').collect::>()[1].parse().unwrap_or_default(); + + // Define a closure for easier header fetching + let header = |name: &str| { + headers + .iter() + .find(|x| x.starts_with(name)) + .map(|f| f.split(": ").collect::>()[1]) + .unwrap_or_default() + }; + + let content_length = header("Content-Length"); + let content_type = header("Content-Type"); Ok( - Response::builder(res.status()) - .body(Body::from_reader(res, None)) + Response::builder(status) + .body(tide::http::Body::from_bytes(body)) .header("Cache-Control", "public, max-age=1209600, s-maxage=86400") .header("Content-Length", content_length) .header("Content-Type", content_type) .build(), ) } - Err(e) => Ok(Response::builder(503).body(e.to_string()).build()), + None => Ok(Response::builder(503).body("Couldn't parse media".to_string()).build()), } } diff --git a/src/utils.rs b/src/utils.rs index 5aaa8fc..6b42513 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -2,6 +2,9 @@ // CRATES // use askama::Template; +use async_recursion::async_recursion; +use async_std::{io, net::TcpStream, prelude::*}; +use async_tls::TlsConnector; use cached::proc_macro::cached; use regex::Regex; use serde_json::{from_str, Error, Value}; @@ -510,54 +513,96 @@ pub async fn error(req: Request<()>, msg: String) -> tide::Result { Ok(Response::builder(404).content_type("text/html").body(body).build()) } +#[async_recursion] +async fn connect(path: String) -> io::Result<(i16, String)> { + // Build reddit-compliant user agent for Libreddit + let user_agent = format!("web:libreddit:{}", env!("CARGO_PKG_VERSION")); + + // Construct an HTTP request body + let req = format!( + "GET {} HTTP/1.1\r\nHost: www.reddit.com\r\nAccept: */*\r\nConnection: close\r\nUser-Agent: {}\r\n\r\n", + path, user_agent + ); + + // Open a TCP connection + let tcp_stream = TcpStream::connect("www.reddit.com:443").await?; + + // Initialize TLS connector for requests + let connector = TlsConnector::default(); + + // Use the connector to start the handshake process + let mut tls_stream = connector.connect("www.reddit.com", tcp_stream).await?; + + // Write the crafted HTTP request to the stream + tls_stream.write_all(req.as_bytes()).await?; + + // And read the response + let mut writer = Vec::new(); + io::copy(&mut tls_stream, &mut writer).await?; + let response = String::from_utf8_lossy(&writer).to_string(); + + let split = response.split("\r\n\r\n").collect::>(); + + let headers = split[0].split("\r\n").collect::>(); + let status: i16 = headers[0].split(' ').collect::>()[1].parse().unwrap_or(200); + let body = split[1].to_string(); + + if (300..400).contains(&status) { + let location = headers + .iter() + .find(|header| header.starts_with("location:")) + .map(|f| f.to_owned()) + .unwrap_or_default() + .split(": ") + .collect::>()[1]; + connect(location.replace("https://www.reddit.com", "")).await + } else { + Ok((status, body)) + } +} + // Make a request to a Reddit API and parse the JSON response #[cached(size = 100, time = 30, result = true)] pub async fn request(path: String) -> Result { let url = format!("https://www.reddit.com{}", path); - // Build reddit-compliant user agent for Libreddit - let user_agent = format!("web:libreddit:{}", env!("CARGO_PKG_VERSION")); - - // Send request using surf - let req = surf::get(&url).header("User-Agent", user_agent.as_str()); - let client = surf::client().with(surf::middleware::Redirect::new(5)); - - let res = client.send(req).await; let err = |msg: &str, e: String| -> Result { eprintln!("{} - {}: {}", url, msg, e); Err(msg.to_string()) }; - match res { - Ok(mut response) => match response.take_body().into_string().await { - // If response is success - Ok(body) => { - // Parse the response from Reddit as JSON - let parsed: Result = from_str(&body); - match parsed { - Ok(json) => { - // If Reddit returned an error - if json["error"].is_i64() { - Err( - json["reason"] - .as_str() - .unwrap_or_else(|| { - json["message"].as_str().unwrap_or_else(|| { - eprintln!("{} - Error parsing reddit error", url); - "Error parsing reddit error" + match connect(path).await { + Ok((status, body)) => { + match status { + // If response is success + 200 => { + // Parse the response from Reddit as JSON + let parsed: Result = from_str(&body); + match parsed { + Ok(json) => { + // If Reddit returned an error + if json["error"].is_i64() { + Err( + json["reason"] + .as_str() + .unwrap_or_else(|| { + json["message"].as_str().unwrap_or_else(|| { + eprintln!("{} - Error parsing reddit error", url); + "Error parsing reddit error" + }) }) - }) - .to_string(), - ) - } else { - Ok(json) + .to_string(), + ) + } else { + Ok(json) + } } + Err(e) => err("Failed to parse page JSON data", e.to_string()), } - Err(e) => err("Failed to parse page JSON data", e.to_string()), } + _ => err("Couldn't send request to Reddit", status.to_string()), } - Err(e) => err("Couldn't parse request body", e.to_string()), - }, + } Err(e) => err("Couldn't send request to Reddit", e.to_string()), } }