protocol: report more informative errors

This commit is contained in:
Denis Drakhnia 2023-12-27 10:22:37 +02:00
parent 2d180481e3
commit 142b28ad64
10 changed files with 477 additions and 330 deletions

View File

@ -31,7 +31,7 @@ fn send_command(cli: &cli::Cli) -> Result<(), Error> {
let n = sock.recv(&mut buf)?;
let (master_challenge, hash_challenge) = match master::Packet::decode(&buf[..n])? {
master::Packet::AdminChallengeResponse(p) => (p.master_challenge, p.hash_challenge),
Some(master::Packet::AdminChallengeResponse(p)) => (p.master_challenge, p.hash_challenge),
_ => return Err(Error::UnexpectedPacket),
};

View File

@ -44,6 +44,10 @@ pub enum Error {
Io(#[from] io::Error),
#[error("Admin challenge do not exist")]
AdminChallengeNotFound,
#[error("Undefined packet")]
UndefinedPacket,
#[error("Unexpected packet")]
UnexpectedPacket,
}
/// HashMap entry to keep tracking creation time.
@ -235,8 +239,9 @@ impl MasterServer {
}
};
if let Err(e) = self.handle_packet(from, &buf[..n]) {
error!("{}: {}", from, e);
let src = &buf[..n];
if let Err(e) = self.handle_packet(from, src) {
debug!("{}: {}: \"{}\"", from, e, Str(src));
}
}
Ok(())
@ -249,165 +254,194 @@ impl MasterServer {
self.admin_challenges.clear();
}
fn handle_server_packet(&mut self, from: SocketAddrV4, p: server::Packet) -> Result<(), Error> {
trace!("{}: recv {:?}", from, p);
match p {
server::Packet::Challenge(p) => {
let master_challenge = self.add_challenge(from);
let mut buf = [0; MAX_PACKET_SIZE];
let p = master::ChallengeResponse::new(master_challenge, p.server_challenge);
trace!("{}: send {:?}", from, p);
let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?;
self.remove_outdated_challenges();
}
server::Packet::ServerAdd(p) => {
let entry = match self.challenges.get(&from) {
Some(e) => e,
None => {
trace!("{}: Challenge does not exists", from);
return Ok(());
}
};
if !entry.is_valid(self.now(), self.timeout.challenge) {
return Ok(());
}
if p.challenge != entry.value {
warn!(
"{}: Expected challenge {} but received {}",
from, entry.value, p.challenge
);
return Ok(());
}
if self.challenges.remove(&from).is_some() {
self.add_server(from, ServerInfo::new(&p));
}
self.remove_outdated_servers();
}
server::Packet::ServerRemove => {
// ignore
}
_ => {
return Err(Error::UnexpectedPacket);
}
}
Ok(())
}
fn handle_game_packet(&mut self, from: SocketAddrV4, p: game::Packet) -> Result<(), Error> {
trace!("{}: recv {:?}", from, p);
match p {
game::Packet::QueryServers(p) => {
if p.filter.clver.map_or(false, |v| v < self.clver) {
let iter = std::iter::once(self.update_addr);
self.send_server_list(from, p.filter.key, iter)?;
} else {
let now = self.now();
let iter = self
.servers
.iter()
.filter(|i| i.1.is_valid(now, self.timeout.server))
.filter(|i| i.1.matches(*i.0, p.region, &p.filter))
.map(|i| *i.0);
self.send_server_list(from, p.filter.key, iter.clone())?;
if p.filter.flags.contains(FilterFlags::NAT) {
self.send_client_to_nat_servers(from, iter)?;
}
}
}
game::Packet::GetServerInfo(_) => {
let p = server::GetServerInfoResponse {
map: self.update_map.as_ref(),
host: self.update_title.as_ref(),
protocol: 48, // XXX: how to detect what version client will accept?
dm: true,
maxcl: 32,
gamedir: "valve", // XXX: probably must be specific for client...
..Default::default()
};
trace!("{}: send {:?}", from, p);
let mut buf = [0; MAX_PACKET_SIZE];
let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?;
}
}
Ok(())
}
fn handle_admin_packet(&mut self, from: SocketAddrV4, p: admin::Packet) -> Result<(), Error> {
trace!("{}: recv {:?}", from, p);
let now = self.now();
if let Some(e) = self.admin_limit.get(from.ip()) {
if e.is_valid(now, self.timeout.admin) {
trace!("{}: rate limit", from);
return Ok(());
}
}
match p {
admin::Packet::AdminChallenge => {
let (master_challenge, hash_challenge) = self.admin_challenge_add(from);
let p = master::AdminChallengeResponse::new(master_challenge, hash_challenge);
trace!("{}: send {:?}", from, p);
let mut buf = [0; 64];
let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?;
self.admin_challenges_cleanup();
}
admin::Packet::AdminCommand(p) => {
let entry = *self
.admin_challenges
.get(from.ip())
.ok_or(Error::AdminChallengeNotFound)?;
if entry.0 != p.master_challenge {
trace!("{}: master challenge is not valid", from);
return Ok(());
}
if !entry.is_valid(now, self.timeout.challenge) {
trace!("{}: challenge is outdated", from);
return Ok(());
}
let state = Params::new()
.hash_length(self.hash.len)
.key(self.hash.key.as_bytes())
.personal(self.hash.personal.as_bytes())
.to_state();
let admin = self.admin_list.iter().find(|i| {
let hash = state
.clone()
.update(i.password.as_bytes())
.update(&entry.1.to_le_bytes())
.finalize();
*p.hash == hash.as_bytes()
});
match admin {
Some(admin) => {
info!("{}: admin({}), command: {:?}", from, &admin.name, p.command);
self.admin_command(p.command);
self.admin_challenge_remove(from);
}
None => {
warn!("{}: invalid admin hash, command: {:?}", from, p.command);
self.admin_limit.insert(*from.ip(), Entry::new(now, ()));
self.admin_limit_cleanup();
}
}
}
}
Ok(())
}
fn handle_packet(&mut self, from: SocketAddrV4, src: &[u8]) -> Result<(), Error> {
if self.is_blocked(from.ip()) {
return Ok(());
}
if let Ok(p) = server::Packet::decode(src) {
match p {
server::Packet::Challenge(p) => {
trace!("{}: recv {:?}", from, p);
let master_challenge = self.add_challenge(from);
let mut buf = [0; MAX_PACKET_SIZE];
let p = master::ChallengeResponse::new(master_challenge, p.server_challenge);
trace!("{}: send {:?}", from, p);
let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?;
self.remove_outdated_challenges();
}
server::Packet::ServerAdd(p) => {
trace!("{}: recv {:?}", from, p);
let entry = match self.challenges.get(&from) {
Some(e) => e,
None => {
trace!("{}: Challenge does not exists", from);
return Ok(());
}
};
if !entry.is_valid(self.now(), self.timeout.challenge) {
return Ok(());
}
if p.challenge != entry.value {
warn!(
"{}: Expected challenge {} but received {}",
from, entry.value, p.challenge
);
return Ok(());
}
if self.challenges.remove(&from).is_some() {
self.add_server(from, ServerInfo::new(&p));
}
self.remove_outdated_servers();
}
_ => {
trace!("{}: recv {:?}", from, p);
}
}
} else if let Ok(p) = game::Packet::decode(src) {
match p {
game::Packet::QueryServers(p) => {
trace!("{}: recv {:?}", from, p);
if p.filter.clver.map_or(false, |v| v < self.clver) {
let iter = std::iter::once(self.update_addr);
self.send_server_list(from, p.filter.key, iter)?;
} else {
let now = self.now();
let iter = self
.servers
.iter()
.filter(|i| i.1.is_valid(now, self.timeout.server))
.filter(|i| i.1.matches(*i.0, p.region, &p.filter))
.map(|i| *i.0);
self.send_server_list(from, p.filter.key, iter.clone())?;
if p.filter.flags.contains(FilterFlags::NAT) {
self.send_client_to_nat_servers(from, iter)?;
}
}
}
game::Packet::GetServerInfo(p) => {
trace!("{}: recv {:?}", from, p);
let p = server::GetServerInfoResponse {
map: self.update_map.as_ref(),
host: self.update_title.as_ref(),
protocol: 48, // XXX: how to detect what version client will accept?
dm: true,
maxcl: 32,
gamedir: "valve", // XXX: probably must be specific for client...
..Default::default()
};
trace!("{}: send {:?}", from, p);
let mut buf = [0; MAX_PACKET_SIZE];
let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?;
}
}
} else if let Ok(p) = admin::Packet::decode(self.hash.len, src) {
let now = self.now();
if let Some(e) = self.admin_limit.get(from.ip()) {
if e.is_valid(now, self.timeout.admin) {
trace!("{}: rate limit", from);
return Ok(());
}
}
match p {
admin::Packet::AdminChallenge(p) => {
trace!("{}: recv {:?}", from, p);
let (master_challenge, hash_challenge) = self.admin_challenge_add(from);
let p = master::AdminChallengeResponse::new(master_challenge, hash_challenge);
trace!("{}: send {:?}", from, p);
let mut buf = [0; 64];
let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?;
self.admin_challenges_cleanup();
}
admin::Packet::AdminCommand(p) => {
trace!("{}: recv {:?}", from, p);
let entry = *self
.admin_challenges
.get(from.ip())
.ok_or(Error::AdminChallengeNotFound)?;
if entry.0 != p.master_challenge {
trace!("{}: master challenge is not valid", from);
return Ok(());
}
if !entry.is_valid(now, self.timeout.challenge) {
trace!("{}: challenge is outdated", from);
return Ok(());
}
let state = Params::new()
.hash_length(self.hash.len)
.key(self.hash.key.as_bytes())
.personal(self.hash.personal.as_bytes())
.to_state();
let admin = self.admin_list.iter().find(|i| {
let hash = state
.clone()
.update(i.password.as_bytes())
.update(&entry.1.to_le_bytes())
.finalize();
*p.hash == hash.as_bytes()
});
match admin {
Some(admin) => {
info!("{}: admin({}), command: {:?}", from, &admin.name, p.command);
self.admin_command(p.command);
self.admin_challenge_remove(from);
}
None => {
warn!("{}: invalid admin hash, command: {:?}", from, p.command);
self.admin_limit.insert(*from.ip(), Entry::new(now, ()));
self.admin_limit_cleanup();
}
}
}
}
} else {
debug!("{}: invalid packet: \"{}\"", from, Str(src));
match server::Packet::decode(src) {
Ok(Some(p)) => return self.handle_server_packet(from, p),
Ok(None) => {}
Err(e) => Err(e)?,
}
Ok(())
match game::Packet::decode(src) {
Ok(Some(p)) => return self.handle_game_packet(from, p),
Ok(None) => {}
Err(e) => Err(e)?,
}
match admin::Packet::decode(self.hash.len, src) {
Ok(Some(p)) => return self.handle_admin_packet(from, p),
Ok(None) => {}
Err(e) => Err(e)?,
}
Err(Error::UndefinedPacket)
}
fn now(&self) -> u32 {

View File

@ -5,7 +5,7 @@
use crate::cursor::{Cursor, CursorMut};
use crate::types::Hide;
use crate::Error;
use crate::{CursorError, Error};
/// Default hash length.
pub const HASH_LEN: usize = 64;
@ -27,7 +27,7 @@ impl AdminChallenge {
if src == Self::HEADER {
Ok(Self)
} else {
Err(Error::InvalidPacket)
Err(CursorError::Expect)?
}
}
@ -97,23 +97,22 @@ impl<'a> AdminCommand<'a> {
#[derive(Clone, Debug, PartialEq)]
pub enum Packet<'a> {
/// Admin challenge request.
AdminChallenge(AdminChallenge),
AdminChallenge,
/// Admin command.
AdminCommand(AdminCommand<'a>),
}
impl<'a> Packet<'a> {
/// Decode packet from `src` with specified hash length.
pub fn decode(hash_len: usize, src: &'a [u8]) -> Result<Self, Error> {
if let Ok(p) = AdminChallenge::decode(src) {
return Ok(Self::AdminChallenge(p));
pub fn decode(hash_len: usize, src: &'a [u8]) -> Result<Option<Self>, Error> {
if src.starts_with(AdminChallenge::HEADER) {
AdminChallenge::decode(src).map(|_| Self::AdminChallenge)
} else if src.starts_with(AdminCommand::HEADER) {
AdminCommand::decode_with_hash_len(hash_len, src).map(Self::AdminCommand)
} else {
return Ok(None);
}
if let Ok(p) = AdminCommand::decode_with_hash_len(hash_len, src) {
return Ok(Self::AdminCommand(p));
}
Err(Error::InvalidPacket)
.map(Some)
}
}
@ -126,7 +125,10 @@ mod tests {
let p = AdminChallenge;
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(AdminChallenge::decode(&buf[..n]), Ok(p));
assert_eq!(
Packet::decode(HASH_LEN, &buf[..n]),
Ok(Some(Packet::AdminChallenge))
);
}
#[test]
@ -134,6 +136,9 @@ mod tests {
let p = AdminCommand::new(0x12345678, &[1; HASH_LEN], "foo bar baz");
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(AdminCommand::decode(&buf[..n]), Ok(p));
assert_eq!(
Packet::decode(HASH_LEN, &buf[..n]),
Ok(Some(Packet::AdminCommand(p)))
);
}
}

View File

@ -7,8 +7,42 @@ use std::mem;
use std::slice;
use std::str;
use thiserror::Error;
use super::color;
use super::types::Str;
use super::{color, Error};
/// The error type for `Cursor` and `CursorMut`.
#[derive(Error, Debug, PartialEq, Eq)]
pub enum Error {
/// Invalid number.
#[error("Invalid number")]
InvalidNumber,
/// Invalid string.
#[error("Invalid string")]
InvalidString,
/// Invalid boolean.
#[error("Invalid boolean")]
InvalidBool,
/// Invalid table entry.
#[error("Invalid table key")]
InvalidTableKey,
/// Invalid table entry.
#[error("Invalid table entry")]
InvalidTableValue,
/// Table end found.
#[error("Table end")]
TableEnd,
/// Expected data not found.
#[error("Expected data not found")]
Expect,
/// An unexpected data found.
#[error("Unexpected data")]
ExpectEmpty,
/// Buffer size is no enougth to decode or encode a packet.
#[error("Unexpected end of buffer")]
UnexpectedEnd,
}
pub trait GetKeyValue<'a>: Sized {
fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error>;
@ -56,7 +90,7 @@ impl<'a> GetKeyValue<'a> for bool {
match cur.get_key_value_raw()? {
b"0" => Ok(false),
b"1" => Ok(true),
_ => Err(Error::InvalidPacket),
_ => Err(Error::InvalidBool),
}
}
}
@ -68,7 +102,7 @@ macro_rules! impl_get_value {
let s = cur.get_key_value::<&str>()?;
// HACK: special case for one asshole
let (_, s) = color::trim_start_color(s);
s.parse().map_err(|_| Error::InvalidPacket)
s.parse().map_err(|_| Error::InvalidNumber)
}
})+
};
@ -216,13 +250,13 @@ impl<'a> Cursor<'a> {
self.advance(s.len())?;
Ok(())
} else {
Err(Error::InvalidPacket)
Err(Error::Expect)
}
}
pub fn expect_empty(&self) -> Result<(), Error> {
if self.has_remaining() {
Err(Error::InvalidPacket)
Err(Error::ExpectEmpty)
} else {
Ok(())
}
@ -252,12 +286,13 @@ impl<'a> Cursor<'a> {
pub fn get_key_value_raw(&mut self) -> Result<&'a [u8], Error> {
let mut cur = *self;
if cur.get_u8()? == b'\\' {
let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n');
*self = cur;
Ok(value)
} else {
Err(Error::InvalidPacket)
match cur.get_u8()? {
b'\\' => {
let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n');
*self = cur;
Ok(value)
}
_ => Err(Error::InvalidTableValue),
}
}
@ -265,14 +300,20 @@ impl<'a> Cursor<'a> {
T::get_key_value(self)
}
pub fn skip_key_value<T: GetKeyValue<'a>>(&mut self) -> Result<(), Error> {
T::get_key_value(self).map(|_| ())
}
pub fn get_key_raw(&mut self) -> Result<&'a [u8], Error> {
let mut cur = *self;
if cur.get_u8()? == b'\\' {
let value = cur.take_while(|c| c != b'\\' && c != b'\n')?;
*self = cur;
Ok(value)
} else {
Err(Error::InvalidPacket)
match cur.get_u8() {
Ok(b'\\') => {
let value = cur.take_while(|c| c != b'\\' && c != b'\n')?;
*self = cur;
Ok(value)
}
Ok(b'\n') | Err(Error::UnexpectedEnd) => Err(Error::TableEnd),
_ => Err(Error::InvalidTableKey),
}
}
@ -288,6 +329,18 @@ pub trait PutKeyValue {
) -> Result<&'b mut CursorMut<'a>, Error>;
}
impl<T> PutKeyValue for &T
where
T: PutKeyValue,
{
fn put_key_value<'a, 'b>(
&self,
cur: &'b mut CursorMut<'a>,
) -> Result<&'b mut CursorMut<'a>, Error> {
(*self).put_key_value(cur)
}
}
impl PutKeyValue for &str {
fn put_key_value<'a, 'b>(
&self,
@ -532,7 +585,7 @@ mod tests {
assert_eq!(cur.get_key(), Ok((&b"gamedir"[..], "valve")));
assert_eq!(cur.get_key(), Ok((&b"password"[..], false)));
assert_eq!(cur.get_key(), Ok((&b"host"[..], "test")));
assert_eq!(cur.get_key::<&[u8]>(), Err(Error::UnexpectedEnd));
assert_eq!(cur.get_key::<&[u8]>(), Err(Error::TableEnd));
Ok(())
}

View File

@ -31,7 +31,6 @@
use std::fmt;
use std::net::SocketAddrV4;
use std::num::ParseIntError;
use std::str::FromStr;
use bitflags::bitflags;
@ -40,7 +39,7 @@ use log::debug;
use crate::cursor::{Cursor, GetKeyValue, PutKeyValue};
use crate::server::{ServerAdd, ServerFlags, ServerType};
use crate::types::Str;
use crate::{Error, ServerInfo};
use crate::{CursorError, Error, ServerInfo};
bitflags! {
/// Additional filter flags.
@ -129,21 +128,21 @@ impl fmt::Display for Version {
}
impl FromStr for Version {
type Err = ParseIntError;
type Err = CursorError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (major, tail) = s.split_once('.').unwrap_or((s, "0"));
let (minor, patch) = tail.split_once('.').unwrap_or((tail, "0"));
let major = major.parse()?;
let minor = minor.parse()?;
let patch = patch.parse()?;
let major = major.parse().map_err(|_| CursorError::InvalidNumber)?;
let minor = minor.parse().map_err(|_| CursorError::InvalidNumber)?;
let patch = patch.parse().map_err(|_| CursorError::InvalidNumber)?;
Ok(Self::with_patch(major, minor, patch))
}
}
impl GetKeyValue<'_> for Version {
fn get_key_value(cur: &mut Cursor) -> Result<Self, Error> {
Self::from_str(cur.get_key_value()?).map_err(|_| Error::InvalidPacket)
fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
cur.get_key_value().and_then(Self::from_str)
}
}
@ -151,7 +150,7 @@ impl PutKeyValue for Version {
fn put_key_value<'a, 'b>(
&self,
cur: &'b mut crate::cursor::CursorMut<'a>,
) -> Result<&'b mut crate::cursor::CursorMut<'a>, Error> {
) -> Result<&'b mut crate::cursor::CursorMut<'a>, CursorError> {
cur.put_key_value(self.major)?
.put_u8(b'.')?
.put_key_value(self.minor)?;
@ -201,42 +200,48 @@ impl<'a> TryFrom<&'a [u8]> for Filter<'a> {
type Error = Error;
fn try_from(src: &'a [u8]) -> Result<Self, Self::Error> {
trait Helper<'a> {
fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error>;
}
impl<'a> Helper<'a> for Cursor<'a> {
fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error> {
T::get_key_value(self).map_err(|e| Error::InvalidFilterValue(key, e))
}
}
let mut cur = Cursor::new(src);
let mut filter = Self::default();
loop {
let key = match cur.get_key_raw().map(Str) {
Ok(s) => s,
Err(Error::UnexpectedEnd) => break,
Err(e) => return Err(e),
Err(CursorError::TableEnd) => break,
Err(e) => Err(e)?,
};
match *key {
b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get_key_value()?),
b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get_key_value()?),
b"gamedir" => filter.gamedir = Some(cur.get_key_value()?),
b"map" => filter.map = Some(cur.get_key_value()?),
b"protocol" => filter.protocol = Some(cur.get_key_value()?),
b"empty" => filter.insert_flag(FilterFlags::EMPTY, cur.get_key_value()?),
b"full" => filter.insert_flag(FilterFlags::FULL, cur.get_key_value()?),
b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get_key_value()?),
b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get_key_value()?),
b"clver" => {
filter.clver = Some(
cur.get_key_value::<&str>()?
.parse()
.map_err(|_| Error::InvalidPacket)?,
);
}
b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get_key_value()?),
b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get_key_value()?),
b"bots" => filter.insert_flag(FilterFlags::BOTS, cur.get_key_value()?),
b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get("dedicated")?),
b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get("secure")?),
b"gamedir" => filter.gamedir = Some(cur.get("gamedir")?),
b"map" => filter.map = Some(cur.get("map")?),
b"protocol" => filter.protocol = Some(cur.get("protocol")?),
b"empty" => filter.insert_flag(FilterFlags::EMPTY, cur.get("empty")?),
b"full" => filter.insert_flag(FilterFlags::FULL, cur.get("full")?),
b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get("password")?),
b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get("noplayers")?),
b"clver" => filter.clver = Some(cur.get("clver")?),
b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get("nat")?),
b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get("lan")?),
b"bots" => filter.insert_flag(FilterFlags::BOTS, cur.get("bots")?),
b"key" => {
filter.key = {
let s = cur.get_key_value::<&str>()?;
let x = u32::from_str_radix(s, 16).map_err(|_| Error::InvalidPacket)?;
Some(x)
}
filter.key = Some(
cur.get_key_value::<&str>()
.and_then(|s| {
u32::from_str_radix(s, 16).map_err(|_| CursorError::InvalidNumber)
})
.map_err(|e| Error::InvalidFilterValue("key", e))?,
)
}
_ => {
// skip unknown fields

View File

@ -35,7 +35,7 @@ where
pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
let mut cur = Cursor::new(src);
cur.expect(QueryServers::HEADER)?;
let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidPacket)?;
let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidRegion)?;
let last = cur.get_cstr_as_str()?;
let filter = match cur.get_bytes(cur.remaining())? {
// some clients may have bug and filter will be with zero at the end
@ -44,7 +44,7 @@ where
};
Ok(Self {
region,
last: last.parse().map_err(|_| Error::InvalidPacket)?,
last: last.parse().map_err(|_| Error::InvalidQueryServersLast)?,
filter: T::try_from(filter)?,
})
}
@ -114,16 +114,15 @@ pub enum Packet<'a> {
impl<'a> Packet<'a> {
/// Decode packet from `src`.
pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
if let Ok(p) = QueryServers::decode(src) {
return Ok(Self::QueryServers(p));
pub fn decode(src: &'a [u8]) -> Result<Option<Self>, Error> {
if src.starts_with(QueryServers::HEADER) {
QueryServers::decode(src).map(Self::QueryServers)
} else if src.starts_with(GetServerInfo::HEADER) {
GetServerInfo::decode(src).map(Self::GetServerInfo)
} else {
return Ok(None);
}
if let Ok(p) = GetServerInfo::decode(src) {
return Ok(Self::GetServerInfo(p));
}
Err(Error::InvalidPacket)
.map(Some)
}
}
@ -151,7 +150,7 @@ mod tests {
};
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(QueryServers::decode(&buf[..n]), Ok(p));
assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::QueryServers(p))));
}
#[test]
@ -171,10 +170,10 @@ mod tests {
};
let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0\0";
assert_eq!(QueryServers::decode(s), Ok(p.clone()));
assert_eq!(Packet::decode(s), Ok(Some(Packet::QueryServers(p.clone()))));
let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0";
assert_eq!(QueryServers::decode(s), Ok(p));
assert_eq!(Packet::decode(s), Ok(Some(Packet::QueryServers(p))));
}
#[test]
@ -182,6 +181,9 @@ mod tests {
let p = GetServerInfo::new(49);
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(GetServerInfo::decode(&buf[..n]), Ok(p));
assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::GetServerInfo(p)))
);
}
}

