Merge pull request #9 from Kkevsterrr/GEN-7_tests

Added tests and code coverage for repository
This commit is contained in:
Kevin Bock 2019-12-13 10:13:32 -05:00 committed by GitHub
commit b36cb7f438
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 2494 additions and 205 deletions

16
.coveragerc Normal file
View File

@ -0,0 +1,16 @@
[run]
branch = True
source = ./
[report]
exclude_lines =
if self.debug:
pragma: no cover
raise NotImplementedError
if __name__ == .__main__.:
def get_args
def main
ignore_errors = True
omit =
tests/*
examples/*

42
.travis.yml Normal file
View File

@ -0,0 +1,42 @@
sudo: required
dist: "bionic"
language: python
python:
- "3.6"
install:
# Travis recently added systemd-resolvd to their VMs. Since full Geneva often runs its own DNS
# server to test DNS strategies, we need to disable system-resolvd.
# First disable the service
- sudo systemctl disable systemd-resolved.service
# Stop the service
- sudo systemctl stop systemd-resolved
# With systemd not running, our own hostname won't resolve - this causes issues with sudo.
# Add back our hostname to /etc/hosts/ so sudo does not complain
- echo $(hostname -I | cut -d\ -f1) $(hostname) | sudo tee -a /etc/hosts
# Replace the 127.0.0.53 nameserver with Google's
- sudo sed 's/nameserver.*/nameserver 8.8.8.8/' /etc/resolv.conf > /tmp/resolv.conf.new
- sudo mv /tmp/resolv.conf.new /etc/resolv.conf
# Now that systemd-resolv.conf is safely disabled, we can now setup for Geneva
- sudo apt-get clean # travis having mirror sync issues
# Install dependencies
- sudo apt-get update
- sudo apt-get -y install libnetfilter-queue-dev python3 python3-pip python3-setuptools graphviz
# Since sudo is required but travis does not set up the root environment, we must override the
# secure_path in sudoers in order for travis's setup to take effect for sudo commands
- printf "Defaults\tenv_reset\nDefaults\tmail_badpass\nDefaults\tsecure_path="/home/travis/virtualenv/python3.6.7/bin/:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin"\nroot\tALL=(ALL:ALL) ALL\n#includedir /etc/sudoers.d\n" > /tmp/sudoers.tmp
# Verify the sudoers file
- sudo visudo -c -f /tmp/sudoers.tmp
# Copy in the sudoers file
- sudo cp /tmp/sudoers.tmp /etc/sudoers
# Now that sudo is good to go, finish installing dependencies
- sudo python3 -m pip install -r requirements.txt
- sudo python3 -m pip install slackclient pytest-cov
script:
- sudo python3 -m pytest --cov=./ -sv tests/ --tb=short
after_script:
- bash <(curl -s https://codecov.io/bash) -t 83a45966-78ce-44c2-80b3-964ecab4a53d || echo "Codecov did not collect coverage reports"

View File

@ -1,4 +1,4 @@
# Geneva # Geneva [![Build Status](https://travis-ci.com/Kkevsterrr/geneva.svg?branch=master)](https://travis-ci.com/Kkevsterrr/geneva) [![codecov](https://codecov.io/gh/Kkevsterrr/geneva/branch/master/graph/badge.svg)](https://codecov.io/gh/Kkevsterrr/geneva)
Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors, such as those in China, India, and Kazakhstan. Unlike many other anti-censorship solutions which require assistance from outside the censoring regime (Tor, VPNs, etc.), Geneva runs strictly on the client. Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors, such as those in China, India, and Kazakhstan. Unlike many other anti-censorship solutions which require assistance from outside the censoring regime (Tor, VPNs, etc.), Geneva runs strictly on the client.

View File

@ -13,9 +13,3 @@ class DuplicateAction(Action):
""" """
logger.debug(" - Duplicating given packet %s" % str(packet)) logger.debug(" - Duplicating given packet %s" % str(packet))
return packet, packet.copy() return packet, packet.copy()
def mutate(self, environment_id=None):
"""
Swaps its left and right child
"""
self.left, self.right = self.right, self.left

View File

@ -196,22 +196,3 @@ class FragmentAction(Action):
self.correct_order = False self.correct_order = False
return True return True
def mutate(self, environment_id=None):
"""
Mutates the fragment action - it either chooses a new segment offset,
switches the packet order, and/or changes whether it segments or fragments.
"""
self.correct_order = self.get_rand_order()
self.segment = random.choice([True, True, True, False])
if self.segment:
if random.random() < 0.5:
self.fragsize = int(random.uniform(1, 60))
else:
self.fragsize = -1
else:
if random.random() < 0.2:
self.fragsize = int(random.uniform(1, 50))
else:
self.fragsize = -1
return self

View File

@ -4,7 +4,8 @@ import string
import os import os
import urllib.parse import urllib.parse
from scapy.all import IP, RandIP, UDP, Raw, TCP, fuzz from scapy.all import IP, RandIP, UDP, DNS, DNSQR, Raw, TCP, fuzz
class Layer(): class Layer():
""" """
@ -179,6 +180,12 @@ class Layer():
value = urllib.parse.unquote(value) value = urllib.parse.unquote(value)
value = value.encode('utf-8') value = value.encode('utf-8')
# Add support for injecting arbitrary protocol payloads if requested
dns_payload = b"\x009ib\x81\x80\x00\x01\x00\x01\x00\x00\x00\x01\x08examples\x03com\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x01+\x00\x04\xc7\xbf2I\x00\x00)\x02\x00\x00\x00\x00\x00\x00\x00"
http_payload = b"GET / HTTP/1.1\r\nHost: www.example.com\r\n\r\n"
value = value.replace(b"__DNS_REQUEST__", dns_payload)
value = value.replace(b"__HTTP_REQUEST__", http_payload)
self.layer.payload = Raw(value) self.layer.payload = Raw(value)
@ -592,3 +599,237 @@ class UDPLayer(Layer):
self.generators = { self.generators = {
'load' : self.gen_load, 'load' : self.gen_load,
} }
class DNSLayer(Layer):
"""
Defines an interface to access DNS header fields.
"""
name = "DNS"
protocol = DNS
_fields = [
"id",
"qr",
"opcode",
"aa",
"tc",
"rd",
"ra",
"z",
"ad",
"cd",
"qd",
"rcode",
"qdcount",
"ancount",
"nscount",
"arcount"
]
fields = _fields
def __init__(self, layer):
"""
Initializes the DNS layer.
"""
Layer.__init__(self, layer)
self.getters = {
"qr" : self.get_bitfield,
"aa" : self.get_bitfield,
"tc" : self.get_bitfield,
"rd" : self.get_bitfield,
"ra" : self.get_bitfield,
"z" : self.get_bitfield,
"ad" : self.get_bitfield,
"cd" : self.get_bitfield
}
self.setters = {
"qr" : self.set_bitfield,
"aa" : self.set_bitfield,
"tc" : self.set_bitfield,
"rd" : self.set_bitfield,
"ra" : self.set_bitfield,
"z" : self.set_bitfield,
"ad" : self.set_bitfield,
"cd" : self.set_bitfield
}
self.generators = {
"id" : self.gen_id,
"qr" : self.gen_bitfield,
"opcode" : self.gen_opcode,
"aa" : self.gen_bitfield,
"tc" : self.gen_bitfield,
"rd" : self.gen_bitfield,
"ra" : self.gen_bitfield,
"z" : self.gen_bitfield,
"ad" : self.gen_bitfield,
"cd" : self.gen_bitfield,
"rcode" : self.gen_rcode,
"qdcount" : self.gen_count,
"ancount" : self.gen_count,
"nscount" : self.gen_count,
"arcount" : self.gen_count
}
def get_bitfield(self, field):
""""""
return int(getattr(self.layer, field))
def set_bitfield(self, packet, field, value):
""""""
return setattr(self.layer, field, int(value))
def gen_bitfield(self, field):
""""""
return random.choice([0,1])
def gen_id(self, field):
return random.randint(0, 65535)
def gen_opcode(self, field):
return random.randint(0, 15)
def gen_rcode(self, field):
return random.randint(0, 15)
def gen_count(self, field):
return random.randint(0, 65535)
@staticmethod
def dns_decompress(packet, logger):
"""
Performs DNS decompression on the given scapy packet, if applicable.
Note that DNS compression/decompression must be done on the boundaries
of a label, so DNS compression does not support arbitrary offsets.
"""
# If this is a TCP packet
if packet.haslayer("TCP"):
raise NotImplementedError
# Perform no action if this is not a DNS or DNSRQ packet
if not packet.haslayer("DNS") or not packet.haslayer("DNSQR"):
return packet
# Extract the query from the DNSQR layer
query = packet["DNSQR"].qname.decode()
if query[len(query) - 1] != '.':
query += '.'
# Split the query by label
labels = query.split(".")
# Collect the first and second half of the query
fhalf = labels[0]
shalf = ".".join(labels[1:])
# Build the first DNS query directly. The format of this a byte string like this:
# b'\x07minghui\xc0\x1a\x00\x01\x00\x01'
# \x07 = the length of the label in this DNSQR
# minghui = the portion of the domain we will request in the first DNSQR
# \xc0\x1a = offset into the DNS packet where the rest of the query will be. The actual offset
# here is the \x1a - DNS mandates that if compression is used, the first two bits be 11
# to differentiate them from the rest. \x1A = 26, which is the length of the DNS header
# plus the length of this DNSQR.
# \x00\x01 = type A record
# \x00\x01 = IN
length = bytes([len(fhalf)])
label = fhalf.encode()
# Since the domain will include an extra ".", add 1
# 2 * 6 is the DNS header
# 1 is the byte that determines the length of the label
# len(label) is the length of the label
# 2 is the offset pointer
# 4 - other record information (class, IN)
packet_offset = 2 * 6 + 1 + len(label) + 2 + 2 + 2
# The word must start with binary 11, so OR the offset with 0xC000.
offset = (0xc000 | packet_offset).to_bytes(2, byteorder='big')
request = b'\x00\x01\x00\x01'
dns_qr1 = length + label + offset + request
# Build the second DNS query directly. The format of the byte string is the same as above
# b'\x02ca\x00\x00\x01\x00\x01'
# \x02 = length of the remaining domain
# ca = portion of the domain in this DNSQR
# \x00 = null byte to signify the end of the query
# \x00\x01 = type A record
# \x00\x01 = IN
# Since the second half could potentially contain many labels, this is done in a list comprehension
dns_qr2 = b"".join([bytes([len(tld)]) + tld.encode() for tld in shalf.split(".")]) + b"\x00\x01\x00\x01"
# Next, we must rebuild the DNS packet itself. If we try to have scapy parse either dns_qr1 or dns_qr2, they
# will look malformed, since neither contains a complete request. Therefore, we must build the entire
# DNS packet at once. First, we must remove the original DNSQR, since this contains the original request
del packet["DNS"].qd
# Once the DNSQR is removed, scapy automatically sets the qdcount to 0. Adjust it to 2
packet["DNS"].qdcount = 2
# Extract the DNS header standalone now for building
dns_header = bytes(packet["DNS"])
dns_packet = DNS(dns_header + dns_qr1 + dns_qr2)
del packet["DNS"]
packet = packet / dns_packet
# Since the size and data of the packet have changed, force scapy to recalculate the important fields
# in below layers, if applicable
if packet.haslayer("IP"):
del packet["IP"].chksum
del packet["IP"].len
if packet.haslayer("UDP"):
del packet["UDP"].chksum
del packet["UDP"].len
return packet
class DNSQRLayer(Layer):
"""
Defines an interface to access DNSQR header fields.
"""
name = "DNSQR"
protocol = DNSQR
_fields = [
"qname",
"qtype",
"qclass"
]
fields = _fields
def __init__(self, layer):
"""
Initializes the DNS layer.
"""
Layer.__init__(self, layer)
self.getters = {
"qname" : self.get_qname
}
self.generators = {
"qname" : self.gen_qname
}
def get_qname(self, field):
"""
Returns decoded qname from packet.
"""
return self.layer.qname.decode('utf-8')
def gen_qname(self, field):
"""
Generates domain name.
"""
return "example.com."
@classmethod
def name_matches(cls, name):
"""
Scapy returns the name of DNSQR as _both_ DNSQR and "DNS Question Record",
which breaks parsing. Override the name_matches method to handle that case
here.
"""
return name.upper() in ["DNSQR", "DNS QUESTION RECORD"]

View File

@ -7,7 +7,9 @@ import actions.layer
_SUPPORTED_LAYERS = [ _SUPPORTED_LAYERS = [
actions.layer.IPLayer, actions.layer.IPLayer,
actions.layer.TCPLayer, actions.layer.TCPLayer,
actions.layer.UDPLayer actions.layer.UDPLayer,
actions.layer.DNSLayer,
actions.layer.DNSQRLayer
] ]
SUPPORTED_LAYERS = _SUPPORTED_LAYERS SUPPORTED_LAYERS = _SUPPORTED_LAYERS
@ -64,9 +66,25 @@ class Packet():
@staticmethod @staticmethod
def _str_load(packet, protocol): def _str_load(packet, protocol):
""" """
Prints packet payload Prints DNS header for now
""" """
return str(packet[protocol].payload) if packet.haslayer("DNS") and packet.haslayer("DNSQR"):
res = "%s:%s:%s " % (
packet["DNSQR"].qname.decode('utf8'),
str(packet["DNSQR"].qtype),
str(packet["DNSQR"].qclass))
DNS_res = ""
for i in range(packet["DNS"].ancount):
dnsrr = packet["DNS"].an[i]
DNS_res += " " + ':'.join([str(dnsrr.rrname.decode('utf8')),
str(dnsrr.type),
str(dnsrr.rclass),
str(dnsrr.ttl),
str(dnsrr.rdlen),
str(dnsrr.rdata)])
return "%s %s" % (res, DNS_res)
else:
return str(packet[protocol].payload)
def __bytes__(self): def __bytes__(self):
""" """
@ -238,3 +256,77 @@ class Packet():
return layer return layer
return None return None
@staticmethod
def reset_restrictions():
"""
Removes layer and field restrictions.
"""
global SUPPORTED_LAYERS, _SUPPORTED_LAYERS
SUPPORTED_LAYERS = _SUPPORTED_LAYERS
for layer in SUPPORTED_LAYERS:
layer.reset_restrictions()
@staticmethod
def restrict_fields(logger, filter_protocols, filter_fields, disable_fields):
"""
Validates input arguments. Used by evolve.py to restrict the scope
of this evolution.
"""
global SUPPORTED_LAYERS
if not disable_fields:
disable_fields = []
# First, apply a field whitelist if it was requested
valid = []
if filter_fields:
for layer in SUPPORTED_LAYERS:
new_fields = []
for field in filter_fields:
if field in layer.fields:
new_fields.append(field)
valid.append(field)
layer.fields = new_fields
if valid and logger:
logger.info("Strategies will only be allowed to use fields: %s" % ", ".join(list(set(valid))))
elif logger:
logger.error("None of the given fields exist in the packet headers of given protocols.")
# Apply a field blacklist if it was requested
for field in disable_fields:
for layer in SUPPORTED_LAYERS:
layer.fields = [f for f in layer.fields if f not in disable_fields]
if disable_fields and logger:
logger.info("Strategies will not be allowed to use fields %s" % ", ".join(disable_fields))
allowed_layers = []
# Finally, filter protocols
for protocol in filter_protocols:
allowed_layer = Packet.get_supported_protocol(protocol)
if not allowed_layer:
if logger:
logger.error("%s not a supported protocol." % protocol)
continue
# Only keep the layer allowed if it contains allowed fields
if allowed_layer.fields:
allowed_layers.append(allowed_layer)
assert allowed_layers, "Cannot evolve with no available packet layers!"
SUPPORTED_LAYERS = allowed_layers
if logger and allowed_layers:
logger.info("Strategies will only be allowed to use protocols: %s" % ", ".join([l.name for l in allowed_layers]))
def dns_decompress(self, logger):
"""
Performs DNS decompression, if applicable. Returns a new packet.
"""
self.packet = actions.layer.DNSLayer.dns_decompress(self.packet, logger)
self.layers = self.setup_layers()
return self

View File

@ -2,7 +2,7 @@ from actions.action import Action
class SleepAction(Action): class SleepAction(Action):
def __init__(self, time=1, environment_id=None): def __init__(self, time=1, environment_id=None):
Action.__init__(self, "sleep", "out") Action.__init__(self, "sleep", "both")
self.terminal = False self.terminal = False
self.branching = False self.branching = False
self.time = time self.time = time

View File

@ -1,85 +0,0 @@
import threading
import os
import actions.packet
from scapy.all import sniff
from scapy.utils import PcapWriter
class Sniffer():
"""
The sniffer class lets the user begin and end sniffing whenever in a given location with a port to filter on.
Call start_sniffing to begin sniffing and stop_sniffing to stop sniffing.
"""
def __init__(self, location, port, logger):
"""
Intializes a sniffer object.
Needs a location and a port to filter on.
"""
self.stop_sniffing_flag = False
self.location = location
self.port = port
self.pcap_thread = None
self.packet_dumper = None
self.logger = logger
full_path = os.path.dirname(location)
assert port, "Need to specify a port in order to launch a sniffer"
if not os.path.exists(full_path):
os.makedirs(full_path)
def __packet_callback(self, scapy_packet):
"""
This callback is called whenever a packet is applied.
Returns true if it should finish, otherwise, returns false.
"""
packet = actions.packet.Packet(scapy_packet)
for proto in ["TCP", "UDP"]:
if(packet.haslayer(proto) and ((packet[proto].sport == self.port) or (packet[proto].dport == self.port))):
break
else:
return self.stop_sniffing_flag
self.logger.debug(str(packet))
self.packet_dumper.write(scapy_packet)
return self.stop_sniffing_flag
def __spawn_sniffer(self):
"""
Saves pcaps to a file. Should be run as a thread.
Ends when the stop_sniffing_flag is set. Should not be called by user
"""
self.packet_dumper = PcapWriter(self.location, append=True, sync=True)
while(self.stop_sniffing_flag == False):
sniff(stop_filter=self.__packet_callback, timeout=1)
def start_sniffing(self):
"""
Starts sniffing. Should be called by user.
"""
self.stop_sniffing_flag = False
self.pcap_thread = threading.Thread(target=self.__spawn_sniffer)
self.pcap_thread.start()
self.logger.debug("Sniffer starting to port %d" % self.port)
def __enter__(self):
"""
Defines a context manager for this sniffer; simply starts sniffing.
"""
self.start_sniffing()
return self
def __exit__(self, exc_type, exc_value, tb):
"""
Defines exit context manager behavior for this sniffer; simply stops sniffing.
"""
self.stop_sniffing()
def stop_sniffing(self):
"""
Stops the sniffer by setting the flag and calling join
"""
if(self.pcap_thread):
self.stop_sniffing_flag = True
self.pcap_thread.join()
self.logger.debug("Sniffer stopping")

View File

@ -2,17 +2,25 @@
TamperAction TamperAction
One of the four packet-level primitives supported by Geneva. Responsible for any packet-level One of the four packet-level primitives supported by Geneva. Responsible for any packet-level
modifications (particularly header modifications). It supports replace and corrupt mode - modifications (particularly header modifications). It supports the following primitives:
in replace mode, it changes a packet field to a fixed value; in corrupt mode, it changes a packet - no operation: it returns the packet given
field to a randomly generated value each time it is run. - replace: it changes a packet field to a fixed value
- corrupt: it changes a packet field to a randomly generated value each time it is run
- add: adds a given value to the value in a field
- compress: performs DNS decompression on the packet (if applicable)
""" """
from actions.action import Action from actions.action import Action
import actions.utils import actions.utils
from actions.layer import DNSLayer
import random import random
# All supported tamper primitives
SUPPORTED_PRIMITIVES = ["corrupt", "replace", "add", "compress"]
class TamperAction(Action): class TamperAction(Action):
""" """
Defines the TamperAction for Geneva. Defines the TamperAction for Geneva.
@ -23,10 +31,7 @@ class TamperAction(Action):
self.tamper_value = tamper_value self.tamper_value = tamper_value
self.tamper_proto = actions.utils.string_to_protocol(tamper_proto) self.tamper_proto = actions.utils.string_to_protocol(tamper_proto)
self.tamper_proto_str = tamper_proto self.tamper_proto_str = tamper_proto
self.tamper_type = tamper_type self.tamper_type = tamper_type
if not self.tamper_type:
self.tamper_type = random.choice(["corrupt", "replace"])
def tamper(self, packet, logger): def tamper(self, packet, logger):
""" """
@ -41,8 +46,19 @@ class TamperAction(Action):
new_value = self.tamper_value new_value = self.tamper_value
# If corrupting the packet field, generate a value for it # If corrupting the packet field, generate a value for it
if self.tamper_type == "corrupt": try:
new_value = packet.gen(self.tamper_proto_str, self.field) if self.tamper_type == "corrupt":
new_value = packet.gen(self.tamper_proto_str, self.field)
elif self.tamper_type == "add":
new_value = int(self.tamper_value) + int(old_value)
elif self.tamper_type == "compress":
return packet.dns_decompress(logger)
except NotImplementedError:
# If a primitive does not support the type of packet given
return packet
except Exception:
# If an unexpected error has occurred
return packet
logger.debug(" - Tampering %s field `%s` (%s) by %s (to %s)" % logger.debug(" - Tampering %s field `%s` (%s) by %s (to %s)" %
(self.tamper_proto_str, self.field, str(old_value), self.tamper_type, str(new_value))) (self.tamper_proto_str, self.field, str(old_value), self.tamper_type, str(new_value)))
@ -67,8 +83,10 @@ class TamperAction(Action):
s = Action.__str__(self) s = Action.__str__(self)
if self.tamper_type == "corrupt": if self.tamper_type == "corrupt":
s += "{%s:%s:%s}" % (self.tamper_proto_str, self.field, self.tamper_type) s += "{%s:%s:%s}" % (self.tamper_proto_str, self.field, self.tamper_type)
elif self.tamper_type in ["replace"]: elif self.tamper_type in ["replace", "add"]:
s += "{%s:%s:%s:%s}" % (self.tamper_proto_str, self.field, self.tamper_type, self.tamper_value) s += "{%s:%s:%s:%s}" % (self.tamper_proto_str, self.field, self.tamper_type, self.tamper_value)
elif self.tamper_type == "compress":
s += "{%s:%s:compress}" % ("DNS", "qd", )
return s return s

View File

@ -65,14 +65,15 @@ class TraceAction(Action):
""" """
Parses a string representation for this object. Parses a string representation for this object.
""" """
if not string:
return False
try: try:
if string: self.start_ttl, self.end_ttl = string.split(":")
self.start_ttl, self.end_ttl = string.split(":") self.start_ttl = int(self.start_ttl)
self.start_ttl = int(self.start_ttl) self.end_ttl = int(self.end_ttl)
self.end_ttl = int(self.end_ttl) if self.start_ttl > self.end_ttl:
if self.start_ttl > self.end_ttl: logger.error("Cannot use a trace with a start ttl greater than end_ttl (%d > %d)" % (self.start_ttl, self.end_ttl))
logger.error("Cannot use a trace with a start ttl greater than end_ttl (%d > %d)" % (self.start_ttl, self.end_ttl)) return False
return False
except ValueError: except ValueError:
logger.exception("Cannot parse ttls from given data %s" % string) logger.exception("Cannot parse ttls from given data %s" % string)
return False return False

View File

@ -30,22 +30,6 @@ class ActionTree():
self.environment_id = None self.environment_id = None
self.ran = False self.ran = False
def initialize(self, num_actions, environment_id, allow_terminal=True, disabled=None):
"""
Sets up this action tree with a given number of random actions.
Note that the returned action trees may have less actions than num_actions
if terminal actions are used.
"""
self.environment_id = environment_id
self.trigger = actions.trigger.Trigger(None, None, None, environment_id=environment_id)
if not allow_terminal or random.random() > 0.1:
allow_terminal = False
for _ in range(num_actions):
new_action = self.get_rand_action(self.direction, disabled=disabled)
self.add_action(new_action)
return self
def __iter__(self): def __iter__(self):
""" """
Sets up a preoder iterator for the tree. Sets up a preoder iterator for the tree.

View File

@ -27,20 +27,6 @@ class Trigger(object):
self.bomb_trigger = bool(gas and gas < 0) self.bomb_trigger = bool(gas and gas < 0)
self.ran = False self.ran = False
@staticmethod
def get_gas():
"""
Returns a random value for gas for this trigger.
"""
if GAS_ENABLED and random.random() < 0.2:
# Use gas in 20% of scenarios
# Pick a number for gas between 0 - 5
gas_remaining = int(random.random() * 5)
else:
# Do not use gas
gas_remaining = None
return gas_remaining
def is_applicable(self, packet, logger): def is_applicable(self, packet, logger):
""" """
Checks if this trigger is applicable to a given packet. Checks if this trigger is applicable to a given packet.

View File

@ -119,7 +119,7 @@ def get_logger(basepath, log_dir, logger_name, log_name, environment_id, log_lev
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setFormatter(formatter) ch.setFormatter(formatter)
ch.setLevel(log_level) ch.setLevel(log_level)
CONSOLE_LOG_LEVEL = log_level CONSOLE_LOG_LEVEL = ch.level
logger.addHandler(ch) logger.addHandler(ch)
return logger return logger
@ -135,34 +135,6 @@ def close_logger(logger):
handler.close() handler.close()
class Logger():
"""
Logging class context manager, as a thin wrapper around the logging class to help
handle closing open file descriptors.
"""
def __init__(self, log_dir, logger_name, log_name, environment_id, log_level=logging.DEBUG):
self.log_dir = log_dir
self.logger_name = logger_name
self.log_name = log_name
self.environment_id = environment_id
self.log_level = log_level
self.logger = None
def __enter__(self):
"""
Sets up a logger.
"""
self.logger = get_logger(PROJECT_ROOT, self.log_dir, self.logger_name, self.log_name, self.environment_id, log_level=self.log_level)
return self.logger
def __exit__(self, exc_type, exc_value, tb):
"""
Closes file handles.
"""
close_logger(self.logger)
def get_console_log_level(): def get_console_log_level():
""" """
returns log level of console handler returns log level of console handler
@ -205,18 +177,6 @@ def setup_dirs(output_dir):
return ga_log_dir return ga_log_dir
def get_from_fuzzed_or_real_packet(environment_id, real_packet_probability, enable_options=True, enable_load=True):
"""
Retrieves a protocol, field, and value from a fuzzed or real packet, depending on
the given probability and if given packets is not None.
"""
packets = actions.utils.read_packets(environment_id)
if packets and random.random() < real_packet_probability:
packet = random.choice(packets)
return packet.get_random()
return actions.packet.Packet().gen_random()
def get_interface(): def get_interface():
""" """
Chooses an interface on the machine to use for socket testing. Chooses an interface on the machine to use for socket testing.

View File

@ -34,8 +34,6 @@ class Engine():
self.server_port = server_port self.server_port = server_port
self.seen_packets = [] self.seen_packets = []
# Set up the directory and ID for logging # Set up the directory and ID for logging
if not output_directory:
output_directory = "trials"
actions.utils.setup_dirs(output_directory) actions.utils.setup_dirs(output_directory)
if not environment_id: if not environment_id:
environment_id = actions.utils.get_id() environment_id = actions.utils.get_id()

72
tests/test_engine.py Normal file
View File

@ -0,0 +1,72 @@
import os
import sys
# Add the path to the engine so we can import it
BASEPATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASEPATH)
import engine
def test_engine():
"""
Basic engine test
"""
# Port to run the engine on
port = 80
# Strategy to use
strategy = "[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:corrupt},),)-| \/"
# Create the engine in debug mode
with engine.Engine(port, strategy, log_level="debug") as eng:
os.system("curl http://example.com?q=ultrasurf")
def test_engine_sleep():
"""
Basic engine test with sleep action
"""
# Port to run the engine on
port = 80
# Strategy to use
strategy = "[TCP:flags:S]-sleep{1}-|"
# Create the engine in debug mode
with engine.Engine(port, strategy, log_level="info") as eng:
os.system("curl http://example.com?q=ultrasurf")
# Strategy to use in opposite direction
strategy = "\/ [TCP:flags:SA]-sleep{1}-|"
# Create the engine in debug mode
with engine.Engine(port, strategy, log_level="debug") as eng:
os.system("curl http://example.com?q=ultrasurf")
def test_engine_trace():
"""
Basic engine test with trace
"""
# Port to run the engine on
port = 80
# Strategy to use
strategy = "[TCP:flags:PA]-trace{2:10}-|"
# Create the engine in debug mode
with engine.Engine(port, strategy, log_level="debug") as eng:
os.system("curl -m 5 http://example.com?q=ultrasurf")
def test_engine_drop():
"""
Basic engine test with drop
"""
# Port to run the engine on
port = 80
# Strategy to use
strategy = "\/ [TCP:flags:SA]-drop-|"
# Create the engine in debug mode
with engine.Engine(port, strategy, log_level="debug") as eng:
os.system("curl -m 3 http://example.com?q=ultrasurf")

220
tests/test_fragment.py Normal file
View File

@ -0,0 +1,220 @@
import logging
import pytest
import sys
# Include the root of the project
sys.path.append("..")
import actions.fragment
import actions.packet
import actions.strategy
import actions.utils
from scapy.all import IP, TCP, UDP
logger = logging.getLogger("test")
def test_segment():
"""
Tests the duplicate action primitive.
"""
fragment = actions.fragment.FragmentAction(correct_order=True)
assert str(fragment) == "fragment{tcp:-1:True}", "Fragment returned incorrect string representation: %s" % str(fragment)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP()/("data"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different"
assert packet1["Raw"].load == b'da', "Left packet incorrectly fragmented"
assert packet2["Raw"].load == b"ta", "Right packet incorrectly fragmented"
def test_segment_reverse():
"""
Tests the duplicate action primitive in reverse!
"""
fragment = actions.fragment.FragmentAction(correct_order=False)
assert str(fragment) == "fragment{tcp:-1:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP()/("data"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different"
assert packet1["Raw"].load == b'ta', "Left packet incorrectly fragmented"
assert packet2["Raw"].load == b"da", "Right packet incorrectly fragmented"
def test_odd_fragment():
"""
Tests long IP fragmentation
"""
fragment = actions.fragment.FragmentAction(correct_order=True, segment=False)
assert str(fragment) == "fragment{ip:-1:True}", "Fragment returned incorrect string representation: %s" % str(fragment)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("dataisodd"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different"
assert packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d', "Left packet incorrectly fragmented"
assert packet2["Raw"].load == b'\x00\x00\x00dP\x02 \x00e\xc1\x00\x00dataisodd', "Right packet incorrectly fragmented"
assert packet1["Raw"].load + packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00e\xc1\x00\x00dataisodd', "Packets fragmentation was incorrect"
def test_custom_fragment():
"""
Tests IP fragments with custom sized lengths
"""
fragment = actions.fragment.FragmentAction(correct_order=True, fragsize=3, segment=False)
assert str(fragment) == "fragment{ip:3:True}", "Fragment returned incorrect string representation: %s" % str(fragment)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("thisissomedata"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different"
assert packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00this', "Left packet incorrectly fragmented"
assert packet2["Raw"].load == b'issomedata', "Right packet incorrectly fragmented"
assert packet1["Raw"].load + packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00thisissomedata', "Packets fragmentation was incorrect"
def test_reverse_fragment():
"""
Tests fragmentation with reversed packets
"""
fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=3, segment=False)
assert str(fragment) == "fragment{ip:3:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("thisissomedata"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different"
assert packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00this', "Left packet incorrectly fragmented"
assert packet1["Raw"].load == b'issomedata', "Right packet incorrectly fragmented"
assert packet2["Raw"].load + packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00thisissomedata', "Packets fragmentation was incorrect"
def test_udp_fragment():
"""
Tests fragmentation with reversed packets
"""
fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False)
assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different"
def test_parse():
"""
Tests parsing.
"""
fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False)
assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
fragment.parse("fragment{tcp:5:False}", logger)
assert fragment.correct_order == False
assert fragment.fragsize == 5
assert fragment.segment == True
with pytest.raises(Exception):
fragment.parse("fragment{tcp:5}", logger)
with pytest.raises(Exception):
fragment.parse("fragment{tcp:a:True}", logger)
assert fragment.correct_order == False
assert fragment.fragsize == 5
assert fragment.segment == True
fragment = actions.fragment.FragmentAction()
assert fragment.correct_order in [True, False]
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
strat = actions.utils.parse("[IP:proto:6:0]-tamper{IP:proto:replace:6}(fragment{ip:-1:True}(tamper{TCP:dataofs:replace:8}(duplicate,),tamper{IP:frag:replace:0}),)-| [IP:tos:0:0]-duplicate-| \/", logger)
strat.act_on_packet(packet, logger)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/UDP(sport=2222, dport=3333, chksum=0x4444))
strat = actions.utils.parse("[IP:proto:6:0]-tamper{IP:proto:replace:6}(fragment{ip:-1:True}(tamper{TCP:dataofs:replace:8}(duplicate,),tamper{IP:frag:replace:0}),)-| [IP:tos:0:0]-duplicate-| \/", logger)
strat.act_on_packet(packet, logger)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, chksum=0x4444))
strat = actions.utils.parse("[TCP:urgptr:0]-tamper{TCP:options-altchksumopt:corrupt}(fragment{tcp:-1:True}(tamper{IP:proto:corrupt},tamper{TCP:seq:replace:654077552}),)-| \/", logger)
strat.act_on_packet(packet, logger)
strat = actions.utils.parse("[TCP:options-mss:]-tamper{TCP:load:replace:}(fragment{tcp:-1:True},)-| \/", logger)
strat.act_on_packet(packet, logger)
strat = actions.utils.parse("[TCP:options-mss:]-tamper{IP:frag:replace:1353}(tamper{TCP:load:replace:}(fragment{tcp:-1:True},),)-| \/", logger)
strat.act_on_packet(packet, logger)
strat = actions.utils.parse("[IP:ihl:5]-duplicate-| [TCP:options-mss:]-tamper{IP:frag:replace:1353}(fragment{tcp:-1:True}(tamper{TCP:load:replace:}(fragment{tcp:-1:False},),tamper{DNSQR:qtype:replace:45416}),)-| \/", logger)
strat.act_on_packet(packet, logger)
strat = actions.utils.parse("[DNSQR:qclass:25989]-duplicate(duplicate(tamper{DNSQR:qtype:replace:30882},),tamper{UDP:sport:replace:42042})-| [TCP:options-nop:]-tamper{TCP:options-nop:corrupt}(tamper{TCP:load:replace:mjkuskjzgy}(tamper{IP:frag:replace:410}(fragment{tcp:-1:True},),),)-| \/", logger)
strat.act_on_packet(packet, logger)
def test_fallback():
"""
Tests fallback behavior.
"""
fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False)
assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
fragment.parse("fragment{ip:0:False}", logger)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert str(packet1) == str(packet2)
fragment.parse("fragment{tcp:-1:False}", logger)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert str(packet1) == str(packet2)
fragment.parse("fragment{tcp:-1:False}", logger)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, chksum=0x4444))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert str(packet1) == str(packet2)
fragment.parse("fragment{ip:-1:False}", logger)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert str(packet1) == str(packet2)
def test_ip_only_fragment():
"""
Tests fragmentation without higher protocols.
"""
fragment = actions.fragment.FragmentAction(correct_order=True)
fragment.parse("fragment{ip:-1:True}", logger)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/("datadata11datadata"))
packet1, packet2 = fragment.run(packet, logger)
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different"
assert packet1["Raw"].load == b'datadata', "Left packet incorrectly fragmented"
assert packet2["Raw"].load == b"11datadata", "Right packet incorrectly fragmented"

