diff --git a/src/client.rs b/src/client.rs index 861b210..a6342d1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,5 @@ use cached::proc_macro::cached; +use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; use hyper::client::HttpConnector; use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri}; @@ -7,12 +8,12 @@ use libflate::gzip; use once_cell::sync::Lazy; use percent_encoding::{percent_encode, CONTROLS}; use serde_json::Value; -use std::sync::Arc; + use std::{io, result::Result}; use tokio::sync::RwLock; use crate::dbg_msg; -use crate::oauth::Oauth; +use crate::oauth::{token_daemon, Oauth}; use crate::server::RequestExt; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; @@ -22,7 +23,11 @@ pub(crate) static CLIENT: Lazy>> = Lazy::ne client::Client::builder().build(https) }); -pub(crate) static OAUTH_CLIENT: Lazy>> = Lazy::new(|| Arc::new(RwLock::new(Oauth::new()))); +pub(crate) static OAUTH_CLIENT: Lazy> = Lazy::new(|| { + let client = block_on(Oauth::new()); + tokio::spawn(token_daemon()); + RwLock::new(client) +}); /// Gets the canonical path for a resource on Reddit. This is accomplished by /// making a `HEAD` request to Reddit at the path given in `path`. @@ -137,12 +142,12 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo let client: client::Client<_, hyper::Body> = CLIENT.clone(); let (token, vendor_id, device_id, user_agent, loid) = { - let client = tokio::task::block_in_place(move || OAUTH_CLIENT.blocking_read()); + let client = block_on(OAUTH_CLIENT.read()); ( client.token.clone(), - client.headers_map.get("Client-Vendor-Id").unwrap().clone(), - client.headers_map.get("X-Reddit-Device-Id").unwrap().clone(), - client.headers_map.get("User-Agent").unwrap().clone(), + client.headers_map.get("Client-Vendor-Id").cloned().unwrap_or_default(), + client.headers_map.get("X-Reddit-Device-Id").cloned().unwrap_or_default(), + client.headers_map.get("User-Agent").cloned().unwrap_or_default(), client.headers_map.get("x-reddit-loid").cloned().unwrap_or_default(), ) }; diff --git a/src/main.rs b/src/main.rs index 3a0802e..c58d49e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,10 +22,13 @@ use hyper::{header::HeaderValue, Body, Request, Response}; mod client; use client::{canonical_path, proxy}; +use log::info; use once_cell::sync::Lazy; use server::RequestExt; use utils::{error, redirect, ThemeAssets}; +use crate::client::OAUTH_CLIENT; + mod server; // Create Services @@ -169,13 +172,16 @@ async fn main() { // Force evaluation of statics. In instance_info case, we need to evaluate // the timestamp so deploy date is accurate - in config case, we need to - // evaluate the configuration to avoid paying penalty at first request. + // evaluate the configuration to avoid paying penalty at first request - + // in OAUTH case, we need to retrieve the token to avoid paying penalty + // at first request + info!("Evaluating config."); Lazy::force(&config::CONFIG); + info!("Evaluating instance info."); Lazy::force(&instance_info::INSTANCE_INFO); - - // Initialize OAuth client spoofing - oauth::initialize().await; + info!("Creating OAUTH client."); + Lazy::force(&OAUTH_CLIENT); // Define default headers (added to all responses) app.default_headers = headers! { diff --git a/src/oauth.rs b/src/oauth.rs index 7851e07..9885a83 100644 --- a/src/oauth.rs +++ b/src/oauth.rs @@ -4,6 +4,7 @@ use crate::client::{CLIENT, OAUTH_CLIENT}; use base64::{engine::general_purpose, Engine as _}; use hyper::{client, Body, Method, Request}; use log::info; + use serde_json::json; static REDDIT_ANDROID_OAUTH_CLIENT_ID: &str = "ohXpoqrZYub1kg"; @@ -25,9 +26,10 @@ pub(crate) static IOS_USER_AGENT: [&str; 3] = [ "Reddit/Version 2023.22.0/Build 613580/iOS Version 16.5", ]; // Various iOS device codes. iPhone 11 displays as `iPhone12,1` -// This is a bit of a hack, but I just changed the number a few times +// I just changed the number a few times for some plausible values pub(crate) static IOS_DEVICES: [&str; 5] = ["iPhone8,1", "iPhone11,1", "iPhone12,1", "iPhone13,1", "iPhone14,1"]; +#[derive(Debug, Clone, Default)] pub(crate) struct Oauth { // Currently unused, may be necessary if we decide to support GQL in the future pub(crate) headers_map: HashMap, @@ -37,7 +39,12 @@ pub(crate) struct Oauth { } impl Oauth { - pub(crate) fn new() -> Self { + pub(crate) async fn new() -> Self { + let mut oauth = Oauth::default(); + oauth.login().await; + oauth + } + pub(crate) fn default() -> Self { // Generate a random device to spoof let device = Device::random(); let headers = device.headers.clone(); @@ -81,7 +88,6 @@ impl Oauth { // Build request let request = builder.body(body).unwrap(); - info!("Request: {request:?}"); // Send request let client: client::Client<_, hyper::Body> = CLIENT.clone(); @@ -94,7 +100,6 @@ impl Oauth { self.headers_map.insert("x-reddit-loid".to_owned(), header.to_str().ok()?.to_string()); } - info!("OAuth response: {resp:?}"); // Serialize response let body_bytes = hyper::body::to_bytes(resp.into_body()).await.ok()?; let json: serde_json::Value = serde_json::from_slice(&body_bytes).ok()?; @@ -104,7 +109,7 @@ impl Oauth { self.expires_in = json.get("expires_in")?.as_u64()?; self.headers_map.insert("Authorization".to_owned(), format!("Bearer {}", self.token)); - info!("Retrieved token {}, expires in {}", self.token, self.expires_in); + info!("✅ Success - Retrieved token \"{}...\", expires in {}", &self.token[..32], self.expires_in); Some(()) } @@ -117,23 +122,18 @@ impl Oauth { refresh } } -// Initialize the OAuth client and launch a thread to monitor subsequent token refreshes. -pub(crate) async fn initialize() { - // Initial login - OAUTH_CLIENT.write().await.login().await.unwrap(); - // Spawn token daemon in background - we want the initial login to be blocked upon, but the - // daemon to be launched in the background. - // Initial login blocks libreddit start-up because we _need_ the oauth token. - tokio::spawn(token_daemon()); -} -async fn token_daemon() { + +pub(crate) async fn token_daemon() { // Monitor for refreshing token loop { // Get expiry time - be sure to not hold the read lock - let expires_in = OAUTH_CLIENT.read().await.expires_in; + let expires_in = { OAUTH_CLIENT.read().await.expires_in }; // sleep for the expiry time minus 2 minutes let duration = Duration::from_secs(expires_in - 120); + + info!("Waiting for {duration:?} seconds before refreshing OAuth token..."); + tokio::time::sleep(duration).await; info!("[{duration:?} ELAPSED] Refreshing OAuth token..."); @@ -145,7 +145,7 @@ async fn token_daemon() { } } } -#[derive(Debug)] +#[derive(Debug, Clone, Default)] struct Device { oauth_id: String, headers: HashMap, @@ -209,13 +209,12 @@ impl Device { } } -#[tokio::test] +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_oauth_client() { - initialize().await; + assert!(!OAUTH_CLIENT.read().await.token.is_empty()); } -#[tokio::test] +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_oauth_client_refresh() { - initialize().await; OAUTH_CLIENT.write().await.refresh().await.unwrap(); }