View File

@ -16,6 +16,7 @@ pub mod master;
pub mod server;
pub mod types;
pub use cursor::Error as CursorError;
pub use server_info::ServerInfo;
use thiserror::Error;
@ -33,13 +34,25 @@ pub enum Error {
/// Failed to decode a packet.
#[error("Invalid packet")]
InvalidPacket,
/// Invalid string in a packet.
#[error("Invalid UTF-8 string")]
InvalidString,
/// Buffer size is no enougth to decode or encode a packet.
#[error("Unexpected end of buffer")]
UnexpectedEnd,
/// Invalid region.
#[error("Invalid region")]
InvalidRegion,
/// Invalid client announce IP.
#[error("Invalid client announce IP")]
InvalidClientAnnounceIp,
/// Invalid last IP.
#[error("Invalid last server IP")]
InvalidQueryServersLast,
/// Server protocol version is not supported.
#[error("Invalid protocol version")]
InvalidProtocolVersion,
/// Cursor error.
#[error("{0}")]
CursorError(#[from] CursorError),
/// Invalid value for server add packet.
#[error("Invalid value for server add key `{0}`: {1}")]
InvalidServerValue(&'static str, #[source] CursorError),
/// Invalid value for query servers packet.
#[error("Invalid value for filter key `{0}`: {1}")]
InvalidFilterValue(&'static str, #[source] CursorError),
}

View File

@ -174,7 +174,7 @@ impl ClientAnnounce {
let addr = cur
.get_str(cur.remaining())?
.parse()
.map_err(|_| Error::InvalidPacket)?;
.map_err(|_| Error::InvalidClientAnnounceIp)?;
cur.expect_empty()?;
Ok(Self { addr })
}
@ -247,24 +247,19 @@ pub enum Packet<'a> {
impl<'a> Packet<'a> {
/// Decode packet from `src`.
pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
if let Ok(p) = ChallengeResponse::decode(src) {
return Ok(Self::ChallengeResponse(p));
pub fn decode(src: &'a [u8]) -> Result<Option<Self>, Error> {
if src.starts_with(ChallengeResponse::HEADER) {
ChallengeResponse::decode(src).map(Self::ChallengeResponse)
} else if src.starts_with(QueryServersResponse::HEADER) {
QueryServersResponse::decode(src).map(Self::QueryServersResponse)
} else if src.starts_with(ClientAnnounce::HEADER) {
ClientAnnounce::decode(src).map(Self::ClientAnnounce)
} else if src.starts_with(AdminChallengeResponse::HEADER) {
AdminChallengeResponse::decode(src).map(Self::AdminChallengeResponse)
} else {
return Ok(None);
}
if let Ok(p) = QueryServersResponse::decode(src) {
return Ok(Self::QueryServersResponse(p));
}
if let Ok(p) = ClientAnnounce::decode(src) {
return Ok(Self::ClientAnnounce(p));
}
if let Ok(p) = AdminChallengeResponse::decode(src) {
return Ok(Self::AdminChallengeResponse(p));
}
Err(Error::InvalidPacket)
.map(Some)
}
}
@ -277,7 +272,10 @@ mod tests {
let p = ChallengeResponse::new(0x12345678, Some(0x87654321));
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(ChallengeResponse::decode(&buf[..n]), Ok(p));
assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::ChallengeResponse(p)))
);
}
#[test]
@ -291,7 +289,10 @@ mod tests {
let p = ChallengeResponse::new(0x12345678, None);
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(ChallengeResponse::decode(&buf[..n]), Ok(p));
assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::ChallengeResponse(p)))
);
}
#[test]
@ -314,7 +315,10 @@ mod tests {
let p = ClientAnnounce::new("1.2.3.4:12345".parse().unwrap());
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(ClientAnnounce::decode(&buf[..n]), Ok(p));
assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::ClientAnnounce(p)))
);
}
#[test]
@ -322,6 +326,9 @@ mod tests {
let p = AdminChallengeResponse::new(0x12345678, 0x87654321);
let mut buf = [0; 64];
let n = p.encode(&mut buf).unwrap();
assert_eq!(AdminChallengeResponse::decode(&buf[..n]), Ok(p));
assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::AdminChallengeResponse(p)))
);
}
}