544
tests/test_packet.py Normal file
View File

@ -0,0 +1,544 @@
import logging
import pytest
import actions.packet
import actions.layer
from scapy.all import IP, TCP, UDP, DNS, DNSQR, Raw, DNSRR
logger = logging.getLogger("test")
def test_parse_layers():
"""
Tests layer parsing.
"""
pkt = IP()/TCP()/Raw("")
packet = actions.packet.Packet(pkt)
layers = list(packet.read_layers())
assert layers[0].name == "IP"
assert layers[1].name == "TCP"
layers_dict = packet.setup_layers()
assert layers_dict["IP"]
assert layers_dict["TCP"]
def test_get_random():
"""
Tests get random
"""
tcplayer = actions.layer.TCPLayer(TCP())
field, value = tcplayer.get_random()
assert field in actions.layer.TCPLayer.fields
def test_gen_random():
"""
Tests gen random
"""
for i in range(0, 2000):
layer, field, value = actions.packet.Packet().gen_random()
assert layer in [DNS, TCP, UDP, IP, DNSQR]
def test_dnsqr():
"""
Tests DNSQR.
"""
pkt = UDP()/DNS(ancount=1)/DNSQR()
pkt.show()
packet = actions.packet.Packet(pkt)
packet.show()
assert len(packet.layers) == 3
assert "UDP" in packet.layers
assert "DNS" in packet.layers
assert "DNSQR" in packet.layers
pkt = IP()/UDP()/DNS()/DNSQR()
packet = actions.packet.Packet(pkt)
assert str(packet)
def test_load():
"""
Tests loads.
"""
tcp = actions.layer.TCPLayer(TCP())
assert tcp.gen("load")
pkt = IP()/"datadata"
p = actions.packet.Packet(pkt)
assert p.get("IP", "load") == "datadata"
p2 = actions.packet.Packet(IP(bytes(p)))
assert p2.get("IP", "load") == "datadata"
p2.set("IP", "load", "data2")
# Check p is unchanged
assert p.get("IP", "load") == "datadata"
assert p2.get("IP", "load") == "data2"
p2.show2()
# Check that we can dump
assert p2.show2(dump=True)
# Check that we can dump
assert p2.show(dump=True)
assert p2.get("IP", "chksum") == None
pkt = IP()/TCP()/"datadata"
p = actions.packet.Packet(pkt)
assert p.get("TCP", "load") == "datadata"
p2 = actions.packet.Packet(IP(bytes(p)))
assert p2.get("TCP", "load") == "datadata"
p2.set("TCP", "load", "data2")
# Check p is unchanged
assert p.get("TCP", "load") == "datadata"
assert p2.get("TCP", "load") == "data2"
p2.show2()
assert p2.get("IP", "chksum") == None
def test_parse_load():
"""
Tests load parsing.
"""
pkt = actions.packet.Packet(IP()/TCP()/"TYPE A\r\n")
print("Parsed: %s" % pkt.get("TCP", "load"))
strat = actions.utils.parse("[TCP:load:TYPE%20A%0D%0A]-drop-| \/", logger)
results = strat.act_on_packet(pkt, logger)
assert not results
value = pkt.gen("TCP", "load") + " " + pkt.gen("TCP", "load")
pkt.set("TCP", "load", value)
assert " " not in pkt.get("TCP", "load"), "%s contained a space!" % pkt.get("TCP", "load")
def test_dns():
"""
Tests DNS layer.
"""
dns = actions.layer.DNSLayer(DNS())
print(dns.gen("id"))
assert dns.gen("id")
p = actions.packet.Packet(DNS(id=0xabcd))
p2 = actions.packet.Packet(DNS(bytes(p)))
assert p.get("DNS", "id") == 0xabcd
assert p2.get("DNS", "id") == 0xabcd
p2.set("DNS", "id", 0x4321)
assert p.get("DNS", "id") == 0xabcd # Check p is unchanged
assert p2.get("DNS", "id") == 0x4321
dns = actions.packet.Packet(DNS(aa=1))
assert dns.get("DNS", "aa") == 1
aa = dns.gen("DNS", "aa")
assert aa == 0 or aa == 1
assert dns.get("DNS", "aa") == 1 # Original value unchanged
dns = actions.packet.Packet(DNS(opcode=15))
assert dns.get("DNS", "opcode") == 15
opcode = dns.gen("DNS", "opcode")
assert opcode >= 0 and opcode <= 15
assert dns.get("DNS", "opcode") == 15 # Original value unchanged
dns.set("DNS", "opcode", 3)
assert dns.get("DNS", "opcode") == 3
dns = actions.packet.Packet(DNS(qr=0))
assert dns.get("DNS", "qr") == 0
qr = dns.gen("DNS", "qr")
assert qr == 0 or qr == 1
assert dns.get("DNS", "qr") == 0 # Original value unchanged
dns.set("DNS", "qr", 1)
assert dns.get("DNS", "qr") == 1
dns = actions.packet.Packet(DNS(arcount=0xAABB))
assert dns.get("DNS", "arcount") == 0xAABB
arcount = dns.gen("DNS", "arcount")
assert arcount >= 0 and arcount <= 0xffff
assert dns.get("DNS", "arcount") == 0xAABB # Original value unchanged
dns.set("DNS", "arcount", 65432)
assert dns.get("DNS", "arcount") == 65432
dns = actions.layer.DNSLayer(DNS()/DNSQR(qname="example.com"))
assert isinstance(dns.get_next_layer(), DNSQR)
print(dns.gen("id"))
assert dns.gen("id")
p = actions.packet.Packet(DNS(id=0xabcd))
p2 = actions.packet.Packet(DNS(bytes(p)))
assert p.get("DNS", "id") == 0xabcd
assert p2.get("DNS", "id") == 0xabcd
def test_read_layers():
"""
Tests the ability to read each layer
"""
packet = IP() / UDP() / TCP() / DNS() / DNSQR(qname="example.com") / DNSQR(qname="example2.com") / DNSQR(qname="example3.com")
packet_geneva = actions.packet.Packet(packet)
packet_geneva.setup_layers()
i = 0
for layer in packet_geneva.read_layers():
if i == 0:
assert isinstance(layer, actions.layer.IPLayer)
elif i == 1:
assert isinstance(layer, actions.layer.UDPLayer)
elif i == 2:
assert isinstance(layer, actions.layer.TCPLayer)
elif i == 3:
assert isinstance(layer, actions.layer.DNSLayer)
elif i == 4:
assert isinstance(layer, actions.layer.DNSQRLayer)
assert layer.layer.qname == b"example.com"
elif i == 5:
assert isinstance(layer, actions.layer.DNSQRLayer)
assert layer.layer.qname == b"example2.com"
elif i == 6:
assert isinstance(layer, actions.layer.DNSQRLayer)
assert layer.layer.qname == b"example3.com"
i += 1
def test_multi_opts():
"""
Tests various option getting/setting.
"""
pkt = IP()/TCP(options=[('MSS', 1460), ('SAckOK', b''), ('Timestamp', (4154603075, 0)), ('NOP', None), ('WScale', 7)])
packet = actions.packet.Packet(pkt)
assert packet.get("TCP", "options-sackok") == ''
assert packet.get("TCP", "options-mss") == 1460
assert packet.get("TCP", "options-timestamp") == 4154603075
assert packet.get("TCP", "options-wscale") == 7
packet.set("TCP", "options-timestamp", 400000000)
assert packet.get("TCP", "options-sackok") == ''
assert packet.get("TCP", "options-mss") == 1460
assert packet.get("TCP", "options-timestamp") == 400000000
assert packet.get("TCP", "options-wscale") == 7
pkt = IP()/TCP(options=[('SAckOK', b''), ('Timestamp', (4154603075, 0)), ('NOP', None), ('WScale', 7)])
packet = actions.packet.Packet(pkt)
# If the option isn't present, it will be returned as an empty string
assert packet.get("TCP", "options-mss") == ''
packet.set("TCP", "options-mss", "")
assert packet.get("TCP", "options-mss") == 0
def test_options_eol():
"""
Tests options-eol.
"""
pkt = TCP(options=[("EOL", None)])
p = actions.packet.Packet(pkt)
assert p.get("TCP", "options-eol") == ""
p2 = actions.packet.Packet(TCP(bytes(p)))
assert p2.get("TCP", "options-eol") == ""
p = actions.packet.Packet(IP()/TCP(options=[]))
assert p.get("TCP", "options-eol") == ""
p.set("TCP", "options-eol", "")
p.show()
assert len(p["TCP"].options) == 1
assert any(k == "EOL" for k, v in p["TCP"].options)
value = p.gen("TCP", "options-eol")
assert value == "", "eol cannot store data"
p.set("TCP", "options-eol", value)
p2 = TCP(bytes(p))
assert any(k == "EOL" for k, v in p2["TCP"].options)
def test_options_mss():
"""
Tests options-eol.
"""
pkt = TCP(options=[("MSS", 1440)])
p = actions.packet.Packet(pkt)
assert p.get("TCP", "options-mss") == 1440
p2 = actions.packet.Packet(TCP(bytes(p)))
assert p2.get("TCP", "options-mss") == 1440
p = actions.packet.Packet(TCP(options=[]))
assert p.get("TCP", "options-mss") == ""
p.set("TCP", "options-mss", 2880)
p.show()
assert len(p["TCP"].options) == 1
assert any(k == "MSS" for k, v in p["TCP"].options)
value = p.gen("TCP", "options-mss")
p.set("TCP", "options-mss", value)
p2 = TCP(bytes(p))
assert any(k == "MSS" for k, v in p2["TCP"].options)
def check_get(protocol, field, value):
"""
Checks if the get method worked for this protocol, field, and value.
"""
pkt = protocol()
setattr(pkt, field, value)
packet = actions.packet.Packet(pkt)
assert packet.get(protocol.__name__, field) == value
def get_test_configs():
"""
Generates test configurations for the getters.
"""
return [
(IP, 'version', 4),
(IP, 'version', 6),
(IP, 'version', 0),
(IP, 'ihl', 0),
(IP, 'tos', 0),
(IP, 'len', 50),
(IP, 'len', 6),
(IP, 'flags', 'MF'),
(IP, 'flags', 'DF'),
(IP, 'flags', 'MF+DF'),
(IP, 'ttl', 25),
(IP, 'proto', 4),
(IP, 'chksum', 0x4444),
(IP, 'src', '127.0.0.1'),
(IP, 'dst', '127.0.0.1'),
(TCP, 'sport', 12345),
(TCP, 'dport', 55555),
(TCP, 'seq', 123123123),
(TCP, 'ack', 181818181),
(TCP, 'dataofs', 5),
(TCP, 'dataofs', 0),
(TCP, 'dataofs', 15),
(TCP, 'reserved', 0),
(TCP, 'window', 100),
(TCP, 'chksum', 0x4444),
(TCP, 'urgptr', 1),
(DNS, 'id', 0xabcd),
(DNS, 'qr', 1),
(DNS, 'opcode', 9),
(DNS, 'aa', 0),
(DNS, 'tc', 1),
(DNS, 'rd', 0),
(DNS, 'ra', 1),
(DNS, 'z', 0),
(DNS, 'ad', 1),
(DNS, 'cd', 0),
(DNS, 'qdcount', 0x1234),
(DNS, 'ancount', 12345),
(DNS, 'nscount', 49870),
(DNS, 'arcount', 0xABCD),
(DNSQR, 'qname', 'example.com.'),
(DNSQR, 'qtype', 1),
(DNSQR, 'qclass', 0),
]
def get_custom_configs():
"""
Generates test configurations that can use the custom getters.
"""
return [
(IP, 'flags', ''),
(TCP, 'options-eol', ''),
(TCP, 'options-nop', ''),
(TCP, 'options-mss', 0),
(TCP, 'options-mss', 1440),
(TCP, 'options-mss', 5000),
(TCP, 'options-wscale', 20),
(TCP, 'options-sackok', ''),
(TCP, 'options-sack', ''),
(TCP, 'options-timestamp', 12345678),
(TCP, 'options-altchksum', 0x44),
(TCP, 'options-altchksumopt', ''),
(TCP, 'options-uto', 1),
#(TCP, 'options-md5header', 'deadc0ffee')
]
@pytest.mark.parametrize("config", get_test_configs(),
ids=['%s-%s-%s' % (proto.__name__, field, str(val)) for proto, field, val in get_test_configs()])
def test_get(config):
"""
Tests value retrieval.
"""
proto, field, val = config
check_get(proto, field, val)
def check_set_get(protocol, field, value):
"""
Checks if the get method worked for this protocol, field, and value.
"""
pkt = actions.packet.Packet(protocol())
pkt.set(protocol.__name__, field, value)
assert pkt.get(protocol.__name__, field) == value
# Rebuild the packet to confirm the type survived
pkt2 = actions.packet.Packet(protocol(bytes(pkt)))
assert pkt2.get(protocol.__name__, field) == value, "Value %s for header %s didn't survive packet parsing." % (value, field)
@pytest.mark.parametrize("config", get_test_configs() + get_custom_configs(),
ids=['%s-%s-%s' % (proto.__name__, field, str(val)) for proto, field, val in get_test_configs() + get_custom_configs()])
def test_set_get(config):
"""
Tests value retrieval.
"""
proto, field, value = config
check_set_get(proto, field, value)
def check_gen_set_get(protocol, field):
"""
Checks if the get method worked for this protocol, field, and value.
"""
pkt = actions.packet.Packet(protocol())
new_value = pkt.gen(protocol.__name__, field)
pkt.set(protocol.__name__, field, new_value)
assert pkt.get(protocol.__name__, field) == new_value
# Rebuild the packet to confirm the type survived
pkt2 = actions.packet.Packet(protocol(bytes(pkt)))
assert pkt2.get(protocol.__name__, field) == new_value
@pytest.mark.parametrize("config", get_test_configs() + get_custom_configs(),
ids=['%s-%s' % (proto.__name__, field) for proto, field, _ in get_test_configs() + get_custom_configs()])
def test_gen_set_get(config):
"""
Tests value retrieval.
"""
# Test each generator 50 times to hit a range of values
for i in range(0, 50):
proto, field, _ = config
check_gen_set_get(proto, field)
def test_custom_get():
"""
Tests value retrieval for custom getters.
"""
pkt = IP()/TCP()/Raw(load="AAAA")
tcp = actions.packet.Packet(pkt)
assert tcp.get("TCP", "load") == "AAAA"
def test_restrict_fields():
"""
Tests packet field restriction.
"""
actions.packet.SUPPORTED_LAYERS = [
actions.layer.IPLayer,
actions.layer.TCPLayer,
actions.layer.UDPLayer
]
tcpfields = actions.layer.TCPLayer.fields
udpfields = actions.layer.UDPLayer.fields
ipfields = actions.layer.IPLayer.fields
actions.packet.Packet.restrict_fields(logger, ["TCP", "UDP"], [], [])
assert len(actions.packet.SUPPORTED_LAYERS) == 2
assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS
assert actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS
assert not actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS
pkt = IP()/TCP()
packet = actions.packet.Packet(pkt)
assert "TCP" in packet.layers
assert not "IP" in packet.layers
assert len(packet.layers) == 1
for i in range(0, 2000):
layer, proto, field = actions.packet.Packet().gen_random()
assert layer in [TCP, UDP]
# Check we can't retrieve any IP fields
for field in actions.layer.IPLayer.fields:
with pytest.raises(AssertionError):
packet.get("IP", field)
# Check we can get all the TCP fields
for field in actions.layer.TCPLayer.fields:
packet.get("TCP", field)
actions.packet.Packet.restrict_fields(logger, ["TCP", "UDP"], ["flags"], [])
packet = actions.packet.Packet(pkt)
assert len(actions.packet.SUPPORTED_LAYERS) == 1
assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS
assert not actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS
assert not actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS
assert actions.layer.TCPLayer.fields == ["flags"]
assert not actions.layer.UDPLayer.fields
# Check we can't retrieve any IP fields
for field in actions.layer.IPLayer.fields:
with pytest.raises(AssertionError):
packet.get("IP", field)
# Check we can get all the TCP fields
for field in tcpfields:
if field == "flags":
packet.get("TCP", field)
else:
with pytest.raises(AssertionError):
packet.get("TCP", field)
for i in range(0, 2000):
layer, field, value = actions.packet.Packet().gen_random()
assert layer == TCP
assert field == "flags"
actions.packet.Packet.reset_restrictions()
actions.packet.SUPPORTED_LAYERS = [
actions.layer.IPLayer,
actions.layer.TCPLayer,
actions.layer.UDPLayer
]
actions.packet.Packet.restrict_fields(logger, ["TCP", "IP"], [], ["sport", "dport", "seq", "src"])
packet = actions.packet.Packet(pkt)
packet = packet.copy()
assert packet.has_supported_layers()
assert len(actions.packet.SUPPORTED_LAYERS) == 2
assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS
assert not actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS
assert actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS
assert set(actions.layer.TCPLayer.fields) == set([f for f in tcpfields if f not in ["sport", "dport", "seq"]])
assert set(actions.layer.IPLayer.fields) == set([f for f in ipfields if f not in ["src"]])
# Check we can't retrieve any IP fields
for field in actions.layer.IPLayer.fields:
if field == "src":
with pytest.raises(AssertionError):
packet.get("IP", field)
else:
packet.get("IP", field)
# Check we can get all the TCP fields
for field in tcpfields:
if field in ["sport", "dport", "seq"]:
with pytest.raises(AssertionError):
packet.get("TCP", field)
else:
packet.get("TCP", field)
for i in range(0, 2000):
layer, field, value = actions.packet.Packet().gen_random()
assert layer in [TCP, IP]
assert field not in ["sport", "dport", "seq", "src"]
actions.packet.Packet.reset_restrictions()
actions.packet.SUPPORTED_LAYERS = [
actions.layer.IPLayer,
actions.layer.TCPLayer,
actions.layer.UDPLayer
]
actions.packet.Packet.restrict_fields(logger, ["IP", "UDP", "DNS"], [], ["version"])
packet = actions.packet.Packet(pkt)
proto, field, value = packet.get_random()
assert proto.__name__ in ["IP", "UDP"]
assert len(actions.packet.SUPPORTED_LAYERS) == 2
assert not actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS
assert actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS
assert actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS
assert set(actions.layer.IPLayer.fields) == set([f for f in ipfields if f not in ["version"]])
assert set(actions.layer.UDPLayer.fields) == set(udpfields)
actions.packet.Packet.reset_restrictions()
for layer in actions.packet.SUPPORTED_LAYERS:
assert layer.fields, '%s has no fields - reset failed!' % str(layer)

