From 528048bcfd2ca4175269bd14c05c31029910ec5f Mon Sep 17 00:00:00 2001 From: Denis Drakhnia Date: Tue, 17 Oct 2023 10:21:52 +0300 Subject: [PATCH] Add protocol crate --- Cargo.lock | 40 ++- Cargo.toml | 34 +- master/Cargo.toml | 29 ++ {config => master/config}/main.toml | 0 {src => master/src}/cli.rs | 0 {src => master/src}/config.rs | 3 +- {src => master/src}/logger.rs | 0 {src => master/src}/main.rs | 5 - master/src/master_server.rs | 287 ++++++++++++++++ {src => master/src}/parser.rs | 0 protocol/Cargo.toml | 12 + protocol/src/admin.rs | 82 +++++ protocol/src/cursor.rs | 516 ++++++++++++++++++++++++++++ {src => protocol/src}/filter.rs | 166 +++++---- protocol/src/game.rs | 128 +++++++ protocol/src/lib.rs | 28 ++ protocol/src/master.rs | 170 +++++++++ protocol/src/server.rs | 506 +++++++++++++++++++++++++++ protocol/src/server_info.rs | 27 ++ protocol/src/types.rs | 51 +++ src/client.rs | 91 ----- src/master_server.rs | 342 ------------------ src/server.rs | 26 -- src/server_info.rs | 329 ------------------ 24 files changed, 1968 insertions(+), 904 deletions(-) create mode 100644 master/Cargo.toml rename {config => master/config}/main.toml (100%) rename {src => master/src}/cli.rs (100%) rename {src => master/src}/config.rs (99%) rename {src => master/src}/logger.rs (100%) rename {src => master/src}/main.rs (93%) create mode 100644 master/src/master_server.rs rename {src => master/src}/parser.rs (100%) create mode 100644 protocol/Cargo.toml create mode 100644 protocol/src/admin.rs create mode 100644 protocol/src/cursor.rs rename {src => protocol/src}/filter.rs (75%) create mode 100644 protocol/src/game.rs create mode 100644 protocol/src/lib.rs create mode 100644 protocol/src/master.rs create mode 100644 protocol/src/server.rs create mode 100644 protocol/src/server_info.rs create mode 100644 protocol/src/types.rs delete mode 100644 src/client.rs delete mode 100644 src/master_server.rs delete mode 100644 src/server.rs delete mode 100644 src/server_info.rs diff --git a/Cargo.lock b/Cargo.lock index 08f9b0c..0487d6a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,21 +83,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "hlmaster" -version = "0.1.0" -dependencies = [ - "bitflags", - "chrono", - "fastrand", - "getopts", - "log", - "once_cell", - "serde", - "thiserror", - "toml", -] - [[package]] name = "iana-time-zone" version = "0.1.57" @@ -388,3 +373,28 @@ name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "xash3d-master" +version = "0.1.0" +dependencies = [ + "bitflags", + "chrono", + "fastrand", + "getopts", + "log", + "once_cell", + "serde", + "thiserror", + "toml", + "xash3d-protocol", +] + +[[package]] +name = "xash3d-protocol" +version = "0.1.0" +dependencies = [ + "bitflags", + "log", + "thiserror", +] diff --git a/Cargo.toml b/Cargo.toml index c8613c9..67b457f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,28 +1,6 @@ -[package] -name = "hlmaster" -version = "0.1.0" -license = "GPL-3.0-only" -authors = ["Denis Drakhnia "] -edition = "2021" -rust-version = "1.56" - -[features] -default = ["logtime"] -logtime = ["chrono"] - -[dependencies] -thiserror = "1.0.49" -getopts = "0.2.21" -log = "<0.4.19" -bitflags = "2.4" -fastrand = "2.0.1" -serde = { version = "1.0.188", features = ["derive"] } -toml = "0.5.11" - -[dependencies.chrono] -version = "<0.4.27" -optional = true -default-features = false -features = ["clock"] -[target.wasm32-unknown-emscripten.dependencies] -once_cell = { version = "<1.18", optional = true } +[workspace] +resolver = "2" +members = [ + "protocol", + "master", +] diff --git a/master/Cargo.toml b/master/Cargo.toml new file mode 100644 index 0000000..48706b7 --- /dev/null +++ b/master/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "xash3d-master" +version = "0.1.0" +license = "GPL-3.0-only" +authors = ["Denis Drakhnia "] +edition = "2021" +rust-version = "1.56" + +[features] +default = ["logtime"] +logtime = ["chrono"] + +[dependencies] +thiserror = "1.0.49" +getopts = "0.2.21" +log = "<0.4.19" +bitflags = "2.4" +fastrand = "2.0.1" +serde = { version = "1.0.188", features = ["derive"] } +toml = "0.5.11" +xash3d-protocol = { path = "../protocol", version = "0.1.0" } + +[dependencies.chrono] +version = "<0.4.27" +optional = true +default-features = false +features = ["clock"] +[target.wasm32-unknown-emscripten.dependencies] +once_cell = { version = "<1.18", optional = true } diff --git a/config/main.toml b/master/config/main.toml similarity index 100% rename from config/main.toml rename to master/config/main.toml diff --git a/src/cli.rs b/master/src/cli.rs similarity index 100% rename from src/cli.rs rename to master/src/cli.rs diff --git a/src/config.rs b/master/src/config.rs similarity index 99% rename from src/config.rs rename to master/src/config.rs index 681cb22..8dd8c65 100644 --- a/src/config.rs +++ b/master/src/config.rs @@ -9,8 +9,7 @@ use std::path::Path; use log::LevelFilter; use serde::{de::Error as _, Deserialize, Deserializer}; use thiserror::Error; - -use crate::filter::Version; +use xash3d_protocol::filter::Version; pub const DEFAULT_CONFIG_PATH: &str = "config/main.toml"; diff --git a/src/logger.rs b/master/src/logger.rs similarity index 100% rename from src/logger.rs rename to master/src/logger.rs diff --git a/src/main.rs b/master/src/main.rs similarity index 93% rename from src/main.rs rename to master/src/main.rs index 13c0752..6640638 100644 --- a/src/main.rs +++ b/master/src/main.rs @@ -2,14 +2,9 @@ // SPDX-FileCopyrightText: 2023 Denis Drakhnia mod cli; -mod client; mod config; -mod filter; mod logger; mod master_server; -mod parser; -mod server; -mod server_info; use log::error; diff --git a/master/src/master_server.rs b/master/src/master_server.rs new file mode 100644 index 0000000..739eb61 --- /dev/null +++ b/master/src/master_server.rs @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use std::collections::HashMap; +use std::io; +use std::net::{SocketAddr, SocketAddrV4, ToSocketAddrs, UdpSocket}; +use std::ops::Deref; +use std::time::Instant; + +use fastrand::Rng; +use log::{error, info, trace, warn}; +use thiserror::Error; +use xash3d_protocol::filter::{Filter, Version}; +use xash3d_protocol::server::Region; +use xash3d_protocol::ServerInfo; +use xash3d_protocol::{game, master, server, Error as ProtocolError}; + +use crate::config::{self, Config}; + +/// The maximum size of UDP packets. +const MAX_PACKET_SIZE: usize = 512; + +/// How many cleanup calls should be skipped before removing outdated servers. +const SERVER_CLEANUP_MAX: usize = 100; + +/// How many cleanup calls should be skipped before removing outdated challenges. +const CHALLENGE_CLEANUP_MAX: usize = 100; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Failed to bind server socket: {0}")] + BindSocket(io::Error), + #[error(transparent)] + Protocol(#[from] ProtocolError), + #[error(transparent)] + Io(#[from] io::Error), +} + +/// HashMap entry to keep tracking creation time. +#[derive(Clone, Debug)] +struct Entry { + time: u32, + value: T, +} + +impl Entry { + fn new(time: u32, value: T) -> Self { + Self { time, value } + } + + fn is_valid(&self, now: u32, duration: u32) -> bool { + (now - self.time) < duration + } +} + +impl Entry { + fn matches(&self, addr: SocketAddrV4, region: Region, filter: &Filter) -> bool { + self.region == region && filter.matches(addr, &self.value) + } +} + +impl Deref for Entry { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.value + } +} + +struct MasterServer { + sock: UdpSocket, + challenges: HashMap>, + servers: HashMap>, + rng: Rng, + + start_time: Instant, + cleanup_challenges: usize, + cleanup_servers: usize, + timeout: config::TimeoutConfig, + + clver: Version, + update_title: Box, + update_map: Box, + update_addr: SocketAddrV4, +} + +impl MasterServer { + fn new(cfg: Config) -> Result { + let addr = SocketAddr::new(cfg.server.ip, cfg.server.port); + info!("Listen address: {}", addr); + let sock = UdpSocket::bind(addr).map_err(Error::BindSocket)?; + let update_addr = + cfg.client + .update_addr + .unwrap_or_else(|| match sock.local_addr().unwrap() { + SocketAddr::V4(addr) => addr, + _ => todo!(), + }); + + Ok(Self { + sock, + start_time: Instant::now(), + challenges: Default::default(), + servers: Default::default(), + rng: Rng::new(), + cleanup_challenges: 0, + cleanup_servers: 0, + timeout: cfg.server.timeout, + clver: cfg.client.version, + update_title: cfg.client.update_title, + update_map: cfg.client.update_map, + update_addr, + }) + } + + fn run(&mut self) -> Result<(), Error> { + let mut buf = [0; MAX_PACKET_SIZE]; + loop { + let (n, from) = self.sock.recv_from(&mut buf)?; + let from = match from { + SocketAddr::V4(a) => a, + _ => { + warn!("{}: Received message from IPv6, unimplemented", from); + continue; + } + }; + + if let Err(e) = self.handle_packet(from, &buf[..n]) { + error!("{}: {}", from, e); + } + } + } + + fn handle_packet(&mut self, from: SocketAddrV4, src: &[u8]) -> Result<(), Error> { + if let Ok(p) = server::Packet::decode(src) { + match p { + server::Packet::Challenge(p) => { + trace!("{}: recv {:?}", from, p); + let master_challenge = self.add_challenge(from); + let mut buf = [0; MAX_PACKET_SIZE]; + let p = master::ChallengeResponse::new(master_challenge, p.server_challenge); + trace!("{}: send {:?}", from, p); + let n = p.encode(&mut buf)?; + self.sock.send_to(&buf[..n], from)?; + self.remove_outdated_challenges(); + } + server::Packet::ServerAdd(p) => { + trace!("{}: recv {:?}", from, p); + let entry = match self.challenges.get(&from) { + Some(e) => e, + None => { + trace!("{}: Challenge does not exists", from); + return Ok(()); + } + }; + if !entry.is_valid(self.now(), self.timeout.challenge) { + return Ok(()); + } + if p.challenge != entry.value { + warn!( + "{}: Expected challenge {} but received {}", + from, entry.value, p.challenge + ); + return Ok(()); + } + if self.challenges.remove(&from).is_some() { + self.add_server(from, ServerInfo::new(&p)); + } + self.remove_outdated_servers(); + } + _ => { + trace!("{}: recv {:?}", from, p); + } + } + } + + if let Ok(p) = game::Packet::decode(src) { + match p { + game::Packet::QueryServers(p) => { + trace!("{}: recv {:?}", from, p); + if p.filter.clver < self.clver { + let iter = std::iter::once(self.update_addr); + self.send_server_list(from, iter)?; + } else { + let now = self.now(); + let iter = self + .servers + .iter() + .filter(|i| i.1.is_valid(now, self.timeout.server)) + .filter(|i| i.1.matches(*i.0, p.region, &p.filter)) + .map(|i| *i.0); + self.send_server_list(from, iter)?; + } + } + game::Packet::GetServerInfo(p) => { + trace!("{}: recv {:?}", from, p); + let p = server::GetServerInfoResponse { + map: self.update_map.as_ref(), + host: self.update_title.as_ref(), + protocol: 49, + dm: true, + maxcl: 32, + gamedir: "valve", + ..Default::default() + }; + trace!("{}: send {:?}", from, p); + let mut buf = [0; MAX_PACKET_SIZE]; + let n = p.encode(&mut buf)?; + self.sock.send_to(&buf[..n], from)?; + } + } + } + + Ok(()) + } + + fn now(&self) -> u32 { + self.start_time.elapsed().as_secs() as u32 + } + + fn add_challenge(&mut self, addr: SocketAddrV4) -> u32 { + let x = self.rng.u32(..); + let entry = Entry::new(self.now(), x); + self.challenges.insert(addr, entry); + x + } + + fn remove_outdated_challenges(&mut self) { + if self.cleanup_challenges < CHALLENGE_CLEANUP_MAX { + self.cleanup_challenges += 1; + return; + } + let now = self.now(); + let old = self.challenges.len(); + self.challenges + .retain(|_, v| v.is_valid(now, self.timeout.challenge)); + let new = self.challenges.len(); + if old != new { + trace!("Removed {} outdated challenges", old - new); + } + self.cleanup_challenges = 0; + } + + fn add_server(&mut self, addr: SocketAddrV4, server: ServerInfo) { + match self.servers.insert(addr, Entry::new(self.now(), server)) { + Some(_) => trace!("{}: Updated GameServer", addr), + None => trace!("{}: New GameServer", addr), + } + } + + fn remove_outdated_servers(&mut self) { + if self.cleanup_servers < SERVER_CLEANUP_MAX { + self.cleanup_servers += 1; + return; + } + let now = self.now(); + let old = self.servers.len(); + self.servers + .retain(|_, v| v.is_valid(now, self.timeout.server)); + let new = self.servers.len(); + if old != new { + trace!("Removed {} outdated servers", old - new); + } + self.cleanup_servers = 0; + } + + fn send_server_list(&self, to: A, iter: I) -> Result<(), Error> + where + A: ToSocketAddrs, + I: Iterator, + { + let mut list = master::QueryServersResponse::new(iter); + loop { + let mut buf = [0; MAX_PACKET_SIZE]; + let (n, is_end) = list.encode(&mut buf)?; + self.sock.send_to(&buf[..n], &to)?; + if is_end { + break; + } + } + Ok(()) + } +} + +pub fn run(cfg: Config) -> Result<(), Error> { + MasterServer::new(cfg)?.run() +} diff --git a/src/parser.rs b/master/src/parser.rs similarity index 100% rename from src/parser.rs rename to master/src/parser.rs diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml new file mode 100644 index 0000000..d922d74 --- /dev/null +++ b/protocol/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "xash3d-protocol" +version = "0.1.0" +license = "GPL-3.0-only" +authors = ["Denis Drakhnia "] +edition = "2021" +rust-version = "1.56" + +[dependencies] +thiserror = "1.0.49" +log = "<0.4.19" +bitflags = "2.4" diff --git a/protocol/src/admin.rs b/protocol/src/admin.rs new file mode 100644 index 0000000..f7738a3 --- /dev/null +++ b/protocol/src/admin.rs @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use crate::cursor::{Cursor, CursorMut}; +use crate::types::Str; +use crate::Error; + +pub const HASH_LEN: usize = 64; + +#[derive(Clone, Debug, PartialEq)] +pub struct AdminChallenge; + +impl AdminChallenge { + pub const HEADER: &'static [u8] = b"adminchallenge"; + + pub fn decode(src: &[u8]) -> Result { + if src == Self::HEADER { + Ok(Self) + } else { + Err(Error::InvalidPacket) + } + } + + pub fn encode(&self, buf: &mut [u8]) -> Result { + Ok(CursorMut::new(buf).put_bytes(Self::HEADER)?.pos()) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct AdminCommand<'a> { + pub hash: &'a [u8], + pub command: Str<&'a [u8]>, +} + +impl<'a> AdminCommand<'a> { + pub const HEADER: &'static [u8] = b"admin"; + + pub fn new(hash: &'a [u8], command: &'a str) -> Self { + Self { + hash, + command: Str(command.as_bytes()), + } + } + + pub fn decode(src: &'a [u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(Self::HEADER)?; + let hash = cur.get_bytes(HASH_LEN)?; + let command = Str(cur.get_bytes(cur.remaining())?); + cur.expect_empty()?; + Ok(Self { hash, command }) + } + + pub fn encode(&self, buf: &mut [u8]) -> Result { + Ok(CursorMut::new(buf) + .put_bytes(Self::HEADER)? + .put_bytes(self.hash)? + .put_bytes(&self.command)? + .pos()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn admin_challenge() { + let p = AdminChallenge; + let mut buf = [0; 512]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(AdminChallenge::decode(&buf[..n]), Ok(p)); + } + + #[test] + fn admin_command() { + let p = AdminCommand::new(&[1; HASH_LEN], "foo bar baz"); + let mut buf = [0; 512]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(AdminCommand::decode(&buf[..n]), Ok(p)); + } +} diff --git a/protocol/src/cursor.rs b/protocol/src/cursor.rs new file mode 100644 index 0000000..a5d1f0d --- /dev/null +++ b/protocol/src/cursor.rs @@ -0,0 +1,516 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use std::fmt; +use std::io::{self, Write as _}; +use std::mem; +use std::slice; +use std::str; + +use super::types::Str; +use super::Error; + +pub trait GetKeyValue<'a>: Sized { + fn get_key_value(cur: &mut Cursor<'a>) -> Result; +} + +impl<'a> GetKeyValue<'a> for &'a [u8] { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + cur.get_key_value_raw() + } +} + +impl<'a> GetKeyValue<'a> for Str<&'a [u8]> { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + cur.get_key_value_raw().map(Str) + } +} + +impl<'a> GetKeyValue<'a> for &'a str { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + let raw = cur.get_key_value_raw()?; + str::from_utf8(raw).map_err(|_| Error::InvalidString) + } +} + +impl<'a> GetKeyValue<'a> for bool { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + match cur.get_key_value_raw()? { + b"0" => Ok(false), + b"1" => Ok(true), + _ => Err(Error::InvalidPacket), + } + } +} + +macro_rules! impl_get_value { + ($($t:ty),+ $(,)?) => { + $(impl<'a> GetKeyValue<'a> for $t { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + cur.get_key_value::<&str>()?.parse().map_err(|_| Error::InvalidPacket) + } + })+ + }; +} + +impl_get_value! { + u8, + u16, + u32, + u64, + + i8, + i16, + i32, + i64, +} + +// TODO: impl GetKeyValue for f32 and f64 + +#[derive(Copy, Clone)] +pub struct Cursor<'a> { + buffer: &'a [u8], +} + +macro_rules! impl_get { + ($($n:ident: $t:ty = $f:ident),+ $(,)?) => ( + $(#[inline] + pub fn $n(&mut self) -> Result<$t, Error> { + const N: usize = mem::size_of::<$t>(); + self.get_array::().map(<$t>::$f) + })+ + ); +} + +impl<'a> Cursor<'a> { + pub fn new(buffer: &'a [u8]) -> Self { + Self { buffer } + } + + pub fn end(self) -> &'a [u8] { + self.buffer + } + + #[inline(always)] + pub fn remaining(&self) -> usize { + self.buffer.len() + } + + #[inline(always)] + pub fn has_remaining(&self) -> bool { + self.remaining() != 0 + } + + pub fn get_bytes(&mut self, count: usize) -> Result<&'a [u8], Error> { + if count <= self.remaining() { + let (head, tail) = self.buffer.split_at(count); + self.buffer = tail; + Ok(head) + } else { + Err(Error::UnexpectedEnd) + } + } + + pub fn advance(&mut self, count: usize) -> Result<(), Error> { + self.get_bytes(count).map(|_| ()) + } + + pub fn get_array(&mut self) -> Result<[u8; N], Error> { + self.get_bytes(N).map(|s| { + let mut array = [0; N]; + array.copy_from_slice(s); + array + }) + } + + pub fn get_str(&mut self, n: usize) -> Result<&'a str, Error> { + let mut cur = *self; + let s = cur + .get_bytes(n) + .and_then(|s| str::from_utf8(s).map_err(|_| Error::InvalidString))?; + *self = cur; + Ok(s) + } + + pub fn get_cstr(&mut self) -> Result, Error> { + let pos = self + .buffer + .iter() + .position(|&c| c == b'\0') + .ok_or(Error::UnexpectedEnd)?; + let (head, tail) = self.buffer.split_at(pos); + self.buffer = &tail[1..]; + Ok(Str(&head[..pos])) + } + + pub fn get_cstr_as_str(&mut self) -> Result<&'a str, Error> { + str::from_utf8(&self.get_cstr()?).map_err(|_| Error::InvalidString) + } + + #[inline(always)] + pub fn get_u8(&mut self) -> Result { + self.get_array::<1>().map(|s| s[0]) + } + + #[inline(always)] + pub fn get_i8(&mut self) -> Result { + self.get_array::<1>().map(|s| s[0] as i8) + } + + impl_get! { + get_u16_le: u16 = from_le_bytes, + get_u32_le: u32 = from_le_bytes, + get_u64_le: u64 = from_le_bytes, + get_i16_le: i16 = from_le_bytes, + get_i32_le: i32 = from_le_bytes, + get_i64_le: i64 = from_le_bytes, + get_f32_le: f32 = from_le_bytes, + get_f64_le: f64 = from_le_bytes, + + get_u16_be: u16 = from_be_bytes, + get_u32_be: u32 = from_be_bytes, + get_u64_be: u64 = from_be_bytes, + get_i16_be: i16 = from_be_bytes, + get_i32_be: i32 = from_be_bytes, + get_i64_be: i64 = from_be_bytes, + get_f32_be: f32 = from_be_bytes, + get_f64_be: f64 = from_be_bytes, + + get_u16_ne: u16 = from_ne_bytes, + get_u32_ne: u32 = from_ne_bytes, + get_u64_ne: u64 = from_ne_bytes, + get_i16_ne: i16 = from_ne_bytes, + get_i32_ne: i32 = from_ne_bytes, + get_i64_ne: i64 = from_ne_bytes, + get_f32_ne: f32 = from_ne_bytes, + get_f64_ne: f64 = from_ne_bytes, + } + + pub fn expect(&mut self, s: &[u8]) -> Result<(), Error> { + if self.buffer.starts_with(s) { + self.advance(s.len())?; + Ok(()) + } else { + Err(Error::InvalidPacket) + } + } + + pub fn expect_empty(&self) -> Result<(), Error> { + if self.has_remaining() { + Err(Error::InvalidPacket) + } else { + Ok(()) + } + } + + pub fn take_while(&mut self, mut cond: F) -> Result<&'a [u8], Error> + where + F: FnMut(u8) -> bool, + { + self.buffer + .iter() + .position(|&i| !cond(i)) + .ok_or(Error::UnexpectedEnd) + .and_then(|n| self.get_bytes(n)) + } + + pub fn take_while_or_all(&mut self, cond: F) -> &'a [u8] + where + F: FnMut(u8) -> bool, + { + self.take_while(cond).unwrap_or_else(|_| { + let (head, tail) = self.buffer.split_at(self.buffer.len()); + self.buffer = tail; + head + }) + } + + pub fn get_key_value_raw(&mut self) -> Result<&'a [u8], Error> { + let mut cur = *self; + if cur.get_u8()? == b'\\' { + let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n'); + *self = cur; + Ok(value) + } else { + Err(Error::InvalidPacket) + } + } + + pub fn get_key_value>(&mut self) -> Result { + T::get_key_value(self) + } + + pub fn get_key_raw(&mut self) -> Result<&'a [u8], Error> { + let mut cur = *self; + if cur.get_u8()? == b'\\' { + let value = cur.take_while(|c| c != b'\\' && c != b'\n')?; + *self = cur; + Ok(value) + } else { + Err(Error::InvalidPacket) + } + } + + pub fn get_key>(&mut self) -> Result<(&'a [u8], T), Error> { + Ok((self.get_key_raw()?, self.get_key_value()?)) + } +} + +pub trait PutKeyValue { + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut CursorMut<'a>, + ) -> Result<&'b mut CursorMut<'a>, Error>; +} + +impl PutKeyValue for &str { + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut CursorMut<'a>, + ) -> Result<&'b mut CursorMut<'a>, Error> { + cur.put_str(self) + } +} + +impl PutKeyValue for bool { + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut CursorMut<'a>, + ) -> Result<&'b mut CursorMut<'a>, Error> { + cur.put_u8(if *self { b'1' } else { b'0' }) + } +} + +macro_rules! impl_put_key_value { + ($($t:ty),+ $(,)?) => { + $(impl PutKeyValue for $t { + fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>, Error> { + cur.put_as_str(self) + } + })+ + }; +} + +impl_put_key_value! { + u8, + u16, + u32, + u64, + + i8, + i16, + i32, + i64, + + f32, + f64, +} + +pub struct CursorMut<'a> { + buffer: &'a [u8], + buffer_mut: &'a mut [u8], +} + +macro_rules! impl_put { + ($($n:ident: $t:ty = $f:ident),+ $(,)?) => ( + $(#[inline] + pub fn $n(&mut self, n: $t) -> Result<&mut Self, Error> { + self.put_array(&n.$f()) + })+ + ); +} + +impl<'a> CursorMut<'a> { + pub fn new(buffer: &'a mut [u8]) -> Self { + Self { + buffer: unsafe { slice::from_raw_parts(buffer.as_ptr(), 0) }, + buffer_mut: buffer, + } + } + + pub fn buffer(&self) -> &'a [u8] { + self.buffer + } + + pub fn buffer_mut<'b: 'a>(&'b mut self) -> &'a mut [u8] { + self.buffer_mut + } + + pub fn end(self) -> (&'a [u8], &'a mut [u8]) { + (self.buffer, self.buffer_mut) + } + + pub fn pos(&mut self) -> usize { + self.buffer.len() + } + + #[inline(always)] + pub fn remaining(&self) -> usize { + self.buffer_mut.len() + } + + pub fn advance(&mut self, count: usize, mut f: F) -> Result<&mut Self, Error> + where + F: FnMut(&'a mut [u8]), + { + if count <= self.remaining() { + let buffer_mut = mem::take(&mut self.buffer_mut); + let (head, tail) = buffer_mut.split_at_mut(count); + f(head); + self.buffer = + unsafe { slice::from_raw_parts(self.buffer.as_ptr(), self.buffer.len() + count) }; + self.buffer_mut = tail; + Ok(self) + } else { + Err(Error::UnexpectedEnd) + } + } + + pub fn put_bytes(&mut self, s: &[u8]) -> Result<&mut Self, Error> { + self.advance(s.len(), |i| { + i.copy_from_slice(s); + }) + } + + pub fn put_array(&mut self, s: &[u8; N]) -> Result<&mut Self, Error> { + self.advance(N, |i| { + i.copy_from_slice(s); + }) + } + + pub fn put_str(&mut self, s: &str) -> Result<&mut Self, Error> { + self.put_bytes(s.as_bytes()) + } + + pub fn put_cstr(&mut self, s: &str) -> Result<&mut Self, Error> { + self.put_str(s)?.put_u8(0) + } + + #[inline(always)] + pub fn put_u8(&mut self, n: u8) -> Result<&mut Self, Error> { + self.put_array(&[n]) + } + + #[inline(always)] + pub fn put_i8(&mut self, n: i8) -> Result<&mut Self, Error> { + self.put_u8(n as u8) + } + + impl_put! { + put_u16_le: u16 = to_le_bytes, + put_u32_le: u32 = to_le_bytes, + put_u64_le: u64 = to_le_bytes, + put_i16_le: i16 = to_le_bytes, + put_i32_le: i32 = to_le_bytes, + put_i64_le: i64 = to_le_bytes, + put_f32_le: f32 = to_le_bytes, + put_f64_le: f64 = to_le_bytes, + + put_u16_be: u16 = to_be_bytes, + put_u32_be: u32 = to_be_bytes, + put_u64_be: u64 = to_be_bytes, + put_i16_be: i16 = to_be_bytes, + put_i32_be: i32 = to_be_bytes, + put_i64_be: i64 = to_be_bytes, + put_f32_be: f32 = to_be_bytes, + put_f64_be: f64 = to_be_bytes, + + put_u16_ne: u16 = to_ne_bytes, + put_u32_ne: u32 = to_ne_bytes, + put_u64_ne: u64 = to_ne_bytes, + put_i16_ne: i16 = to_ne_bytes, + put_i32_ne: i32 = to_ne_bytes, + put_i64_ne: i64 = to_ne_bytes, + put_f32_ne: f32 = to_ne_bytes, + put_f64_ne: f64 = to_ne_bytes, + } + + pub fn put_as_str(&mut self, value: T) -> Result<&mut Self, Error> { + let mut cur = io::Cursor::new(mem::take(&mut self.buffer_mut)); + write!(&mut cur, "{}", value).map_err(|_| Error::UnexpectedEnd)?; + let n = cur.position() as usize; + self.buffer_mut = cur.into_inner(); + self.advance(n, |_| {}) + } + + pub fn put_key_value(&mut self, value: T) -> Result<&mut Self, Error> { + value.put_key_value(self) + } + + pub fn put_key_raw(&mut self, key: &str, value: &[u8]) -> Result<&mut Self, Error> { + self.put_u8(b'\\')? + .put_str(key)? + .put_u8(b'\\')? + .put_bytes(value) + } + + pub fn put_key(&mut self, key: &str, value: T) -> Result<&mut Self, Error> { + self.put_u8(b'\\')? + .put_str(key)? + .put_u8(b'\\')? + .put_key_value(value) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cursor() -> Result<(), Error> { + let mut buf = [0; 64]; + let s = CursorMut::new(&mut buf) + .put_bytes(b"12345678")? + .put_array(b"4321")? + .put_str("abc")? + .put_cstr("def")? + .put_u8(0x7f)? + .put_i8(-128)? + .put_u32_le(0x44332211)? + .buffer(); + + let mut cur = Cursor::new(s); + assert_eq!(cur.get_bytes(8), Ok(&b"12345678"[..])); + assert_eq!(cur.get_array::<4>(), Ok(*b"4321")); + assert_eq!(cur.get_str(3), Ok("abc")); + assert_eq!(cur.get_cstr(), Ok(Str(&b"def"[..]))); + assert_eq!(cur.get_u8(), Ok(0x7f)); + assert_eq!(cur.get_i8(), Ok(-128)); + assert_eq!(cur.get_u32_le(), Ok(0x44332211)); + assert_eq!(cur.get_u8(), Err(Error::UnexpectedEnd)); + + Ok(()) + } + + #[test] + fn key() -> Result<(), Error> { + let mut buf = [0; 512]; + let s = CursorMut::new(&mut buf) + .put_key("p", 49)? + .put_key("map", "crossfire")? + .put_key("dm", true)? + .put_key("team", false)? + .put_key("coop", false)? + .put_key("numcl", 4)? + .put_key("maxcl", 32)? + .put_key("gamedir", "valve")? + .put_key("password", false)? + .put_key("host", "test")? + .buffer(); + + let mut cur = Cursor::new(s); + assert_eq!(cur.get_key(), Ok((&b"p"[..], 49_u8))); + assert_eq!(cur.get_key(), Ok((&b"map"[..], "crossfire"))); + assert_eq!(cur.get_key(), Ok((&b"dm"[..], true))); + assert_eq!(cur.get_key(), Ok((&b"team"[..], false))); + assert_eq!(cur.get_key(), Ok((&b"coop"[..], false))); + assert_eq!(cur.get_key(), Ok((&b"numcl"[..], 4_u8))); + assert_eq!(cur.get_key(), Ok((&b"maxcl"[..], 32_u8))); + assert_eq!(cur.get_key(), Ok((&b"gamedir"[..], "valve"))); + assert_eq!(cur.get_key(), Ok((&b"password"[..], false))); + assert_eq!(cur.get_key(), Ok((&b"host"[..], "test"))); + assert_eq!(cur.get_key::<&[u8]>(), Err(Error::UnexpectedEnd)); + + Ok(()) + } +} diff --git a/src/filter.rs b/protocol/src/filter.rs similarity index 75% rename from src/filter.rs rename to protocol/src/filter.rs index f95d8d1..0aa4572 100644 --- a/src/filter.rs +++ b/protocol/src/filter.rs @@ -34,11 +34,12 @@ use std::num::ParseIntError; use std::str::FromStr; use bitflags::bitflags; -use log::{debug, log_enabled, Level}; +use log::debug; -use crate::parser::{Error as ParserError, ParseValue, Parser}; -use crate::server::Server; -use crate::server_info::{ServerFlags, ServerInfo, ServerType}; +use crate::cursor::{Cursor, GetKeyValue, PutKeyValue}; +use crate::server::{ServerAdd, ServerFlags, ServerType}; +use crate::types::Str; +use crate::{Error, ServerInfo}; bitflags! { #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] @@ -64,8 +65,8 @@ bitflags! { } } -impl From<&ServerInfo> for FilterFlags { - fn from(info: &ServerInfo) -> Self { +impl From<&ServerAdd> for FilterFlags { + fn from(info: &ServerAdd) -> Self { let mut flags = Self::empty(); flags.set(Self::DEDICATED, info.server_type == ServerType::Dedicated); @@ -115,24 +116,31 @@ impl FromStr for Version { } } -impl ParseValue<'_> for Version { - type Err = ParserError; +impl GetKeyValue<'_> for Version { + fn get_key_value(cur: &mut Cursor) -> Result { + Self::from_str(cur.get_key_value()?).map_err(|_| Error::InvalidPacket) + } +} - fn parse(p: &mut Parser<'_>) -> Result { - let s = p.parse::<&str>()?; - let v = s.parse()?; - Ok(v) +impl PutKeyValue for Version { + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut crate::cursor::CursorMut<'a>, + ) -> Result<&'b mut crate::cursor::CursorMut<'a>, Error> { + cur.put_key_value(self.major)? + .put_u8(b'.')? + .put_key_value(self.minor) } } #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct Filter<'a> { /// Servers running the specified modification (ex. cstrike) - pub gamedir: Option<&'a str>, + pub gamedir: &'a [u8], /// Servers running the specified map (ex. cs_italy) - pub map: Option<&'a str>, + pub map: &'a [u8], /// Client version. - pub clver: Option, + pub clver: Version, pub flags: FilterFlags, pub flags_mask: FilterFlags, @@ -144,62 +152,47 @@ impl Filter<'_> { self.flags_mask.insert(flag); } - pub fn matches(&self, _addr: SocketAddrV4, server: &Server) -> bool { - if (server.flags & self.flags_mask) != self.flags { - return false; - } - if self.gamedir.map_or(false, |i| &*server.gamedir != i) { - return false; - } - if self.map.map_or(false, |i| &*server.map != i) { - return false; - } - true + pub fn matches(&self, _addr: SocketAddrV4, info: &ServerInfo) -> bool { + !((info.flags & self.flags_mask) != self.flags + || (!self.gamedir.is_empty() && self.gamedir != &*info.gamedir) + || (!self.map.is_empty() && self.map != &*info.map)) } } impl<'a> Filter<'a> { - pub fn from_bytes(src: &'a [u8]) -> Result { - let mut parser = Parser::new(src); - let filter = parser.parse()?; - Ok(filter) - } -} - -impl<'a> ParseValue<'a> for Filter<'a> { - type Err = ParserError; - - fn parse(p: &mut Parser<'a>) -> Result { + pub fn from_bytes(src: &'a [u8]) -> Result { + let mut cur = Cursor::new(src); let mut filter = Self::default(); loop { - let name = match p.parse_bytes() { + let key = match cur.get_key_raw().map(Str) { Ok(s) => s, - Err(ParserError::End) => break, + Err(Error::UnexpectedEnd) => break, Err(e) => return Err(e), }; - match name { - b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, p.parse()?), - b"secure" => filter.insert_flag(FilterFlags::SECURE, p.parse()?), - b"gamedir" => filter.gamedir = Some(p.parse()?), - b"map" => filter.map = Some(p.parse()?), - b"empty" => filter.insert_flag(FilterFlags::NOT_EMPTY, p.parse()?), - b"full" => filter.insert_flag(FilterFlags::FULL, p.parse()?), - b"password" => filter.insert_flag(FilterFlags::PASSWORD, p.parse()?), - b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, p.parse()?), - b"clver" => filter.clver = Some(p.parse()?), - b"nat" => filter.insert_flag(FilterFlags::NAT, p.parse()?), - b"lan" => filter.insert_flag(FilterFlags::LAN, p.parse()?), - b"bots" => filter.insert_flag(FilterFlags::BOTS, p.parse()?), + match *key { + b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get_key_value()?), + b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get_key_value()?), + b"gamedir" => filter.gamedir = cur.get_key_value()?, + b"map" => filter.map = cur.get_key_value()?, + b"empty" => filter.insert_flag(FilterFlags::NOT_EMPTY, cur.get_key_value()?), + b"full" => filter.insert_flag(FilterFlags::FULL, cur.get_key_value()?), + b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get_key_value()?), + b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get_key_value()?), + b"clver" => { + filter.clver = cur + .get_key_value::<&str>()? + .parse() + .map_err(|_| Error::InvalidPacket)? + } + b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get_key_value()?), + b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get_key_value()?), + b"bots" => filter.insert_flag(FilterFlags::BOTS, cur.get_key_value()?), _ => { // skip unknown fields - let value = p.parse_bytes()?; - if log_enabled!(Level::Debug) { - let name = String::from_utf8_lossy(name); - let value = String::from_utf8_lossy(value); - debug!("Invalid Filter field \"{}\" = \"{}\"", name, value); - } + let value = Str(cur.get_key_value_raw()?); + debug!("Invalid Filter field \"{}\" = \"{}\"", key, value); } } } @@ -208,8 +201,43 @@ impl<'a> ParseValue<'a> for Filter<'a> { } } +impl fmt::Display for &Filter<'_> { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + macro_rules! display_flag { + ($n:expr, $f:expr) => { + if self.flags_mask.contains($f) { + let flag = if self.flags.contains($f) { '1' } else { '0' }; + write!(fmt, "\\{}\\{}", $n, flag)?; + } + }; + } + + display_flag!("dedicated", FilterFlags::DEDICATED); + display_flag!("secure", FilterFlags::SECURE); + if !self.gamedir.is_empty() { + write!(fmt, "\\gamedir\\{}", Str(self.gamedir))?; + } + display_flag!("secure", FilterFlags::SECURE); + if !self.map.is_empty() { + write!(fmt, "\\map\\{}", Str(self.map))?; + } + display_flag!("empty", FilterFlags::NOT_EMPTY); + display_flag!("full", FilterFlags::FULL); + display_flag!("password", FilterFlags::PASSWORD); + display_flag!("noplayers", FilterFlags::NOPLAYERS); + write!(fmt, "\\clver\\{}", self.clver)?; + display_flag!("nat", FilterFlags::NAT); + display_flag!("lan", FilterFlags::LAN); + display_flag!("bots", FilterFlags::BOTS); + + Ok(()) + } +} + #[cfg(test)] mod tests { + use super::super::cursor::CursorMut; + use super::super::types::Str; use super::*; macro_rules! tests { @@ -238,17 +266,17 @@ mod tests { tests! { parse_gamedir { b"\\gamedir\\valve" => { - gamedir: Some("valve"), + gamedir: &b"valve"[..], } } parse_map { b"\\map\\crossfire" => { - map: Some("crossfire"), + map: &b"crossfire"[..], } } parse_clver { b"\\clver\\0.20" => { - clver: Some(Version::new(0, 20)), + clver: Version::new(0, 20), } } parse_dedicated(flags_mask: FilterFlags::DEDICATED) { @@ -321,9 +349,9 @@ mod tests { \\password\\1\ \\secure\\1\ " => { - gamedir: Some("valve"), - map: Some("crossfire"), - clver: Some(Version::new(0, 20)), + gamedir: &b"valve"[..], + map: &b"crossfire"[..], + clver: Version::new(0, 20), flags: FilterFlags::all(), flags_mask: FilterFlags::all(), } @@ -334,8 +362,14 @@ mod tests { ($($addr:expr => $info:expr $(=> $func:expr)?)+) => ( [$({ let addr = $addr.parse::().unwrap(); - let (_, info, _) = ServerInfo::<&str>::from_bytes($info).unwrap(); - let server = Server::new(&info); + let mut buf = [0; 512]; + let n = CursorMut::new(&mut buf) + .put_bytes(ServerAdd::HEADER).unwrap() + .put_key("challenge", 0).unwrap() + .put_bytes($info).unwrap() + .pos(); + let p = ServerAdd::>::decode(&buf[..n]).unwrap(); + let server = ServerInfo::new(&p); $( let mut server = server; let func: fn(&mut Server) = $func; diff --git a/protocol/src/game.rs b/protocol/src/game.rs new file mode 100644 index 0000000..c6df3ef --- /dev/null +++ b/protocol/src/game.rs @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use std::net::SocketAddrV4; + +use crate::cursor::{Cursor, CursorMut}; +use crate::filter::Filter; +use crate::server::Region; +use crate::Error; + +#[derive(Clone, Debug, PartialEq)] +pub struct QueryServers<'a> { + pub region: Region, + pub last: SocketAddrV4, + pub filter: Filter<'a>, +} + +impl<'a> QueryServers<'a> { + pub const HEADER: &'static [u8] = b"1"; + + pub fn decode(src: &'a [u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(Self::HEADER)?; + let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidPacket)?; + let last = cur.get_cstr_as_str()?; + let filter = cur.get_cstr()?; + cur.expect_empty()?; + Ok(Self { + region, + last: last.parse().map_err(|_| Error::InvalidPacket)?, + filter: Filter::from_bytes(&filter)?, + }) + } + + pub fn encode(&self, buf: &mut [u8]) -> Result { + Ok(CursorMut::new(buf) + .put_bytes(Self::HEADER)? + .put_u8(self.region as u8)? + .put_as_str(self.last)? + .put_u8(0)? + .put_as_str(&self.filter)? + .put_u8(0)? + .pos()) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct GetServerInfo { + pub protocol: u8, +} + +impl GetServerInfo { + pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffinfo "; + + pub fn new(protocol: u8) -> Self { + Self { protocol } + } + + pub fn decode(src: &[u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(Self::HEADER)?; + let protocol = cur + .get_str(cur.remaining())? + .parse() + .map_err(|_| Error::InvalidPacket)?; + Ok(Self { protocol }) + } + + pub fn encode(&self, buf: &mut [u8]) -> Result { + Ok(CursorMut::new(buf) + .put_bytes(Self::HEADER)? + .put_as_str(self.protocol)? + .pos()) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Packet<'a> { + QueryServers(QueryServers<'a>), + GetServerInfo(GetServerInfo), +} + +impl<'a> Packet<'a> { + pub fn decode(src: &'a [u8]) -> Result { + if let Ok(p) = QueryServers::decode(src) { + return Ok(Self::QueryServers(p)); + } + + if let Ok(p) = GetServerInfo::decode(src) { + return Ok(Self::GetServerInfo(p)); + } + + Err(Error::InvalidPacket) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::filter::{FilterFlags, Version}; + use std::net::Ipv4Addr; + + #[test] + fn query_servers() { + let p = QueryServers { + region: Region::RestOfTheWorld, + last: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), + filter: Filter { + gamedir: &b"valve"[..], + map: &b"crossfire"[..], + clver: Version::new(0, 20), + flags: FilterFlags::all(), + flags_mask: FilterFlags::all(), + }, + }; + let mut buf = [0; 512]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(QueryServers::decode(&buf[..n]), Ok(p)); + } + + #[test] + fn get_server_info() { + let p = GetServerInfo::new(49); + let mut buf = [0; 512]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(GetServerInfo::decode(&buf[..n]), Ok(p)); + } +} diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs new file mode 100644 index 0000000..e4fde48 --- /dev/null +++ b/protocol/src/lib.rs @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +mod cursor; +mod server_info; + +pub mod admin; +pub mod filter; +pub mod game; +pub mod master; +pub mod server; +pub mod types; + +pub use server_info::ServerInfo; + +use thiserror::Error; + +pub const VERSION: u32 = 49; + +#[derive(Error, Debug, PartialEq, Eq)] +pub enum Error { + #[error("Invalid packet")] + InvalidPacket, + #[error("Invalid UTF-8 string")] + InvalidString, + #[error("Unexpected end of buffer")] + UnexpectedEnd, +} diff --git a/protocol/src/master.rs b/protocol/src/master.rs new file mode 100644 index 0000000..02276bb --- /dev/null +++ b/protocol/src/master.rs @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use std::net::{Ipv4Addr, SocketAddrV4}; + +use super::cursor::{Cursor, CursorMut}; +use super::Error; + +#[derive(Clone, Debug, PartialEq)] +pub struct ChallengeResponse { + pub master_challenge: u32, + pub server_challenge: u32, +} + +impl ChallengeResponse { + pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffs\n"; + + pub fn new(master_challenge: u32, server_challenge: u32) -> Self { + Self { + master_challenge, + server_challenge, + } + } + + pub fn decode(src: &[u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(Self::HEADER)?; + let master_challenge = cur.get_u32_le()?; + let server_challenge = cur.get_u32_le()?; + cur.expect_empty()?; + Ok(Self { + master_challenge, + server_challenge, + }) + } + + pub fn encode(&self, buf: &mut [u8; N]) -> Result { + Ok(CursorMut::new(buf) + .put_bytes(Self::HEADER)? + .put_u32_le(self.master_challenge)? + .put_u32_le(self.server_challenge)? + .pos()) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct QueryServersResponse { + inner: I, +} + +impl QueryServersResponse<()> { + pub const HEADER: &'static [u8] = b"\xff\xff\xff\xfff\n"; +} + +impl<'a> QueryServersResponse<&'a [u8]> { + pub fn decode(src: &'a [u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(QueryServersResponse::HEADER)?; + if cur.remaining() % 6 != 0 { + return Err(Error::InvalidPacket); + } + let s = cur.get_bytes(cur.remaining())?; + let inner = if s.ends_with(&[0; 6]) { + &s[..s.len() - 6] + } else { + s + }; + Ok(Self { inner }) + } + + pub fn iter(&self) -> impl 'a + Iterator { + let mut cur = Cursor::new(self.inner); + (0..self.inner.len() / 6).map(move |_| { + let ip = Ipv4Addr::from(cur.get_array().unwrap()); + let port = cur.get_u16_be().unwrap(); + SocketAddrV4::new(ip, port) + }) + } +} + +impl QueryServersResponse +where + I: Iterator, +{ + pub fn new(iter: I) -> Self { + Self { inner: iter } + } + + pub fn encode(&mut self, buf: &mut [u8]) -> Result<(usize, bool), Error> { + let mut cur = CursorMut::new(buf); + cur.put_bytes(QueryServersResponse::HEADER)?; + let mut is_end = false; + while cur.remaining() >= 12 { + match self.inner.next() { + Some(i) => { + cur.put_array(&i.ip().octets())?.put_u16_be(i.port())?; + } + None => { + is_end = true; + break; + } + } + } + Ok((cur.put_array(&[0; 6])?.pos(), is_end)) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct AdminChallengeResponse { + pub challenge: u32, +} + +impl AdminChallengeResponse { + pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffadminchallenge"; + + pub fn new(challenge: u32) -> Self { + Self { challenge } + } + + pub fn decode(src: &[u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(Self::HEADER)?; + let challenge = cur.get_u32_le()?; + cur.expect_empty()?; + Ok(Self { challenge }) + } + + pub fn encode(&self, buf: &mut [u8]) -> Result { + Ok(CursorMut::new(buf) + .put_bytes(Self::HEADER)? + .put_u32_le(self.challenge)? + .pos()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn challenge_response() { + let p = ChallengeResponse::new(0x12345678, 0x87654321); + let mut buf = [0; 512]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(ChallengeResponse::decode(&buf[..n]), Ok(p)); + } + + #[test] + fn query_servers_response() { + let servers: &[SocketAddrV4] = &[ + "1.2.3.4:27001".parse().unwrap(), + "1.2.3.4:27002".parse().unwrap(), + "1.2.3.4:27003".parse().unwrap(), + "1.2.3.4:27004".parse().unwrap(), + ]; + let mut p = QueryServersResponse::new(servers.iter().cloned()); + let mut buf = [0; 512]; + let (n, _) = p.encode(&mut buf).unwrap(); + let e = QueryServersResponse::decode(&buf[..n]).unwrap(); + assert_eq!(e.iter().collect::>(), servers); + } + + #[test] + fn admin_challenge_response() { + let p = AdminChallengeResponse::new(0x12345678); + let mut buf = [0; 64]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(AdminChallengeResponse::decode(&buf[..n]), Ok(p)); + } +} diff --git a/protocol/src/server.rs b/protocol/src/server.rs new file mode 100644 index 0000000..ed64158 --- /dev/null +++ b/protocol/src/server.rs @@ -0,0 +1,506 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use std::fmt; + +use bitflags::bitflags; +use log::debug; + +use super::cursor::{Cursor, CursorMut, GetKeyValue, PutKeyValue}; +use super::filter::Version; +use super::types::Str; +use super::Error; + +#[derive(Clone, Debug, PartialEq)] +pub struct Challenge { + pub server_challenge: u32, +} + +impl Challenge { + pub const HEADER: &'static [u8] = b"q\xff"; + + pub fn new(server_challenge: u32) -> Self { + Self { server_challenge } + } + + pub fn decode(src: &[u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(Self::HEADER)?; + let server_challenge = cur.get_u32_le()?; + cur.expect_empty()?; + Ok(Self { server_challenge }) + } + + pub fn encode(&self, buf: &mut [u8; N]) -> Result { + Ok(CursorMut::new(buf) + .put_bytes(Self::HEADER)? + .put_u32_le(self.server_challenge)? + .pos()) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(u8)] +pub enum Os { + Linux, + Windows, + Mac, + Unknown, +} + +impl Default for Os { + fn default() -> Os { + Os::Unknown + } +} + +impl TryFrom<&[u8]> for Os { + type Error = Error; + + fn try_from(value: &[u8]) -> Result { + match value { + b"l" => Ok(Os::Linux), + b"w" => Ok(Os::Windows), + b"m" => Ok(Os::Mac), + _ => Ok(Os::Unknown), + } + } +} + +impl GetKeyValue<'_> for Os { + fn get_key_value(cur: &mut Cursor) -> Result { + cur.get_key_value_raw()?.try_into() + } +} + +impl PutKeyValue for Os { + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut CursorMut<'a>, + ) -> Result<&'b mut CursorMut<'a>, Error> { + match self { + Self::Linux => cur.put_str("l"), + Self::Windows => cur.put_str("w"), + Self::Mac => cur.put_str("m"), + Self::Unknown => cur.put_str("?"), + } + } +} + +impl fmt::Display for Os { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + let s = match self { + Os::Linux => "Linux", + Os::Windows => "Windows", + Os::Mac => "Mac", + Os::Unknown => "Unknown", + }; + write!(fmt, "{}", s) + } +} + +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(u8)] +pub enum ServerType { + Dedicated, + Local, + Proxy, + Unknown, +} + +impl Default for ServerType { + fn default() -> Self { + Self::Unknown + } +} + +impl TryFrom<&[u8]> for ServerType { + type Error = Error; + + fn try_from(value: &[u8]) -> Result { + match value { + b"d" => Ok(Self::Dedicated), + b"l" => Ok(Self::Local), + b"p" => Ok(Self::Proxy), + _ => Ok(Self::Unknown), + } + } +} + +impl GetKeyValue<'_> for ServerType { + fn get_key_value(cur: &mut Cursor) -> Result { + cur.get_key_value_raw()?.try_into() + } +} + +impl PutKeyValue for ServerType { + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut CursorMut<'a>, + ) -> Result<&'b mut CursorMut<'a>, Error> { + match self { + Self::Dedicated => cur.put_str("d"), + Self::Local => cur.put_str("l"), + Self::Proxy => cur.put_str("p"), + Self::Unknown => cur.put_str("?"), + } + } +} + +impl fmt::Display for ServerType { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + use ServerType as E; + + let s = match self { + E::Dedicated => "dedicated", + E::Local => "local", + E::Proxy => "proxy", + E::Unknown => "unknown", + }; + + write!(fmt, "{}", s) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(u8)] +pub enum Region { + USEastCoast = 0x00, + USWestCoast = 0x01, + SouthAmerica = 0x02, + Europe = 0x03, + Asia = 0x04, + Australia = 0x05, + MiddleEast = 0x06, + Africa = 0x07, + RestOfTheWorld = 0xff, +} + +impl Default for Region { + fn default() -> Self { + Self::RestOfTheWorld + } +} + +impl TryFrom for Region { + type Error = Error; + + fn try_from(value: u8) -> Result { + match value { + 0x00 => Ok(Region::USEastCoast), + 0x01 => Ok(Region::USWestCoast), + 0x02 => Ok(Region::SouthAmerica), + 0x03 => Ok(Region::Europe), + 0x04 => Ok(Region::Asia), + 0x05 => Ok(Region::Australia), + 0x06 => Ok(Region::MiddleEast), + 0x07 => Ok(Region::Africa), + 0xff => Ok(Region::RestOfTheWorld), + _ => Err(Error::InvalidPacket), + } + } +} + +impl GetKeyValue<'_> for Region { + fn get_key_value(cur: &mut Cursor) -> Result { + cur.get_key_value::()?.try_into() + } +} + +bitflags! { + #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] + pub struct ServerFlags: u8 { + const BOTS = 1 << 0; + const PASSWORD = 1 << 1; + const SECURE = 1 << 2; + const LAN = 1 << 3; + const NAT = 1 << 4; + } +} + +#[derive(Clone, Debug, PartialEq, Default)] +pub struct ServerAdd { + pub gamedir: T, + pub map: T, + pub version: Version, + pub product: T, + pub challenge: u32, + pub server_type: ServerType, + pub os: Os, + pub region: Region, + pub protocol: u8, + pub players: u8, + pub max: u8, + pub flags: ServerFlags, +} + +impl ServerAdd<()> { + pub const HEADER: &'static [u8] = b"0\n"; +} + +impl<'a, T> ServerAdd +where + T: 'a + Default + GetKeyValue<'a>, +{ + pub fn decode(src: &'a [u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(ServerAdd::HEADER)?; + + let mut ret = Self::default(); + let mut challenge = None; + loop { + let key = match cur.get_key_raw() { + Ok(s) => s, + Err(Error::UnexpectedEnd) => break, + Err(e) => return Err(e), + }; + + match key { + b"protocol" => ret.protocol = cur.get_key_value()?, + b"challenge" => challenge = Some(cur.get_key_value()?), + b"players" => ret.players = cur.get_key_value()?, + b"max" => ret.max = cur.get_key_value()?, + b"gamedir" => ret.gamedir = cur.get_key_value()?, + b"map" => ret.map = cur.get_key_value()?, + b"type" => ret.server_type = cur.get_key_value()?, + b"os" => ret.os = cur.get_key_value()?, + b"version" => ret.version = cur.get_key_value()?, + b"region" => ret.region = cur.get_key_value()?, + b"product" => ret.product = cur.get_key_value()?, + b"bots" => ret.flags.set(ServerFlags::BOTS, cur.get_key_value()?), + b"password" => ret.flags.set(ServerFlags::PASSWORD, cur.get_key_value()?), + b"secure" => ret.flags.set(ServerFlags::SECURE, cur.get_key_value()?), + b"lan" => ret.flags.set(ServerFlags::LAN, cur.get_key_value()?), + b"nat" => ret.flags.set(ServerFlags::NAT, cur.get_key_value()?), + _ => { + // skip unknown fields + let value = cur.get_key_value::>()?; + debug!("Invalid ServerInfo field \"{}\" = \"{}\"", Str(key), value); + } + } + } + + match challenge { + Some(c) => { + ret.challenge = c; + Ok(ret) + } + None => Err(Error::InvalidPacket), + } + } +} + +impl ServerAdd +where + T: PutKeyValue + Clone, +{ + pub fn encode(&self, buf: &mut [u8]) -> Result { + Ok(CursorMut::new(buf) + .put_bytes(ServerAdd::HEADER)? + .put_key("protocol", self.protocol)? + .put_key("challenge", self.challenge)? + .put_key("players", self.players)? + .put_key("max", self.max)? + .put_key("gamedir", self.gamedir.clone())? + .put_key("map", self.map.clone())? + .put_key("type", self.server_type)? + .put_key("os", self.os)? + .put_key("version", self.version)? + .put_key("region", self.region as u8)? + .put_key("product", self.product.clone())? + .put_key("bots", self.flags.contains(ServerFlags::BOTS))? + .put_key("password", self.flags.contains(ServerFlags::PASSWORD))? + .put_key("secure", self.flags.contains(ServerFlags::SECURE))? + .put_key("lan", self.flags.contains(ServerFlags::LAN))? + .put_key("nat", self.flags.contains(ServerFlags::NAT))? + .pos()) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct ServerRemove; + +impl ServerRemove { + pub const HEADER: &'static [u8] = b"b\n"; + + pub fn decode(src: &[u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(Self::HEADER)?; + cur.expect_empty()?; + Ok(Self) + } + + pub fn encode(&self, buf: &mut [u8; N]) -> Result { + Ok(CursorMut::new(buf).put_bytes(Self::HEADER)?.pos()) + } +} + +#[derive(Clone, Debug, PartialEq, Default)] +pub struct GetServerInfoResponse { + pub gamedir: T, + pub map: T, + pub host: T, + pub protocol: u8, + pub numcl: u8, + pub maxcl: u8, + pub dm: bool, + pub team: bool, + pub coop: bool, + pub password: bool, +} + +impl GetServerInfoResponse<()> { + pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffinfo\n"; +} + +impl<'a, T> GetServerInfoResponse +where + T: 'a + Default + GetKeyValue<'a>, +{ + pub fn decode(src: &'a [u8]) -> Result { + let mut cur = Cursor::new(src); + cur.expect(GetServerInfoResponse::HEADER)?; + + let mut ret = Self::default(); + loop { + let key = match cur.get_key_raw() { + Ok(s) => s, + Err(Error::UnexpectedEnd) => break, + Err(e) => return Err(e), + }; + + match key { + b"p" => ret.protocol = cur.get_key_value()?, + b"map" => ret.map = cur.get_key_value()?, + b"dm" => ret.dm = cur.get_key_value()?, + b"team" => ret.team = cur.get_key_value()?, + b"coop" => ret.coop = cur.get_key_value()?, + b"numcl" => ret.numcl = cur.get_key_value()?, + b"maxcl" => ret.maxcl = cur.get_key_value()?, + b"gamedir" => ret.gamedir = cur.get_key_value()?, + b"password" => ret.password = cur.get_key_value()?, + b"host" => ret.host = cur.get_key_value()?, + _ => { + // skip unknown fields + let value = cur.get_key_value::>()?; + debug!( + "Invalid GetServerInfo field \"{}\" = \"{}\"", + Str(key), + value + ); + } + } + } + + Ok(ret) + } +} + +impl<'a> GetServerInfoResponse<&'a str> { + pub fn encode(&self, buf: &mut [u8]) -> Result { + Ok(CursorMut::new(buf) + .put_bytes(GetServerInfoResponse::HEADER)? + .put_key("p", self.protocol)? + .put_key("map", self.map)? + .put_key("dm", self.dm)? + .put_key("team", self.team)? + .put_key("coop", self.coop)? + .put_key("numcl", self.numcl)? + .put_key("maxcl", self.maxcl)? + .put_key("gamedir", self.gamedir)? + .put_key("password", self.password)? + .put_key("host", self.host)? + .pos()) + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Packet<'a> { + Challenge(Challenge), + ServerAdd(ServerAdd>), + ServerRemove, + GetServerInfoResponse(GetServerInfoResponse>), +} + +impl<'a> Packet<'a> { + pub fn decode(src: &'a [u8]) -> Result { + if let Ok(p) = Challenge::decode(src) { + return Ok(Self::Challenge(p)); + } + + if let Ok(p) = ServerAdd::decode(src) { + return Ok(Self::ServerAdd(p)); + } + + if ServerRemove::decode(src).is_ok() { + return Ok(Self::ServerRemove); + } + + if let Ok(p) = GetServerInfoResponse::decode(src) { + return Ok(Self::GetServerInfoResponse(p)); + } + + Err(Error::InvalidPacket) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn challenge() { + let p = Challenge::new(0x12345678); + let mut buf = [0; 128]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(Challenge::decode(&buf[..n]), Ok(p)); + } + + #[test] + fn server_add() { + let p = ServerAdd { + gamedir: "valve", + map: "crossfire", + version: Version::new(0, 20), + product: "foobar", + challenge: 0x12345678, + server_type: ServerType::Dedicated, + os: Os::Linux, + region: Region::RestOfTheWorld, + protocol: 49, + players: 4, + max: 32, + flags: ServerFlags::all(), + }; + let mut buf = [0; 512]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(ServerAdd::decode(&buf[..n]), Ok(p)); + } + + #[test] + fn server_remove() { + let p = ServerRemove; + let mut buf = [0; 64]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(ServerRemove::decode(&buf[..n]), Ok(p)); + } + + #[test] + fn get_server_info_response() { + let p = GetServerInfoResponse { + protocol: 49, + map: "crossfire", + dm: true, + team: true, + coop: true, + numcl: 4, + maxcl: 32, + gamedir: "valve", + password: true, + host: "Test", + }; + let mut buf = [0; 512]; + let n = p.encode(&mut buf).unwrap(); + assert_eq!(GetServerInfoResponse::decode(&buf[..n]), Ok(p)); + } +} diff --git a/protocol/src/server_info.rs b/protocol/src/server_info.rs new file mode 100644 index 0000000..6a40f63 --- /dev/null +++ b/protocol/src/server_info.rs @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use super::filter::{FilterFlags, Version}; +use super::server::{Region, ServerAdd}; +use super::types::Str; + +#[derive(Clone, Debug)] +pub struct ServerInfo { + pub version: Version, + pub gamedir: Box<[u8]>, + pub map: Box<[u8]>, + pub flags: FilterFlags, + pub region: Region, +} + +impl ServerInfo { + pub fn new(info: &ServerAdd>) -> Self { + Self { + version: info.version, + gamedir: info.gamedir.to_vec().into_boxed_slice(), + map: info.map.to_vec().into_boxed_slice(), + flags: FilterFlags::from(info), + region: info.region, + } + } +} diff --git a/protocol/src/types.rs b/protocol/src/types.rs new file mode 100644 index 0000000..98aa90e --- /dev/null +++ b/protocol/src/types.rs @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use std::fmt; +use std::ops::Deref; + +/// Wrapper for slice of bytes with printing the bytes as a string +#[derive(Copy, Clone, PartialEq, Eq, Default)] +pub struct Str(pub T); + +impl From for Str { + fn from(value: T) -> Self { + Self(value) + } +} + +impl fmt::Debug for Str +where + T: AsRef<[u8]>, +{ + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + for &c in self.0.as_ref() { + match c { + b'\n' => write!(fmt, "\\n")?, + b'\t' => write!(fmt, "\\t")?, + _ if c.is_ascii_graphic() || c == b' ' => { + write!(fmt, "{}", c as char)?; + } + _ => write!(fmt, "\\x{:02x}", c)?, + } + } + Ok(()) + } +} + +impl fmt::Display for Str +where + T: AsRef<[u8]>, +{ + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + ::fmt(self, fmt) + } +} + +impl Deref for Str { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/client.rs b/src/client.rs deleted file mode 100644 index 04e619c..0000000 --- a/src/client.rs +++ /dev/null @@ -1,91 +0,0 @@ -// SPDX-License-Identifier: GPL-3.0-only -// SPDX-FileCopyrightText: 2023 Denis Drakhnia - -use std::fmt; -use std::io; -use std::ops::Deref; -use std::str; - -use log::debug; -use thiserror::Error; - -use crate::server_info::{Region, ServerInfo}; - -#[derive(Error, Debug)] -pub enum Error { - #[error("Invalid packet")] - InvalidPacket, - #[error(transparent)] - IoError(#[from] io::Error), -} - -pub struct Filter<'a>(&'a [u8]); - -impl fmt::Debug for Filter<'_> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - String::from_utf8_lossy(self.0).fmt(fmt) - } -} - -impl<'a> Deref for Filter<'a> { - type Target = [u8]; - - fn deref(&self) -> &Self::Target { - self.0 - } -} - -#[derive(Debug)] -pub enum Packet<'a> { - Challenge(Option), - ServerAdd(Option, ServerInfo<&'a str>), - ServerRemove, - QueryServers(Region, Filter<'a>), - ServerInfo, -} - -impl<'a> Packet<'a> { - pub fn decode(s: &'a [u8]) -> Result { - match s { - [b'1', region, tail @ ..] => { - let region = Region::try_from(*region).map_err(|_| Error::InvalidPacket)?; - let (tail, _last_ip) = decode_cstr(tail)?; - let (tail, filter) = decode_cstr(tail)?; - if !tail.is_empty() { - return Err(Error::InvalidPacket); - } - - Ok(Self::QueryServers(region, Filter(filter))) - } - [b'q', 0xff, b0, b1, b2, b3] => { - let challenge = u32::from_le_bytes([*b0, *b1, *b2, *b3]); - Ok(Self::Challenge(Some(challenge))) - } - [b'0', b'\n', tail @ ..] => { - let (challenge, info, tail) = - ServerInfo::from_bytes(tail).map_err(|_| Error::InvalidPacket)?; - if tail != b"" && tail != b"\n" { - debug!("unexpected end {:?}", tail); - } - Ok(Self::ServerAdd(challenge, info)) - } - [b'b', b'\n'] => Ok(Self::ServerRemove), - [b'q'] => Ok(Self::Challenge(None)), - [0xff, 0xff, 0xff, 0xff, b'i', b'n', b'f', b'o', b' ', _, _] => Ok(Self::ServerInfo), - _ => Err(Error::InvalidPacket), - } - } -} - -fn decode_cstr(data: &[u8]) -> Result<(&[u8], &[u8]), Error> { - data.iter() - .position(|&c| c == 0) - .ok_or(Error::InvalidPacket) - .map(|offset| (&data[offset + 1..], &data[..offset])) -} - -// fn decode_str(data: &[u8]) -> Result<(&[u8], &str), Error> { -// let (tail, s) = decode_cstr(data)?; -// let s = str::from_utf8(s).map_err(|_| Error::InvalidPacket)?; -// Ok((tail, s)) -// } diff --git a/src/master_server.rs b/src/master_server.rs deleted file mode 100644 index 3f3c0da..0000000 --- a/src/master_server.rs +++ /dev/null @@ -1,342 +0,0 @@ -// SPDX-License-Identifier: GPL-3.0-only -// SPDX-FileCopyrightText: 2023 Denis Drakhnia - -use std::collections::HashMap; -use std::io::prelude::*; -use std::io::{self, Cursor}; -use std::net::{SocketAddr, SocketAddrV4, ToSocketAddrs, UdpSocket}; -use std::ops::Deref; -use std::time::Instant; - -use fastrand::Rng; -use log::{error, info, trace, warn}; -use thiserror::Error; - -use crate::client::Packet; -use crate::config::{self, Config}; -use crate::filter::{Filter, Version}; -use crate::server::Server; -use crate::server_info::Region; - -/// The maximum size of UDP packets. -const MAX_PACKET_SIZE: usize = 512; - -const CHALLENGE_RESPONSE_HEADER: &[u8] = b"\xff\xff\xff\xffs\n"; -const SERVER_LIST_HEADER: &[u8] = b"\xff\xff\xff\xfff\n"; - -/// How many cleanup calls should be skipped before removing outdated servers. -const SERVER_CLEANUP_MAX: usize = 100; - -/// How many cleanup calls should be skipped before removing outdated challenges. -const CHALLENGE_CLEANUP_MAX: usize = 100; - -#[derive(Error, Debug)] -pub enum Error { - #[error("Failed to bind server socket: {0}")] - BindSocket(io::Error), - #[error("Failed to decode packet: {0}")] - ClientPacket(#[from] crate::client::Error), - #[error("Missing challenge in ServerInfo")] - MissingChallenge, - #[error(transparent)] - Io(#[from] io::Error), -} - -/// HashMap entry to keep tracking creation time. -struct Entry { - time: u32, - value: T, -} - -impl Entry { - fn new(time: u32, value: T) -> Self { - Self { time, value } - } - - fn is_valid(&self, now: u32, duration: u32) -> bool { - (now - self.time) < duration - } -} - -impl Entry { - fn matches(&self, addr: SocketAddrV4, region: Region, filter: &Filter) -> bool { - self.region == region && filter.matches(addr, self) - } -} - -impl Deref for Entry { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.value - } -} - -struct MasterServer { - sock: UdpSocket, - challenges: HashMap>, - servers: HashMap>, - rng: Rng, - - start_time: Instant, - cleanup_challenges: usize, - cleanup_servers: usize, - timeout: config::TimeoutConfig, - - clver: Version, - update_title: Box, - update_map: Box, - update_addr: SocketAddrV4, -} - -impl MasterServer { - fn new(cfg: Config) -> Result { - let addr = SocketAddr::new(cfg.server.ip, cfg.server.port); - info!("Listen address: {}", addr); - let sock = UdpSocket::bind(addr).map_err(Error::BindSocket)?; - let update_addr = - cfg.client - .update_addr - .unwrap_or_else(|| match sock.local_addr().unwrap() { - SocketAddr::V4(addr) => addr, - _ => todo!(), - }); - - Ok(Self { - sock, - start_time: Instant::now(), - challenges: Default::default(), - servers: Default::default(), - rng: Rng::new(), - cleanup_challenges: 0, - cleanup_servers: 0, - timeout: cfg.server.timeout, - clver: cfg.client.version, - update_title: cfg.client.update_title, - update_map: cfg.client.update_map, - update_addr, - }) - } - - fn run(&mut self) -> Result<(), Error> { - let mut buf = [0; MAX_PACKET_SIZE]; - loop { - let (n, from) = self.sock.recv_from(&mut buf)?; - let from = match from { - SocketAddr::V4(a) => a, - _ => { - warn!("{}: Received message from IPv6, unimplemented", from); - continue; - } - }; - - if let Err(e) = self.handle_packet(from, &buf[..n]) { - error!("{}: {}", from, e); - } - } - } - - fn handle_packet(&mut self, from: SocketAddrV4, s: &[u8]) -> Result<(), Error> { - let packet = match Packet::decode(s) { - Ok(p) => p, - Err(_) => { - trace!("{}: Failed to decode {:?}", from, s); - return Ok(()); - } - }; - - trace!("{}: recv {:?}", from, packet); - - match packet { - Packet::Challenge(server_challenge) => { - let challenge = self.add_challenge(from); - trace!("{}: New challenge {}", from, challenge); - self.send_challenge_response(from, challenge, server_challenge)?; - self.remove_outdated_challenges(); - } - Packet::ServerAdd(challenge, info) => { - let challenge = match challenge { - Some(c) => c, - None => return Err(Error::MissingChallenge), - }; - let entry = match self.challenges.get(&from) { - Some(e) => e, - None => { - trace!("{}: Challenge does not exists", from); - return Ok(()); - } - }; - if !entry.is_valid(self.now(), self.timeout.challenge) { - return Ok(()); - } - if challenge != entry.value { - warn!( - "{}: Expected challenge {} but received {}", - from, entry.value, challenge - ); - return Ok(()); - } - if self.challenges.remove(&from).is_some() { - self.add_server(from, Server::new(&info)); - } - self.remove_outdated_servers(); - } - Packet::ServerRemove => { /* ignore */ } - Packet::QueryServers(region, filter) => { - let filter = match Filter::from_bytes(&filter) { - Ok(f) => f, - _ => { - warn!("{}: Invalid filter: {:?}", from, filter); - return Ok(()); - } - }; - - if filter.clver.map_or(true, |v| v < self.clver) { - let iter = std::iter::once(&self.update_addr); - self.send_server_list(from, iter)?; - } else { - let now = self.now(); - let iter = self - .servers - .iter() - .filter(|i| i.1.is_valid(now, self.timeout.server)) - .filter(|i| i.1.matches(*i.0, region, &filter)) - .map(|i| i.0); - self.send_server_list(from, iter)?; - } - } - Packet::ServerInfo => { - let mut buf = [0; MAX_PACKET_SIZE]; - let mut cur = Cursor::new(&mut buf[..]); - cur.write_all(b"\xff\xff\xff\xffinfo\n")?; - cur.write_all(b"\\p\\49")?; - cur.write_all(b"\\map\\")?; - cur.write_all(self.update_map.as_bytes())?; - cur.write_all(b"\\dm\\1")?; - cur.write_all(b"\\team\\0")?; - cur.write_all(b"\\coop\\0")?; - cur.write_all(b"\\numcl\\0")?; - cur.write_all(b"\\maxcl\\0")?; - cur.write_all(b"\\gamedir\\valve")?; - cur.write_all(b"\\password\\0")?; - cur.write_all(b"\\host\\")?; - cur.write_all(self.update_title.as_bytes())?; - let n = cur.position() as usize; - self.sock.send_to(&buf[..n], from)?; - } - } - - Ok(()) - } - - fn now(&self) -> u32 { - self.start_time.elapsed().as_secs() as u32 - } - - fn add_challenge(&mut self, addr: SocketAddrV4) -> u32 { - let x = self.rng.u32(..); - let entry = Entry::new(self.now(), x); - self.challenges.insert(addr, entry); - x - } - - fn remove_outdated_challenges(&mut self) { - if self.cleanup_challenges < CHALLENGE_CLEANUP_MAX { - self.cleanup_challenges += 1; - return; - } - let now = self.now(); - let old = self.challenges.len(); - self.challenges - .retain(|_, v| v.is_valid(now, self.timeout.challenge)); - let new = self.challenges.len(); - if old != new { - trace!("Removed {} outdated challenges", old - new); - } - self.cleanup_challenges = 0; - } - - fn add_server(&mut self, addr: SocketAddrV4, server: Server) { - match self.servers.insert(addr, Entry::new(self.now(), server)) { - Some(_) => trace!("{}: Updated GameServer", addr), - None => trace!("{}: New GameServer", addr), - } - } - - fn remove_outdated_servers(&mut self) { - if self.cleanup_servers < SERVER_CLEANUP_MAX { - self.cleanup_servers += 1; - return; - } - let now = self.now(); - let old = self.servers.len(); - self.servers - .retain(|_, v| v.is_valid(now, self.timeout.server)); - let new = self.servers.len(); - if old != new { - trace!("Removed {} outdated servers", old - new); - } - self.cleanup_servers = 0; - } - - fn send_challenge_response( - &self, - to: A, - challenge: u32, - server_challenge: Option, - ) -> Result<(), io::Error> { - let mut buf = [0; MAX_PACKET_SIZE]; - let mut cur = Cursor::new(&mut buf[..]); - - cur.write_all(CHALLENGE_RESPONSE_HEADER)?; - cur.write_all(&challenge.to_le_bytes())?; - if let Some(x) = server_challenge { - cur.write_all(&x.to_le_bytes())?; - } - - let n = cur.position() as usize; - self.sock.send_to(&buf[..n], to)?; - Ok(()) - } - - fn send_server_list<'a, A, I>(&self, to: A, mut iter: I) -> Result<(), io::Error> - where - A: ToSocketAddrs, - I: Iterator, - { - let mut buf = [0; MAX_PACKET_SIZE]; - let mut done = false; - while !done { - let mut cur = Cursor::new(&mut buf[..]); - cur.write_all(SERVER_LIST_HEADER)?; - - loop { - match iter.next() { - Some(i) => { - cur.write_all(&i.ip().octets()[..])?; - cur.write_all(&i.port().to_be_bytes())?; - } - None => { - done = true; - break; - } - } - - if (cur.position() as usize) > (MAX_PACKET_SIZE - 12) { - break; - } - } - - // terminate list - cur.write_all(&[0; 6][..])?; - - let n = cur.position() as usize; - self.sock.send_to(&buf[..n], &to)?; - } - Ok(()) - } -} - -pub fn run(cfg: Config) -> Result<(), Error> { - MasterServer::new(cfg)?.run() -} diff --git a/src/server.rs b/src/server.rs deleted file mode 100644 index 2e94a9e..0000000 --- a/src/server.rs +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-License-Identifier: GPL-3.0-only -// SPDX-FileCopyrightText: 2023 Denis Drakhnia - -use crate::filter::FilterFlags; -use crate::server_info::{Region, ServerInfo}; - -#[derive(Clone, Debug)] -pub struct Server { - pub version: Box, - pub gamedir: Box, - pub map: Box, - pub flags: FilterFlags, - pub region: Region, -} - -impl Server { - pub fn new(info: &ServerInfo<&str>) -> Self { - Self { - version: info.version.to_string().into_boxed_str(), - gamedir: info.gamedir.to_string().into_boxed_str(), - map: info.map.to_string().into_boxed_str(), - flags: FilterFlags::from(info), - region: info.region, - } - } -} diff --git a/src/server_info.rs b/src/server_info.rs deleted file mode 100644 index b8b14d2..0000000 --- a/src/server_info.rs +++ /dev/null @@ -1,329 +0,0 @@ -// SPDX-License-Identifier: GPL-3.0-only -// SPDX-FileCopyrightText: 2023 Denis Drakhnia - -use std::fmt; - -use bitflags::bitflags; -use log::{debug, log_enabled, Level}; -use thiserror::Error; - -use crate::parser::{Error as ParserError, ParseValue, Parser}; - -#[derive(Copy, Clone, Error, Debug, PartialEq, Eq)] -pub enum Error { - #[error("Invalid region")] - InvalidRegion, - #[error(transparent)] - Parser(#[from] ParserError), -} - -pub type Result = std::result::Result; - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -#[repr(u8)] -pub enum Os { - Linux, - Windows, - Mac, - Unknown, -} - -impl Default for Os { - fn default() -> Os { - Os::Unknown - } -} - -impl ParseValue<'_> for Os { - type Err = Error; - - fn parse(p: &mut Parser) -> Result { - match p.parse_bytes()? { - b"l" => Ok(Os::Linux), - b"w" => Ok(Os::Windows), - b"m" => Ok(Os::Mac), - _ => Ok(Os::Unknown), - } - } -} - -impl fmt::Display for Os { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - let s = match self { - Os::Linux => "Linux", - Os::Windows => "Windows", - Os::Mac => "Mac", - Os::Unknown => "Unknown", - }; - write!(fmt, "{}", s) - } -} - -#[derive(Copy, Clone, Debug, PartialEq)] -#[repr(u8)] -pub enum ServerType { - Dedicated, - Local, - Proxy, - Unknown, -} - -impl Default for ServerType { - fn default() -> Self { - Self::Unknown - } -} - -impl ParseValue<'_> for ServerType { - type Err = Error; - - fn parse(p: &mut Parser) -> Result { - match p.parse_bytes()? { - b"d" => Ok(Self::Dedicated), - b"l" => Ok(Self::Local), - b"p" => Ok(Self::Proxy), - _ => Ok(Self::Unknown), - } - } -} - -impl fmt::Display for ServerType { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - use ServerType as E; - - let s = match self { - E::Dedicated => "dedicated", - E::Local => "local", - E::Proxy => "proxy", - E::Unknown => "unknown", - }; - - write!(fmt, "{}", s) - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -#[repr(u8)] -pub enum Region { - USEastCoast = 0x00, - USWestCoast = 0x01, - SouthAmerica = 0x02, - Europe = 0x03, - Asia = 0x04, - Australia = 0x05, - MiddleEast = 0x06, - Africa = 0x07, - RestOfTheWorld = 0xff, -} - -impl Default for Region { - fn default() -> Self { - Self::RestOfTheWorld - } -} - -impl TryFrom for Region { - type Error = (); - - fn try_from(value: u8) -> Result { - match value { - 0x00 => Ok(Region::USEastCoast), - 0x01 => Ok(Region::USWestCoast), - 0x02 => Ok(Region::SouthAmerica), - 0x03 => Ok(Region::Europe), - 0x04 => Ok(Region::Asia), - 0x05 => Ok(Region::Australia), - 0x06 => Ok(Region::MiddleEast), - 0x07 => Ok(Region::Africa), - 0xff => Ok(Region::RestOfTheWorld), - _ => Err(()), - } - } -} - -impl ParseValue<'_> for Region { - type Err = Error; - - fn parse(p: &mut Parser<'_>) -> Result { - let value = p.parse::()?; - Self::try_from(value).map_err(|_| Error::InvalidRegion) - } -} - -bitflags! { - #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] - pub struct ServerFlags: u8 { - const BOTS = 1 << 0; - const PASSWORD = 1 << 1; - const SECURE = 1 << 2; - const LAN = 1 << 3; - const NAT = 1 << 4; - } -} - -#[derive(Clone, Debug, Default, PartialEq)] -pub struct ServerInfo> { - pub gamedir: T, - pub map: T, - pub version: T, - pub product: T, - pub server_type: ServerType, - pub os: Os, - pub region: Region, - pub protocol: u8, - pub players: u8, - pub max: u8, - pub flags: ServerFlags, -} - -impl<'a, T> ServerInfo -where - T: 'a + Default + ParseValue<'a, Err = ParserError>, -{ - pub fn from_bytes(src: &'a [u8]) -> Result<(Option, Self, &'a [u8]), Error> { - let mut parser = Parser::new(src); - let (challenge, info) = parser.parse()?; - let tail = match parser.end() { - [b'\n', tail @ ..] => tail, - tail => tail, - }; - Ok((challenge, info, tail)) - } -} - -impl<'a, T> ParseValue<'a> for (Option, ServerInfo) -where - T: 'a + Default + ParseValue<'a, Err = ParserError>, -{ - type Err = Error; - - fn parse(p: &mut Parser<'a>) -> Result { - let mut info = ServerInfo::default(); - let mut challenge = None; - - loop { - let name = match p.parse_bytes() { - Ok(s) => s, - Err(ParserError::End) => break, - Err(e) => return Err(e.into()), - }; - - match name { - b"protocol" => info.protocol = p.parse()?, - b"challenge" => challenge = Some(p.parse()?), - b"players" => info.players = p.parse()?, - b"max" => info.max = p.parse()?, - b"gamedir" => info.gamedir = p.parse()?, - b"map" => info.map = p.parse()?, - b"type" => info.server_type = p.parse()?, - b"os" => info.os = p.parse()?, - b"version" => info.version = p.parse()?, - b"region" => info.region = p.parse()?, - b"product" => info.product = p.parse()?, - b"bots" => info.flags.set(ServerFlags::BOTS, p.parse()?), - b"password" => info.flags.set(ServerFlags::PASSWORD, p.parse()?), - b"secure" => info.flags.set(ServerFlags::SECURE, p.parse()?), - b"lan" => info.flags.set(ServerFlags::LAN, p.parse()?), - b"nat" => info.flags.set(ServerFlags::NAT, p.parse()?), - _ => { - // skip unknown fields - let value = p.parse_bytes()?; - if log_enabled!(Level::Debug) { - let name = String::from_utf8_lossy(name); - let value = String::from_utf8_lossy(value); - debug!("Invalid ServerInfo field \"{}\" = \"{}\"", name, value); - } - } - } - } - - Ok((challenge, info)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::parser::parse; - - #[test] - fn parse_os() { - assert_eq!(parse(b"\\l\\"), Ok(Os::Linux)); - assert_eq!(parse(b"\\w\\"), Ok(Os::Windows)); - assert_eq!(parse(b"\\m\\"), Ok(Os::Mac)); - assert_eq!(parse::(b"\\u\\"), Ok(Os::Unknown)); - } - - #[test] - fn parse_server_type() { - use ServerType as E; - - assert_eq!(parse(b"\\d\\"), Ok(E::Dedicated)); - assert_eq!(parse(b"\\l\\"), Ok(E::Local)); - assert_eq!(parse(b"\\p\\"), Ok(E::Proxy)); - assert_eq!(parse::(b"\\u\\"), Ok(E::Unknown)); - } - - #[test] - fn parse_region() { - assert_eq!(parse(b"\\0\\"), Ok(Region::USEastCoast)); - assert_eq!(parse(b"\\1\\"), Ok(Region::USWestCoast)); - assert_eq!(parse(b"\\2\\"), Ok(Region::SouthAmerica)); - assert_eq!(parse(b"\\3\\"), Ok(Region::Europe)); - assert_eq!(parse(b"\\4\\"), Ok(Region::Asia)); - assert_eq!(parse(b"\\5\\"), Ok(Region::Australia)); - assert_eq!(parse(b"\\6\\"), Ok(Region::MiddleEast)); - assert_eq!(parse(b"\\7\\"), Ok(Region::Africa)); - assert_eq!(parse(b"\\-1\\"), Ok(Region::RestOfTheWorld)); - assert_eq!(parse::(b"\\-2\\"), Err(Error::InvalidRegion)); - assert_eq!( - parse::(b"\\u\\"), - Err(Error::Parser(ParserError::InvalidInteger)) - ); - } - - #[test] - fn parse_server_info() { - let buf = b"\ - \\protocol\\47\ - \\challenge\\12345678\ - \\players\\16\ - \\max\\32\ - \\bots\\1\ - \\invalid_field\\field_value\ - \\gamedir\\cstrike\ - \\map\\de_dust\ - \\type\\d\ - \\password\\1\ - \\os\\l\ - \\secure\\1\ - \\lan\\1\ - \\version\\1.1.2.5\ - \\region\\-1\ - \\product\\cstrike\ - \\nat\\1\ - \ntail\ - "; - - assert_eq!( - ServerInfo::from_bytes(&buf[..]), - Ok(( - Some(12345678), - ServerInfo::<&str> { - protocol: 47, - players: 16, - max: 32, - gamedir: "cstrike", - map: "de_dust", - server_type: ServerType::Dedicated, - os: Os::Linux, - version: "1.1.2.5", - region: Region::RestOfTheWorld, - product: "cstrike", - flags: ServerFlags::all(), - }, - &b"tail"[..] - )) - ); - } -}