View File

@ -11,7 +11,7 @@ use log::debug;
use super::cursor::{Cursor, CursorMut, GetKeyValue, PutKeyValue};
use super::filter::Version;
use super::types::Str;
use super::Error;
use super::{CursorError, Error};
/// Sended to a master server before `ServerAdd` packet.
#[derive(Clone, Debug, PartialEq)]
@ -74,7 +74,7 @@ impl Default for Os {
}
impl TryFrom<&[u8]> for Os {
type Error = Error;
type Error = CursorError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
match value {
@ -87,7 +87,7 @@ impl TryFrom<&[u8]> for Os {
}
impl GetKeyValue<'_> for Os {
fn get_key_value(cur: &mut Cursor) -> Result<Self, Error> {
fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
cur.get_key_value_raw()?.try_into()
}
}
@ -96,7 +96,7 @@ impl PutKeyValue for Os {
fn put_key_value<'a, 'b>(
&self,
cur: &'b mut CursorMut<'a>,
) -> Result<&'b mut CursorMut<'a>, Error> {
) -> Result<&'b mut CursorMut<'a>, CursorError> {
match self {
Self::Linux => cur.put_str("l"),
Self::Windows => cur.put_str("w"),
@ -139,7 +139,7 @@ impl Default for ServerType {
}
impl TryFrom<&[u8]> for ServerType {
type Error = Error;
type Error = CursorError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
match value {
@ -152,7 +152,7 @@ impl TryFrom<&[u8]> for ServerType {
}
impl GetKeyValue<'_> for ServerType {
fn get_key_value(cur: &mut Cursor) -> Result<Self, Error> {
fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
cur.get_key_value_raw()?.try_into()
}
}
@ -161,7 +161,7 @@ impl PutKeyValue for ServerType {
fn put_key_value<'a, 'b>(
&self,
cur: &'b mut CursorMut<'a>,
) -> Result<&'b mut CursorMut<'a>, Error> {
) -> Result<&'b mut CursorMut<'a>, CursorError> {
match self {
Self::Dedicated => cur.put_str("d"),
Self::Local => cur.put_str("l"),
@ -217,7 +217,7 @@ impl Default for Region {
}
impl TryFrom<u8> for Region {
type Error = Error;
type Error = CursorError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
@ -230,13 +230,13 @@ impl TryFrom<u8> for Region {
0x06 => Ok(Region::MiddleEast),
0x07 => Ok(Region::Africa),
0xff => Ok(Region::RestOfTheWorld),
_ => Err(Error::InvalidPacket),
_ => Err(CursorError::InvalidNumber),
}
}
}
impl GetKeyValue<'_> for Region {
fn get_key_value(cur: &mut Cursor) -> Result<Self, Error> {
fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
cur.get_key_value::<u8>()?.try_into()
}
}
@ -304,28 +304,38 @@ where
{
/// Decode packet from `src`.
pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
trait Helper<'a> {
fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error>;
}
impl<'a> Helper<'a> for Cursor<'a> {
fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error> {
T::get_key_value(self).map_err(|e| Error::InvalidServerValue(key, e))
}
}
let mut cur = Cursor::new(src);
cur.expect(ServerAdd::HEADER)?;
let mut ret = Self::default();
let mut challenge = None;
while cur.as_slice().starts_with(&[b'\\']) {
loop {
let key = match cur.get_key_raw() {
Ok(s) => s,
Err(Error::UnexpectedEnd) => break,
Err(e) => return Err(e),
Err(CursorError::TableEnd) => break,
Err(e) => Err(e)?,
};
match key {
b"protocol" => ret.protocol = cur.get_key_value()?,
b"challenge" => challenge = Some(cur.get_key_value()?),
b"players" => ret.players = cur.get_key_value()?,
b"max" => ret.max = cur.get_key_value()?,
b"gamedir" => ret.gamedir = cur.get_key_value()?,
b"product" => { let _ = cur.get_key_value::<Str<&[u8]>>()?; }, // legacy key, ignore
b"map" => ret.map = cur.get_key_value()?,
b"type" => ret.server_type = cur.get_key_value()?,
b"os" => ret.os = cur.get_key_value()?,
b"protocol" => ret.protocol = cur.get("protocol")?,
b"challenge" => challenge = Some(cur.get("challenge")?),
b"players" => ret.players = cur.get("players")?,
b"max" => ret.max = cur.get("max")?,
b"gamedir" => ret.gamedir = cur.get("gamedir")?,
b"product" => cur.skip_key_value::<&[u8]>()?, // legacy key, ignore
b"map" => ret.map = cur.get("map")?,
b"type" => ret.server_type = cur.get("type")?,
b"os" => ret.os = cur.get("os")?,
b"version" => {
ret.version = cur
.get_key_value()
@ -335,12 +345,14 @@ where
})
.unwrap_or_default()
}
b"region" => ret.region = cur.get_key_value()?,
b"bots" => ret.flags.set(ServerFlags::BOTS, cur.get_key_value::<u8>()? != 0),
b"password" => ret.flags.set(ServerFlags::PASSWORD, cur.get_key_value()?),
b"secure" => ret.flags.set(ServerFlags::SECURE, cur.get_key_value()?),
b"lan" => ret.flags.set(ServerFlags::LAN, cur.get_key_value()?),
b"nat" => ret.flags.set(ServerFlags::NAT, cur.get_key_value()?),
b"region" => ret.region = cur.get("region")?,
b"bots" => ret
.flags
.set(ServerFlags::BOTS, cur.get::<u8>("bots")? != 0),
b"password" => ret.flags.set(ServerFlags::PASSWORD, cur.get("password")?),
b"secure" => ret.flags.set(ServerFlags::SECURE, cur.get("secure")?),
b"lan" => ret.flags.set(ServerFlags::LAN, cur.get("lan")?),
b"nat" => ret.flags.set(ServerFlags::NAT, cur.get("nat")?),
_ => {
// skip unknown fields
let value = cur.get_key_value::<Str<&[u8]>>()?;
@ -354,14 +366,14 @@ where
ret.challenge = c;
Ok(ret)
}
None => Err(Error::InvalidPacket),
None => Err(Error::InvalidServerValue("challenge", CursorError::Expect)),
}
}
}
impl<T> ServerAdd<T>
where
T: PutKeyValue + Clone,
T: PutKeyValue,
{
/// Encode packet to `buf`.
pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> {
@ -371,8 +383,8 @@ where
.put_key("challenge", self.challenge)?
.put_key("players", self.players)?
.put_key("max", self.max)?
.put_key("gamedir", self.gamedir.clone())?
.put_key("map", self.map.clone())?
.put_key("gamedir", &self.gamedir)?
.put_key("map", &self.map)?
.put_key("type", self.server_type)?
.put_key("os", self.os)?
.put_key("version", self.version)?
@ -469,8 +481,8 @@ where
loop {
let key = match cur.get_key_raw() {
Ok(s) => s,
Err(Error::UnexpectedEnd) => break,
Err(e) => return Err(e),
Err(CursorError::TableEnd) => break,
Err(e) => Err(e)?,
};
match key {
@ -500,21 +512,24 @@ where
}
}
impl<'a> GetServerInfoResponse<&'a str> {
impl<T> GetServerInfoResponse<T>
where
T: PutKeyValue,
{
/// Encode packet to `buf`.
pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> {
Ok(CursorMut::new(buf)
.put_bytes(GetServerInfoResponse::HEADER)?
.put_key("p", self.protocol)?
.put_key("map", self.map)?
.put_key("map", &self.map)?
.put_key("dm", self.dm)?
.put_key("team", self.team)?
.put_key("coop", self.coop)?
.put_key("numcl", self.numcl)?
.put_key("maxcl", self.maxcl)?
.put_key("gamedir", self.gamedir)?
.put_key("gamedir", &self.gamedir)?
.put_key("password", self.password)?
.put_key("host", self.host)?
.put_key("host", &self.host)?
.pos())
}
}
@ -534,24 +549,19 @@ pub enum Packet<'a> {
impl<'a> Packet<'a> {
/// Decode packet from `src`.
pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
if let Ok(p) = Challenge::decode(src) {
return Ok(Self::Challenge(p));
pub fn decode(src: &'a [u8]) -> Result<Option<Self>, Error> {
if src.starts_with(Challenge::HEADER) {
Challenge::decode(src).map(Self::Challenge)
} else if src.starts_with(ServerAdd::HEADER) {
ServerAdd::decode(src).map(Self::ServerAdd)
} else if src.starts_with(ServerRemove::HEADER) {
ServerRemove::decode(src).map(|_| Self::ServerRemove)
} else if src.starts_with(GetServerInfoResponse::HEADER) {
GetServerInfoResponse::decode(src).map(Self::GetServerInfoResponse)
} else {
return Ok(None);
}
if let Ok(p) = ServerAdd::decode(src) {
return Ok(Self::ServerAdd(p));
}
if ServerRemove::decode(src).is_ok() {
return Ok(Self::ServerRemove);
}
if let Ok(p) = GetServerInfoResponse::decode(src) {
return Ok(Self::GetServerInfoResponse(p));
}
Err(Error::InvalidPacket)
.map(Some)
}
}
@ -564,13 +574,16 @@ mod tests {
let p = Challenge::new(Some(0x12345678));
let mut buf = [0; 128];
let n = p.encode(&mut buf).unwrap();
assert_eq!(Challenge::decode(&buf[..n]), Ok(p));
assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::Challenge(p))));
}
#[test]
fn challenge_old() {
let s = b"q\xff";
assert_eq!(Challenge::decode(s), Ok(Challenge::new(None)));
assert_eq!(
Packet::decode(s),
Ok(Some(Packet::Challenge(Challenge::new(None))))
);
let p = Challenge::new(None);
let mut buf = [0; 128];
@ -581,8 +594,8 @@ mod tests {
#[test]
fn server_add() {
let p = ServerAdd {
gamedir: "valve",
map: "crossfire",
gamedir: Str(&b"valve"[..]),
map: Str(&b"crossfire"[..]),
version: Version::new(0, 20),
challenge: 0x12345678,
server_type: ServerType::Dedicated,
@ -595,7 +608,7 @@ mod tests {
};
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(ServerAdd::decode(&buf[..n]), Ok(p));
assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::ServerAdd(p))));
}
#[test]
@ -603,26 +616,29 @@ mod tests {
let p = ServerRemove;
let mut buf = [0; 64];
let n = p.encode(&mut buf).unwrap();
assert_eq!(ServerRemove::decode(&buf[..n]), Ok(p));
assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::ServerRemove)));
}
#[test]
fn get_server_info_response() {
let p = GetServerInfoResponse {
protocol: 49,
map: "crossfire",
map: Str("crossfire".as_bytes()),
dm: true,
team: true,
coop: true,
numcl: 4,
maxcl: 32,
gamedir: "valve",
gamedir: Str("valve".as_bytes()),
password: true,
host: "Test",
host: Str("Test".as_bytes()),
};
let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap();
assert_eq!(GetServerInfoResponse::decode(&buf[..n]), Ok(p));
assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::GetServerInfoResponse(p)))
);
}
#[test]

View File

@ -6,6 +6,9 @@
use std::fmt;
use std::ops::Deref;
use crate::cursor::{CursorMut, PutKeyValue};
use crate::CursorError;
/// Wrapper for slice of bytes with printing the bytes as a string.
///
/// # Examples
@ -24,6 +27,15 @@ impl<T> From<T> for Str<T> {
}
}
impl PutKeyValue for Str<&[u8]> {
fn put_key_value<'a, 'b>(
&self,
cur: &'b mut CursorMut<'a>,
) -> Result<&'b mut CursorMut<'a>, CursorError> {
cur.put_bytes(self.0)
}
}
impl<T> fmt::Debug for Str<T>
where
T: AsRef<[u8]>,