94
tests/test_strategy.py Normal file
View File

@ -0,0 +1,94 @@
import logging
import pytest
import actions.tree
import actions.drop
import actions.tamper
import actions.trace
import actions.duplicate
import actions.sleep
import actions.utils
import actions.strategy
from scapy.all import IP, TCP
logger = logging.getLogger("test")
def test_run():
"""
Tests strategy execution.
"""
strat1 = actions.utils.parse("[TCP:flags:R]-duplicate-| \/", logger)
strat2 = actions.utils.parse("[TCP:flags:S]-drop-| \/", logger)
strat3 = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:dataofs:replace:0},)-| \/", logger)
strat4 = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:15239},),duplicate(tamper{TCP:flags:replace:S}(tamper{TCP:chksum:replace:14539}(tamper{TCP:seq:corrupt},),),))-| \/", logger)
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
packets = strat1.act_on_packet(p1, logger, direction="out")
assert packets, "Strategy dropped SYN packets"
assert len(packets) == 1
assert packets[0]["TCP"].flags == "S"
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
packets = strat2.act_on_packet(p1, logger, direction="out")
assert not packets, "Strategy failed to drop SYN packets"
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="A", dataofs=5))
packets = strat3.act_on_packet(p1, logger, direction="out")
assert packets, "Strategy dropped packets"
assert len(packets) == 2, "Incorrect number of packets emerged from forest"
assert packets[0]["TCP"].dataofs == 0, "Packet tamper failed"
assert packets[1]["TCP"].dataofs == 5, "Duplicate packet was tampered"
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="A", dataofs=5, chksum=100))
packets = strat4.act_on_packet(p1, logger, direction="out")
assert packets, "Strategy dropped packets"
assert len(packets) == 3, "Incorrect number of packets emerged from forest"
assert packets[0]["TCP"].flags == "R", "Packet tamper failed"
assert packets[0]["TCP"].chksum != p1["TCP"].chksum, "Packet tamper failed"
assert packets[1]["TCP"].flags == "S", "Packet tamper failed"
assert packets[1]["TCP"].chksum != p1["TCP"].chksum, "Packet tamper failed"
assert packets[1]["TCP"].seq != p1["TCP"].seq, "Packet tamper failed"
assert packets[2]["TCP"].flags == "A", "Duplicate failed"
strat4 = actions.utils.parse("[TCP:load:]-tamper{TCP:load:replace:mhe76jm0bd}(fragment{ip:-1:True}(tamper{IP:load:corrupt},drop),)-| \/ ", logger)
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
packets = strat4.act_on_packet(p1, logger)
# Will fail with scapy 2.4.2 if packet is reparsed
strat5 = actions.utils.parse("\"[TCP:options-eol:]-tamper{TCP:load:replace:o}(tamper{TCP:dataofs:replace:11},)-| \/\"", logger)
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
packets = strat5.act_on_packet(p1, logger)
def test_pretty_print():
"""
Tests if the string representation of this strategy is correct
"""
logger = logging.getLogger("test")
strat = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:corrupt},),)-| \/ ", logger)
correct = "TCP:flags:A\nduplicate\n├── tamper{TCP:flags:replace:R}\n│ └── tamper{TCP:chksum:corrupt}\n│ └── ===> \n└── ===> \n \n \/ \n "
assert strat.pretty_print() == correct
def test_sleep_parse_handling():
"""
Tests that the sleep action handles bad parsing.
"""
print("Testing incorrect parsing:")
assert not actions.sleep.SleepAction().parse("THISHSOULDFAIL", logger)
assert actions.sleep.SleepAction().parse("10.5", logger)
def test_trace_parse_handling():
"""
Tests that the sleep action handles bad parsing.
"""
print("Testing incorrect parsing:")
assert not actions.trace.TraceAction().parse("5:4", logger)
assert not actions.trace.TraceAction().parse("THISHOULDFAIL", logger)
assert not actions.trace.TraceAction().parse("", logger)

389
tests/test_tamper.py Normal file
View File

@ -0,0 +1,389 @@
import copy
import logging
import sys
import pytest
import random
# Include the root of the project
sys.path.append("..")
import actions.strategy
import actions.packet
import actions.utils
import actions.tamper
import actions.layer
from scapy.all import IP, TCP, UDP, DNS, DNSQR, sr1
logger = logging.getLogger("test")
def test_tamper():
"""
Tests tampering with replace
"""
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
original = copy.deepcopy(packet)
tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="R")
lpacket, rpacket = tamper.run(packet, logger)
assert not rpacket, "Tamper must not return right child"
assert lpacket, "Tamper must give a left child"
assert id(lpacket) == id(packet), "Tamper must edit in place"
# Confirm tamper replaced the field it was supposed to
assert packet[TCP].flags == "R", "Tamper did not replace flags."
new_value = packet[TCP].flags
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
# ._fix()-ed, it will return different values each time it's requested
for _ in range(0, 5):
assert packet[TCP].flags == new_value, "Replaced value is not stable"
# Confirm tamper didn't corrupt anything else in the TCP header
assert confirm_unchanged(packet, original, TCP, ["flags"])
# Confirm tamper didn't corrupt anything in the IP header
assert confirm_unchanged(packet, original, IP, [])
def test_tamper_ip():
"""
Tests tampering with IP
"""
packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
original = copy.deepcopy(packet)
tamper = actions.tamper.TamperAction(None, field="src", tamper_type="replace", tamper_value="192.168.1.1", tamper_proto="IP")
lpacket, rpacket = tamper.run(packet, logger)
assert not rpacket, "Tamper must not return right child"
assert lpacket, "Tamper must give a left child"
assert id(lpacket) == id(packet), "Tamper must edit in place"
# Confirm tamper replaced the field it was supposed to
assert packet[IP].src == "192.168.1.1", "Tamper did not replace flags."
# Confirm tamper didn't corrupt anything in the TCP header
assert confirm_unchanged(packet, original, TCP, [])
# Confirm tamper didn't corrupt anything else in the IP header
assert confirm_unchanged(packet, original, IP, ["src"])
def test_tamper_udp():
"""
Tests tampering with UDP
"""
packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/UDP(sport=2222, dport=53))
original = copy.deepcopy(packet)
tamper = actions.tamper.TamperAction(None, field="chksum", tamper_type="replace", tamper_value=4444, tamper_proto="UDP")
lpacket, rpacket = tamper.run(packet, logger)
assert not rpacket, "Tamper must not return right child"
assert lpacket, "Tamper must give a left child"
assert id(lpacket) == id(packet), "Tamper must edit in place"
# Confirm tamper replaced the field it was supposed to
assert packet[UDP].chksum == 4444, "Tamper did not replace flags."
# Confirm tamper didn't corrupt anything in the TCP header
assert confirm_unchanged(packet, original, UDP, ["chksum"])
# Confirm tamper didn't corrupt anything else in the IP header
assert confirm_unchanged(packet, original, IP, [])
def test_tamper_ip_ident():
"""
Tests tampering with IP and that the checksum is correctly changed
"""
packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
original = copy.deepcopy(packet)
tamper = actions.tamper.TamperAction(None, field='id', tamper_type='replace', tamper_value=3333, tamper_proto="IP")
lpacket, rpacket = tamper.run(packet, logger)
assert not rpacket, "Tamper must not return right child"
assert lpacket, "Tamper must give a left child"
assert id(lpacket) == id(packet), "Tamper must edit in place"
# Confirm tamper replaced the field it was supposed to
assert packet[IP].id == 3333, "Tamper did not replace flags."
# Confirm tamper didn't corrupt anything in the TCP header
assert confirm_unchanged(packet, original, TCP, [])
# Confirm tamper didn't corrupt anything else in the IP header
assert confirm_unchanged(packet, original, IP, ["id"])
def confirm_unchanged(packet, original, protocol, changed):
"""
Checks that no other field besides the given array of changed fields
are different between these two packets.
"""
for header in packet.layers:
if packet.layers[header].protocol != protocol:
continue
for field in packet.layers[header].fields:
# Skip checking the field we just changed
if field in changed or field == "load":
continue
assert packet.get(protocol.__name__, field) == original.get(protocol.__name__, field), "Tamper changed %s field %s." % (str(protocol), field)
return True
def test_parse_parameters():
"""
Tests that tamper properly rejects malformed tamper actions
"""
with pytest.raises(Exception):
actions.tamper.TamperAction().parse("this:has:too:many:parameters", logger)
with pytest.raises(Exception):
actions.tamper.TamperAction().parse("not:enough", logger)
def test_corrupt():
"""
Tests the tamper 'corrupt' primitive.
"""
tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="corrupt", tamper_value="R")
assert tamper.field == "flags", "Tamper action changed fields."
assert tamper.tamper_type == "corrupt", "Tamper action changed types."
assert str(tamper) == "tamper{TCP:flags:corrupt}", "Tamper returned incorrect string representation: %s" % str(tamper)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
original = copy.deepcopy(packet)
tamper.tamper(packet, logger)
new_value = packet[TCP].flags
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
# ._fix()-ed, it will return different values each time it's requested
for _ in range(0, 5):
assert packet[TCP].flags == new_value, "Corrupted value is not stable"
# Confirm tamper didn't corrupt anything else in the TCP header
assert confirm_unchanged(packet, original, TCP, ["flags"])
# Confirm tamper didn't corrupt anything else in the IP header
assert confirm_unchanged(packet, original, IP, [])
def test_add():
"""
Tests the tamper 'add' primitive.
"""
tamper = actions.tamper.TamperAction(None, field="seq", tamper_type="add", tamper_value=10)
assert tamper.field == "seq", "Tamper action changed fields."
assert tamper.tamper_type == "add", "Tamper action changed types."
assert str(tamper) == "tamper{TCP:seq:add:10}", "Tamper returned incorrect string representation: %s" % str(tamper)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
original = copy.deepcopy(packet)
tamper.tamper(packet, logger)
new_value = packet[TCP].seq
assert new_value == 110, "Tamper did not add"
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
# ._fix()-ed, it will return different values each time it's requested
for _ in range(0, 5):
assert packet[TCP].seq == new_value, "Corrupted value is not stable"
# Confirm tamper didn't corrupt anything else in the TCP header
assert confirm_unchanged(packet, original, TCP, ["seq"])
# Confirm tamper didn't corrupt anything else in the IP header
assert confirm_unchanged(packet, original, IP, [])
def test_decompress():
"""
Tests the tamper 'decompress' primitive.
"""
tamper = actions.tamper.TamperAction(None, field="qd", tamper_type="compress", tamper_value=10, tamper_proto="DNS")
assert tamper.field == "qd", "Tamper action changed fields."
assert tamper.tamper_type == "compress", "Tamper action changed types."
assert str(tamper) == "tamper{DNS:qd:compress}", "Tamper returned incorrect string representation: %s" % str(tamper)
packet = actions.packet.Packet(IP(dst="8.8.8.8")/UDP(dport=53)/DNS(qd=DNSQR(qname="minghui.ca.")))
original = packet.copy()
tamper.tamper(packet, logger)
assert bytes(packet["DNS"]) == b'\x00\x00\x01\x00\x00\x02\x00\x00\x00\x00\x00\x00\x07minghui\xc0\x1a\x00\x01\x00\x01\x02ca\x00\x00\x01\x00\x01'
resp = sr1(packet.packet)
assert resp["DNS"]
assert resp["DNS"].rcode != 1
assert resp["DNSQR"]
assert resp["DNSRR"].rdata
assert confirm_unchanged(packet, original, IP, ["len"])
print(resp.summary())
packet = actions.packet.Packet(IP(dst="8.8.8.8")/UDP(dport=53)/DNS(qd=DNSQR(qname="maps.google.com")))
original = packet.copy()
tamper.tamper(packet, logger)
assert bytes(packet["DNS"]) == b'\x00\x00\x01\x00\x00\x02\x00\x00\x00\x00\x00\x00\x04maps\xc0\x17\x00\x01\x00\x01\x06google\x03com\x00\x00\x01\x00\x01'
resp = sr1(packet.packet)
assert resp["DNS"]
assert resp["DNS"].rcode != 1
assert resp["DNSQR"]
assert resp["DNSRR"].rdata
assert confirm_unchanged(packet, original, IP, ["len"])
print(resp.summary())
# Confirm this is a NOP on normal packets
packet = actions.packet.Packet(IP()/UDP())
original = packet.copy()
tamper.tamper(packet, logger)
assert packet.packet.summary() == original.packet.summary()
# Confirm tamper didn't corrupt anything else in the TCP header
assert confirm_unchanged(packet, original, UDP, [])
# Confirm tamper didn't corrupt anything else in the IP header
assert confirm_unchanged(packet, original, IP, [])
packet = actions.packet.Packet(IP(dst="8.8.8.8")/TCP(dport=53)/DNS(qd=DNSQR(qname="maps.google.com")))
original = packet.copy()
tamper.tamper(packet, logger)
assert bytes(packet) == bytes(original)
def test_corrupt_chksum():
"""
Tests the tamper 'replace' primitive.
"""
tamper = actions.tamper.TamperAction(None, field="chksum", tamper_type="corrupt", tamper_value="R")
assert tamper.field == "chksum", "Tamper action changed checksum."
assert tamper.tamper_type == "corrupt", "Tamper action changed types."
assert str(tamper) == "tamper{TCP:chksum:corrupt}", "Tamper returned incorrect string representation: %s" % str(tamper)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
original = copy.deepcopy(packet)
tamper.tamper(packet, logger)
# Confirm tamper actually corrupted the checksum
assert packet[TCP].chksum != 0
new_value = packet[TCP].chksum
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
# ._fix()-ed, it will return different values each time it's requested
for _ in range(0, 5):
assert packet[TCP].chksum == new_value, "Corrupted value is not stable"
# Confirm tamper didn't corrupt anything else in the TCP header
assert confirm_unchanged(packet, original, TCP, ["chksum"])
# Confirm tamper didn't corrupt anything else in the IP header
assert confirm_unchanged(packet, original, IP, [])
def test_corrupt_dataofs():
"""
Tests the tamper 'replace' primitive.
"""
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S", dataofs="6L"))
original = copy.deepcopy(packet)
tamper = actions.tamper.TamperAction(None, field="dataofs", tamper_type="corrupt")
tamper.tamper(packet, logger)
# Confirm tamper actually corrupted the checksum
assert packet[TCP].dataofs != "0"
new_value = packet[TCP].dataofs
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
# ._fix()-ed, it will return different values each time it's requested
for _ in range(0, 5):
assert packet[TCP].dataofs == new_value, "Corrupted value is not stable"
# Confirm tamper didn't corrupt anything else in the TCP header
assert confirm_unchanged(packet, original, TCP, ["dataofs"])
# Confirm tamper didn't corrupt anything in the IP header
assert confirm_unchanged(packet, original, IP, [])
def test_replace():
"""
Tests the tamper 'replace' primitive.
"""
tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="R")
assert tamper.field == "flags", "Tamper action changed fields."
assert tamper.tamper_type == "replace", "Tamper action changed types."
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
original = copy.deepcopy(packet)
tamper.tamper(packet, logger)
# Confirm tamper replaced the field it was supposed to
assert packet[TCP].flags == "R", "Tamper did not replace flags."
# Confirm tamper didn't replace anything else in the TCP header
assert confirm_unchanged(packet, original, TCP, ["flags"])
# Confirm tamper didn't replace anything else in the IP header
assert confirm_unchanged(packet, original, IP, [])
# chksums must be handled specially by tamper, so run a second check on this value
tamper.field = "chksum"
tamper.tamper_value = 0x4444
original = copy.deepcopy(packet)
tamper.tamper(packet, logger)
assert packet[TCP].chksum == 0x4444, "Tamper failed to change chksum."
# Confirm tamper didn't replace anything else in the TCP header
assert confirm_unchanged(packet, original, TCP, ["chksum"])
# Confirm tamper didn't replace anything else in the IP header
assert confirm_unchanged(packet, original, IP, [])
def test_parse_flags():
"""
Tests the tamper 'replace' primitive.
"""
tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="FRAPUN")
assert tamper.field == "flags", "Tamper action changed checksum."
assert tamper.tamper_type == "replace", "Tamper action changed types."
assert str(tamper) == "tamper{TCP:flags:replace:FRAPUN}", "Tamper returned incorrect string representation: %s" % str(tamper)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
tamper.tamper(packet, logger)
assert packet[TCP].flags == "FRAPUN", "Tamper failed to change flags."
@pytest.mark.parametrize("test_type", ["parsed", "direct"])
@pytest.mark.parametrize("value", ["EOL", "NOP", "Timestamp", "MSS", "WScale", "SAckOK", "SAck", "Timestamp", "AltChkSum", "AltChkSumOpt", "UTO"])
def test_options(value, test_type):
"""
Tests tampering options
"""
if test_type == "direct":
tamper = actions.tamper.TamperAction(None, field="options-%s" % value.lower(), tamper_type="corrupt", tamper_value=bytes([12]))
else:
tamper = actions.tamper.TamperAction(None)
assert tamper.parse("TCP:options-%s:replace:" % value.lower(), logger)
assert tamper.parse("TCP:options-%s:corrupt" % value.lower(), logger)
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
tamper.run(packet, logger)
opts_dict_lookup = value.lower().replace(" ", "_")
for optname, optval in packet["TCP"].options:
if optname == value:
break
elif optname == actions.layer.TCPLayer.options_names[opts_dict_lookup]:
break
else:
pytest.fail("Failed to find %s in options" % value)
assert len(packet["TCP"].options) == 1
raw_p = bytes(packet)
assert raw_p, "options broke scapy bytes"
p2 = actions.packet.Packet(IP(bytes(raw_p)))
assert p2.haslayer("IP")
assert p2.haslayer("TCP")
# EOLs might be added for padding, so just check >= 1
assert len(p2["TCP"].options) >= 1
for optname, optval in p2["TCP"].options:
if optname == value:
break
elif optname == actions.layer.TCPLayer.options_names[opts_dict_lookup]:
break
else:
pytest.fail("Failed to find %s in options" % value)

549
tests/test_tree.py Normal file
View File

@ -0,0 +1,549 @@
import logging
import os
from scapy.all import IP, TCP
import actions.tree
import actions.drop
import actions.tamper
import actions.duplicate
import actions.utils
def test_init():
"""
Tests initialization
"""
print(actions.action.Action.get_actions("out"))
def test_count_leaves():
"""
Tests leaf count is correct.
"""
a = actions.tree.ActionTree("out")
logger = logging.getLogger("test")
assert not a.parse("TCP:reserved:0tamper{TCP:flags:replace:S}-|", logger), "Tree parsed malformed DNA"
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
assert a.count_leaves() == 1
assert a.remove_one()
a.add_action(duplicate)
assert a.count_leaves() == 1
duplicate.left = duplicate2
assert a.count_leaves() == 1
duplicate.right = drop
assert a.count_leaves() == 2
def test_check():
"""
Tests action tree check function.
"""
a = actions.tree.ActionTree("out")
logger = logging.getLogger("test")
a.parse("[TCP:flags:RA]-tamper{TCP:flags:replace:S}-|", logger)
p = actions.packet.Packet(IP()/TCP(flags="A"))
assert not a.check(p, logger)
p = actions.packet.Packet(IP(ttl=64)/TCP(flags="RA"))
assert a.check(p, logger)
assert a.remove_one()
assert a.check(p, logger)
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
assert a.check(p, logger)
a.parse("[IP:ttl:64]-tamper{TCP:flags:replace:S}-|", logger)
assert a.check(p, logger)
p = actions.packet.Packet(IP(ttl=15)/TCP(flags="RA"))
assert not a.check(p, logger)
def test_scapy():
"""
Tests misc. scapy aspects relevant to strategies.
"""
a = actions.tree.ActionTree("out")
logger = logging.getLogger("test")
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
p = actions.packet.Packet(IP()/TCP(flags="A"))
assert a.check(p, logger)
packets = a.run(p, logger)
assert packets[0][TCP].flags == "S"
p = actions.packet.Packet(IP()/TCP(flags="A"))
assert a.check(p, logger)
a.parse("[TCP:reserved:0]-tamper{TCP:chksum:corrupt}-|", logger)
packets = a.run(p, logger)
assert packets[0][TCP].chksum
assert a.check(p, logger)
def test_str():
"""
Tests string representation.
"""
logger = logging.getLogger("test")
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
assert str(a).strip() == "[%s]-|" % str(t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
assert a.add_action(tamper)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|"
# Tree will not add a duplicate action
assert not a.add_action(tamper)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|"
assert a.add_action(tamper2)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|"
assert a.add_action(actions.duplicate.DuplicateAction())
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
drop = actions.drop.DropAction()
assert a.add_action(drop)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(drop,),),)-|" or \
str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(,drop),),)-|"
assert a.remove_action(drop)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
# Cannot remove action that is not present
assert not a.remove_action(drop)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
a = actions.tree.ActionTree("out", trigger=t)
orig = "[TCP:urgptr:15963]-duplicate(,drop)-|"
a.parse(orig, logger)
assert a.remove_one()
assert orig != str(a)
assert str(a) in ["[TCP:urgptr:15963]-drop-|", "[TCP:urgptr:15963]-duplicate-|"]
def test_pretty_print_send():
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
duplicate = actions.duplicate.DuplicateAction()
a.add_action(duplicate)
correct_string = "TCP:flags:0\nduplicate\n├── ===> \n└── ===> "
assert a.pretty_print() == correct_string
def test_pretty_print():
"""
Print complex tree, although difficult to test
"""
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
duplicate3 = actions.duplicate.DuplicateAction()
duplicate4 = actions.duplicate.DuplicateAction()
duplicate5 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
drop2 = actions.drop.DropAction()
drop3 = actions.drop.DropAction()
drop4 = actions.drop.DropAction()
duplicate.left = duplicate2
duplicate.right = duplicate3
duplicate2.left = tamper
duplicate2.right = drop
duplicate3.left = duplicate4
duplicate3.right = drop2
duplicate4.left = duplicate5
duplicate4.right = drop3
duplicate5.left = drop4
duplicate5.right = tamper2
a.add_action(duplicate)
correct_string = "TCP:flags:0\nduplicate\n├── duplicate\n│ ├── tamper{TCP:flags:replace:S}\n│ │ └── ===> \n│ └── drop\n└── duplicate\n ├── duplicate\n │ ├── duplicate\n │ │ ├── drop\n │ │ └── tamper{TCP:flags:replace:R}\n │ │ └── ===> \n │ └── drop\n └── drop"
assert a.pretty_print() == correct_string
assert a.pretty_print(visual=True)
assert os.path.exists("tree.png")
os.remove("tree.png")
a.parse("[TCP:flags:0]-|", logging.getLogger("test"))
a.pretty_print(visual=True) # Empty action tree
assert not os.path.exists("tree.png")
def test_pretty_print_order():
"""
Tests the left/right ordering by reading in a new tree
"""
logger = logging.getLogger("test")
a = actions.tree.ActionTree("out")
assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S}(tamper{TCP:chksum:replace:14239},),))-|", logger)
correct_pretty_print = "TCP:flags:A\nduplicate\n├── tamper{TCP:flags:replace:R}\n│ └── tamper{TCP:chksum:replace:14239}\n│ └── ===> \n└── duplicate\n ├── tamper{TCP:flags:replace:S}\n │ └── tamper{TCP:chksum:replace:14239}\n │ └── ===> \n └── ===> "
assert a.pretty_print() == correct_pretty_print
def test_parse():
"""
Tests string parsing.
"""
logger = logging.getLogger("test")
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
base_t = actions.trigger.Trigger("field", "flags", "TCP")
base_a = actions.tree.ActionTree("out", trigger=base_t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
a.parse("[TCP:flags:0]-|", logger)
assert str(a) == str(base_a)
assert len(a) == 0
base_a.add_action(tamper)
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}-|", logger)
assert str(a) == str(base_a)
assert len(a) == 1
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|", logging.getLogger("test"))
base_a.add_action(tamper2)
assert str(a) == str(base_a)
assert len(a) == 2
base_a.add_action(tamper3)
base_a.add_action(tamper4)
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},),),)-|", logging.getLogger("test"))
assert str(a) == str(base_a)
assert len(a) == 4
base_t = actions.trigger.Trigger("field", "flags", "TCP")
base_a = actions.tree.ActionTree("out", trigger=base_t)
duplicate = actions.duplicate.DuplicateAction()
assert a.parse("[TCP:flags:0]-duplicate-|", logger)
base_a.add_action(duplicate)
assert str(a) == str(base_a)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="A")
tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
duplicate.left = tamper
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},)-|", logger)
assert str(a) == str(base_a)
duplicate.right = tamper2
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-|", logger)
assert str(a) == str(base_a)
tamper2.left = tamper3
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:A},))-|", logger)
assert str(a) == str(base_a)
strategy = actions.utils.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-| \/", logger)
assert strategy
assert len(strategy.out_actions[0]) == 3
assert len(strategy.in_actions) == 0
assert not a.parse("[]", logger) # No valid trigger
assert not a.parse("[TCP:flags:0]-", logger) # No valid ending "|"
assert not a.parse("[TCP:]-|", logger) # invalid trigger
assert not a.parse("[TCP:flags:0]-foo-|", logger) # Non-existent action
assert not a.parse("[TCP:flags:0]--|", logger) # Empty action
assert not a.parse("[TCP:flags:0]-duplicate(,,,)-|", logger) # Bad tree
assert not a.parse("[TCP:flags:0]-duplicate()))-|", logger) # Bad tree
assert not a.parse("[TCP:flags:0]-duplicate(((()-|", logger) # Bad tree
assert not a.parse("[TCP:flags:0]-duplicate(,))))-|", logger) # Bad tree
assert not a.parse("[TCP:flags:0]-drop(duplicate,)-|", logger) # Terminal action with children
assert not a.parse("[TCP:flags:0]-drop(duplicate,duplicate)-|", logger) # Terminal action with children
assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(,duplicate)-|", logger) # Non-branching action with right child
assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(drop,duplicate)-|", logger) # Non-branching action with children
def test_tree():
"""
Tests basic tree functionality.
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
duplicate = actions.duplicate.DuplicateAction()
assert a.get_parent(None) == (None, None)
a.add_action(None)
a.add_action(tamper)
assert a.get_slots() == 1
a.add_action(tamper2)
assert a.get_parent(tamper2) == (tamper, "left")
assert a.get_slots() == 1
a.add_action(duplicate)
assert a.get_slots() == 2
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
drop = actions.drop.DropAction()
a.add_action(drop)
assert a.get_parent(drop) == (None, None)
assert a.get_slots() == 0
add_success = a.add_action(tamper)
assert not add_success
assert a.get_slots() == 0
rep = ""
for s in a.string_repr(a.action_root):
rep += s
assert rep == "drop"
print(str(a))
assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:seq:corrupt},)-|", logging.getLogger("test"))
for act in a:
print(str(a))
assert len(a) == 2
assert a.get_slots() == 2
def test_remove():
"""
Tests remove
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
tamper3 = actions.tamper.TamperAction()
assert not a.remove_action(tamper)
a.add_action(tamper)
assert a.remove_action(tamper)
a.add_action(tamper)
a.add_action(tamper2)
a.add_action(tamper3)
assert a.remove_action(tamper2)
assert tamper2 not in a
assert tamper.left == tamper3
assert not tamper.right
assert len(a) == 2
a = actions.tree.ActionTree("out", trigger=t)
duplicate = actions.duplicate.DuplicateAction()
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
tamper3 = actions.tamper.TamperAction()
a.add_action(tamper)
assert a.action_root == tamper
duplicate.left = tamper2
duplicate.right = tamper3
a.add_action(duplicate)
assert a.get_parent(tamper3) == (duplicate, "right")
assert len(a) == 4
assert a.remove_action(duplicate)
assert duplicate not in a
assert tamper.left == tamper2
assert not tamper.right
assert len(a) == 2
a.parse("[TCP:flags:A]-|", logging.getLogger("test"))
assert not a.remove_one(), "Cannot remove one with no action root"
def test_len():
"""
Tests length calculation.
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
assert len(a) == 0, "__len__ returned wrong length"
a.add_action(tamper)
assert len(a) == 1, "__len__ returned wrong length"
a.add_action(tamper)
assert len(a) == 1, "__len__ returned wrong length"
a.add_action(tamper2)
assert len(a) == 2, "__len__ returned wrong length"
duplicate = actions.duplicate.DuplicateAction()
a.add_action(duplicate)
assert len(a) == 3, "__len__ returned wrong length"
def test_contains():
"""
Tests contains method
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
tamper3 = actions.tamper.TamperAction()
assert not a.contains(tamper), "contains incorrect behavior"
assert not a.contains(tamper2), "contains incorrect behavior"
a.add_action(tamper)
assert a.contains(tamper), "contains incorrect behavior"
assert not a.contains(tamper2), "contains incorrect behavior"
add_success = a.add_action(tamper)
assert not add_success, "added duplicate action"
assert a.contains(tamper), "contains incorrect behavior"
assert not a.contains(tamper2), "contains incorrect behavior"
a.add_action(tamper2)
assert a.contains(tamper), "contains incorrect behavior"
assert a.contains(tamper2), "contains incorrect behavior"
a.remove_action(tamper2)
assert a.contains(tamper), "contains incorrect behavior"
assert not a.contains(tamper2), "contains incorrect behavior"
a.add_action(tamper2)
assert a.contains(tamper), "contains incorrect behavior"
assert a.contains(tamper2), "contains incorrect behavior"
remove_success = a.remove_action(tamper)
assert remove_success
assert not a.contains(tamper), "contains incorrect behavior"
assert a.contains(tamper2), "contains incorrect behavior"
a.add_action(tamper3)
assert a.contains(tamper3), "contains incorrect behavior"
assert len(a) == 2, "len incorrect return"
remove_success = a.remove_action(tamper2)
assert remove_success
def test_iter():
"""
Tests iterator.
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
assert a.add_action(tamper)
assert a.add_action(tamper2)
assert not a.add_action(tamper)
for node in a:
print(node)
def test_run():
"""
Tests running packets through the chain.
"""
logger = logging.getLogger("test")
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
packet = actions.packet.Packet(IP()/TCP())
a.add_action(tamper)
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 1
assert None not in packets
assert packets[0].get("TCP", "flags") == "S"
a.add_action(tamper2)
print(str(a))
packet = actions.packet.Packet(IP()/TCP())
assert not a.add_action(tamper), "tree added duplicate action"
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 1
assert None not in packets
assert packets[0].get("TCP", "flags") == "R"
print(str(a))
a.remove_action(tamper2)
a.remove_action(tamper)
a.add_action(duplicate)
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 2
assert None not in packets
assert packets[0][TCP].flags == "RA"
assert packets[1][TCP].flags == "RA"
print(str(a))
duplicate.left = tamper
duplicate.right = tamper2
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 2
assert None not in packets
print(str(a))
print(str(packets[0]))
print(str(packets[1]))
assert packets[0][TCP].flags == "S"
assert packets[1][TCP].flags == "R"
print(str(a))
tamper.left = duplicate2
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 3
assert None not in packets
assert packets[0][TCP].flags == "S"
assert packets[1][TCP].flags == "S"
assert packets[2][TCP].flags == "R"
print(str(a))
tamper2.left = drop
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 2
assert None not in packets
assert packets[0][TCP].flags == "S"
assert packets[1][TCP].flags == "S"
print(str(a))
assert a.remove_action(duplicate2)
tamper.left = actions.drop.DropAction()
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logger )
assert len(packets) == 0
print(str(a))
a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S},))-|", logger)
packet = actions.packet.Packet(IP()/TCP(flags="A"))
assert a.check(packet, logger)
packets = a.run(packet, logger)
assert len(packets) == 3
assert packets[0][TCP].flags == "R"
assert packets[1][TCP].flags == "S"
assert packets[2][TCP].flags == "A"
def test_index():
"""
Tests index
"""
a = actions.tree.ActionTree("out")
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="F")
assert a.add_action(tamper)
assert a[0] == tamper
assert not a[1]
assert a.add_action(tamper2)
assert a[0] == tamper
assert a[1] == tamper2
assert a[-1] == tamper2
assert not a[10]
assert a.add_action(tamper3)
assert a[-1] == tamper3
assert not a[-11]
def test_choose_one():
"""
Tests choose_one functionality
"""
a = actions.tree.ActionTree("out")
drop = actions.drop.DropAction()
assert not a.choose_one()
assert a.add_action(drop)
assert a.choose_one() == drop
assert a.remove_action(drop)
assert not a.choose_one()
duplicate = actions.duplicate.DuplicateAction()
a.add_action(duplicate)
assert a.choose_one() == duplicate
duplicate.left = drop
assert a.choose_one() in [duplicate, drop]
# Make sure that both actions get chosen
chosen = set()
for i in range(0, 10000):
act = a.choose_one()
chosen.add(act)
assert chosen == set([duplicate, drop])

151
tests/test_trigger.py Normal file
View File

@ -0,0 +1,151 @@
import logging
import sys
# Include the root of the project
sys.path.append("..")
import actions.packet
import actions.strategy
import actions.tamper
import actions.utils
from scapy.all import IP, TCP
logger = logging.getLogger("test")
def test_trigger_gas():
"""
Tests triggers having gas, including changing that gas while in use
"""
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA"))
trigger = actions.trigger.Trigger("field", "flags", "TCP", trigger_value="SA", gas=1)
print(trigger)
assert trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
print(trigger)
# test add gas #
trigger.add_gas(3)
assert trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
# Test disable, set, and enable gas #
trigger.disable_gas()
assert trigger.is_applicable(packet, logger)
trigger.set_gas(3)
assert trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
trigger.enable_gas()
trigger.set_gas(2)
assert trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
def test_bomb_trigger_gas():
"""
Tests triggers having bomb gas, including changing that gas while in use
"""
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA"))
trigger = actions.trigger.Trigger("field", "flags", "TCP", trigger_value="SA", gas=-1)
print(trigger)
assert not trigger.is_applicable(packet, logger), "trigger should not fire on first run"
assert trigger.is_applicable(packet, logger), "trigger should fire on second run"
print(trigger)
# test add gas #
trigger.add_gas(-3)
assert not trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
# Test disable, set, and enable gas #
trigger.disable_gas()
assert trigger.is_applicable(packet, logger)
trigger.set_gas(-3)
assert not trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
trigger.enable_gas()
trigger.set_gas(-2)
assert not trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
def test_trigger_parse_gas():
"""
Tests triggers having gas, including changing that gas while in use
"""
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA"))
# parse a trigger with 1 gas
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:1")
assert trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
# parse a trigger with no gas left
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:0")
assert not trigger.is_applicable(packet, logger)
# parse a trigger not using gas
trigger = actions.trigger.Trigger.parse("TCP:flags:SA")
assert trigger.is_applicable(packet, logger)
# Check that adding gas while gas is disabled does not work
trigger.add_gas(10)
assert trigger.gas_remaining == None
trigger.enable_gas()
trigger.set_gas(2)
assert trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
# Test that it can handle leading/trailing []
trigger = actions.trigger.Trigger.parse("[TCP:flags:SA]")
assert trigger.is_applicable(packet, logger)
def test_bomb_trigger_parse_gas():
"""
Tests bomb triggers having gas, including changing that gas while in use
"""
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA"))
# parse a bomb trigger with 1 gas
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:-1")
assert not trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
# parse a trigger with no gas left
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:0")
assert not trigger.is_applicable(packet, logger)
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:-1")
assert not trigger.is_applicable(packet, logger)
# parse a trigger not using gas
trigger = actions.trigger.Trigger.parse("TCP:flags:SA")
assert trigger.is_applicable(packet, logger)
# Check that adding gas while gas is disabled does not work
trigger.add_gas(10)
assert trigger.gas_remaining == None
trigger.enable_gas()
trigger.set_gas(2)
assert trigger.is_applicable(packet, logger)
assert trigger.is_applicable(packet, logger)
assert not trigger.is_applicable(packet, logger)
# Test that it can handle leading/trailing []
trigger = actions.trigger.Trigger.parse("[TCP:flags:SA]")
assert trigger.is_applicable(packet, logger)

42
tests/test_utils.py Normal file
View File

@ -0,0 +1,42 @@
import sys
import pytest
# Include the root of the project
sys.path.append("..")
import actions.action
import actions.strategy
import actions.utils
import actions.duplicate
import logging
logger = logging.getLogger("test")
def get_test_configs():
"""
Sets up the tests
"""
tests = [
("both", True, ['DuplicateAction', 'DropAction', 'SleepAction', 'TraceAction', 'TamperAction', 'FragmentAction']),
("in", True, ['DropAction', 'TamperAction', 'SleepAction']),
("out", True, ['DropAction', 'TamperAction', 'TraceAction', 'SleepAction', 'DuplicateAction', 'FragmentAction']),
("both", False, ['DuplicateAction', 'SleepAction', 'TamperAction', 'FragmentAction']),
("in", False, ['TamperAction', 'SleepAction']),
("out", False, ['TamperAction', 'SleepAction', 'DuplicateAction', 'FragmentAction']),
]
# To ensure caching is not breaking anything, double the tests
return tests + tests
@pytest.mark.parametrize("direction,allow_terminal,supported_actions", get_test_configs())
def test_get_actions(direction, allow_terminal, supported_actions):
"""
Tests the duplicate action primitive.
"""
collected_actions = actions.action.Action.get_actions(direction, allow_terminal=allow_terminal)
names = []
for name, action_class in collected_actions:
names.append(name)
assert set(names) == set(supported_actions)
assert len(names) == len(supported_actions)