mirror of https://github.com/Kkevsterrr/geneva
Merge pull request #9 from Kkevsterrr/GEN-7_tests
Added tests and code coverage for repository
This commit is contained in:
commit
b36cb7f438
|
@ -0,0 +1,16 @@
|
|||
[run]
|
||||
branch = True
|
||||
source = ./
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
if self.debug:
|
||||
pragma: no cover
|
||||
raise NotImplementedError
|
||||
if __name__ == .__main__.:
|
||||
def get_args
|
||||
def main
|
||||
ignore_errors = True
|
||||
omit =
|
||||
tests/*
|
||||
examples/*
|
|
@ -0,0 +1,42 @@
|
|||
sudo: required
|
||||
dist: "bionic"
|
||||
|
||||
language: python
|
||||
|
||||
python:
|
||||
- "3.6"
|
||||
|
||||
install:
|
||||
# Travis recently added systemd-resolvd to their VMs. Since full Geneva often runs its own DNS
|
||||
# server to test DNS strategies, we need to disable system-resolvd.
|
||||
# First disable the service
|
||||
- sudo systemctl disable systemd-resolved.service
|
||||
# Stop the service
|
||||
- sudo systemctl stop systemd-resolved
|
||||
# With systemd not running, our own hostname won't resolve - this causes issues with sudo.
|
||||
# Add back our hostname to /etc/hosts/ so sudo does not complain
|
||||
- echo $(hostname -I | cut -d\ -f1) $(hostname) | sudo tee -a /etc/hosts
|
||||
# Replace the 127.0.0.53 nameserver with Google's
|
||||
- sudo sed 's/nameserver.*/nameserver 8.8.8.8/' /etc/resolv.conf > /tmp/resolv.conf.new
|
||||
- sudo mv /tmp/resolv.conf.new /etc/resolv.conf
|
||||
# Now that systemd-resolv.conf is safely disabled, we can now setup for Geneva
|
||||
- sudo apt-get clean # travis having mirror sync issues
|
||||
# Install dependencies
|
||||
- sudo apt-get update
|
||||
- sudo apt-get -y install libnetfilter-queue-dev python3 python3-pip python3-setuptools graphviz
|
||||
# Since sudo is required but travis does not set up the root environment, we must override the
|
||||
# secure_path in sudoers in order for travis's setup to take effect for sudo commands
|
||||
- printf "Defaults\tenv_reset\nDefaults\tmail_badpass\nDefaults\tsecure_path="/home/travis/virtualenv/python3.6.7/bin/:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin"\nroot\tALL=(ALL:ALL) ALL\n#includedir /etc/sudoers.d\n" > /tmp/sudoers.tmp
|
||||
# Verify the sudoers file
|
||||
- sudo visudo -c -f /tmp/sudoers.tmp
|
||||
# Copy in the sudoers file
|
||||
- sudo cp /tmp/sudoers.tmp /etc/sudoers
|
||||
# Now that sudo is good to go, finish installing dependencies
|
||||
- sudo python3 -m pip install -r requirements.txt
|
||||
- sudo python3 -m pip install slackclient pytest-cov
|
||||
|
||||
script:
|
||||
- sudo python3 -m pytest --cov=./ -sv tests/ --tb=short
|
||||
|
||||
after_script:
|
||||
- bash <(curl -s https://codecov.io/bash) -t 83a45966-78ce-44c2-80b3-964ecab4a53d || echo "Codecov did not collect coverage reports"
|
|
@ -1,4 +1,4 @@
|
|||
# Geneva
|
||||
# Geneva [![Build Status](https://travis-ci.com/Kkevsterrr/geneva.svg?branch=master)](https://travis-ci.com/Kkevsterrr/geneva) [![codecov](https://codecov.io/gh/Kkevsterrr/geneva/branch/master/graph/badge.svg)](https://codecov.io/gh/Kkevsterrr/geneva)
|
||||
|
||||
Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors, such as those in China, India, and Kazakhstan. Unlike many other anti-censorship solutions which require assistance from outside the censoring regime (Tor, VPNs, etc.), Geneva runs strictly on the client.
|
||||
|
||||
|
|
|
@ -13,9 +13,3 @@ class DuplicateAction(Action):
|
|||
"""
|
||||
logger.debug(" - Duplicating given packet %s" % str(packet))
|
||||
return packet, packet.copy()
|
||||
|
||||
def mutate(self, environment_id=None):
|
||||
"""
|
||||
Swaps its left and right child
|
||||
"""
|
||||
self.left, self.right = self.right, self.left
|
||||
|
|
|
@ -196,22 +196,3 @@ class FragmentAction(Action):
|
|||
self.correct_order = False
|
||||
|
||||
return True
|
||||
|
||||
def mutate(self, environment_id=None):
|
||||
"""
|
||||
Mutates the fragment action - it either chooses a new segment offset,
|
||||
switches the packet order, and/or changes whether it segments or fragments.
|
||||
"""
|
||||
self.correct_order = self.get_rand_order()
|
||||
self.segment = random.choice([True, True, True, False])
|
||||
if self.segment:
|
||||
if random.random() < 0.5:
|
||||
self.fragsize = int(random.uniform(1, 60))
|
||||
else:
|
||||
self.fragsize = -1
|
||||
else:
|
||||
if random.random() < 0.2:
|
||||
self.fragsize = int(random.uniform(1, 50))
|
||||
else:
|
||||
self.fragsize = -1
|
||||
return self
|
||||
|
|
243
actions/layer.py
243
actions/layer.py
|
@ -4,7 +4,8 @@ import string
|
|||
import os
|
||||
import urllib.parse
|
||||
|
||||
from scapy.all import IP, RandIP, UDP, Raw, TCP, fuzz
|
||||
from scapy.all import IP, RandIP, UDP, DNS, DNSQR, Raw, TCP, fuzz
|
||||
|
||||
|
||||
class Layer():
|
||||
"""
|
||||
|
@ -179,6 +180,12 @@ class Layer():
|
|||
value = urllib.parse.unquote(value)
|
||||
|
||||
value = value.encode('utf-8')
|
||||
# Add support for injecting arbitrary protocol payloads if requested
|
||||
dns_payload = b"\x009ib\x81\x80\x00\x01\x00\x01\x00\x00\x00\x01\x08examples\x03com\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x01+\x00\x04\xc7\xbf2I\x00\x00)\x02\x00\x00\x00\x00\x00\x00\x00"
|
||||
http_payload = b"GET / HTTP/1.1\r\nHost: www.example.com\r\n\r\n"
|
||||
|
||||
value = value.replace(b"__DNS_REQUEST__", dns_payload)
|
||||
value = value.replace(b"__HTTP_REQUEST__", http_payload)
|
||||
|
||||
self.layer.payload = Raw(value)
|
||||
|
||||
|
@ -592,3 +599,237 @@ class UDPLayer(Layer):
|
|||
self.generators = {
|
||||
'load' : self.gen_load,
|
||||
}
|
||||
|
||||
|
||||
class DNSLayer(Layer):
|
||||
"""
|
||||
Defines an interface to access DNS header fields.
|
||||
"""
|
||||
name = "DNS"
|
||||
protocol = DNS
|
||||
_fields = [
|
||||
"id",
|
||||
"qr",
|
||||
"opcode",
|
||||
"aa",
|
||||
"tc",
|
||||
"rd",
|
||||
"ra",
|
||||
"z",
|
||||
"ad",
|
||||
"cd",
|
||||
"qd",
|
||||
"rcode",
|
||||
"qdcount",
|
||||
"ancount",
|
||||
"nscount",
|
||||
"arcount"
|
||||
]
|
||||
fields = _fields
|
||||
def __init__(self, layer):
|
||||
"""
|
||||
Initializes the DNS layer.
|
||||
"""
|
||||
Layer.__init__(self, layer)
|
||||
|
||||
self.getters = {
|
||||
"qr" : self.get_bitfield,
|
||||
"aa" : self.get_bitfield,
|
||||
"tc" : self.get_bitfield,
|
||||
"rd" : self.get_bitfield,
|
||||
"ra" : self.get_bitfield,
|
||||
"z" : self.get_bitfield,
|
||||
"ad" : self.get_bitfield,
|
||||
"cd" : self.get_bitfield
|
||||
}
|
||||
|
||||
self.setters = {
|
||||
"qr" : self.set_bitfield,
|
||||
"aa" : self.set_bitfield,
|
||||
"tc" : self.set_bitfield,
|
||||
"rd" : self.set_bitfield,
|
||||
"ra" : self.set_bitfield,
|
||||
"z" : self.set_bitfield,
|
||||
"ad" : self.set_bitfield,
|
||||
"cd" : self.set_bitfield
|
||||
}
|
||||
|
||||
self.generators = {
|
||||
"id" : self.gen_id,
|
||||
"qr" : self.gen_bitfield,
|
||||
"opcode" : self.gen_opcode,
|
||||
"aa" : self.gen_bitfield,
|
||||
"tc" : self.gen_bitfield,
|
||||
"rd" : self.gen_bitfield,
|
||||
"ra" : self.gen_bitfield,
|
||||
"z" : self.gen_bitfield,
|
||||
"ad" : self.gen_bitfield,
|
||||
"cd" : self.gen_bitfield,
|
||||
"rcode" : self.gen_rcode,
|
||||
"qdcount" : self.gen_count,
|
||||
"ancount" : self.gen_count,
|
||||
"nscount" : self.gen_count,
|
||||
"arcount" : self.gen_count
|
||||
}
|
||||
|
||||
def get_bitfield(self, field):
|
||||
""""""
|
||||
return int(getattr(self.layer, field))
|
||||
|
||||
def set_bitfield(self, packet, field, value):
|
||||
""""""
|
||||
return setattr(self.layer, field, int(value))
|
||||
|
||||
def gen_bitfield(self, field):
|
||||
""""""
|
||||
return random.choice([0,1])
|
||||
|
||||
def gen_id(self, field):
|
||||
return random.randint(0, 65535)
|
||||
|
||||
def gen_opcode(self, field):
|
||||
return random.randint(0, 15)
|
||||
|
||||
def gen_rcode(self, field):
|
||||
return random.randint(0, 15)
|
||||
|
||||
def gen_count(self, field):
|
||||
return random.randint(0, 65535)
|
||||
|
||||
@staticmethod
|
||||
def dns_decompress(packet, logger):
|
||||
"""
|
||||
Performs DNS decompression on the given scapy packet, if applicable.
|
||||
Note that DNS compression/decompression must be done on the boundaries
|
||||
of a label, so DNS compression does not support arbitrary offsets.
|
||||
"""
|
||||
# If this is a TCP packet
|
||||
if packet.haslayer("TCP"):
|
||||
raise NotImplementedError
|
||||
|
||||
# Perform no action if this is not a DNS or DNSRQ packet
|
||||
if not packet.haslayer("DNS") or not packet.haslayer("DNSQR"):
|
||||
return packet
|
||||
|
||||
# Extract the query from the DNSQR layer
|
||||
query = packet["DNSQR"].qname.decode()
|
||||
if query[len(query) - 1] != '.':
|
||||
query += '.'
|
||||
|
||||
# Split the query by label
|
||||
labels = query.split(".")
|
||||
|
||||
# Collect the first and second half of the query
|
||||
fhalf = labels[0]
|
||||
shalf = ".".join(labels[1:])
|
||||
|
||||
# Build the first DNS query directly. The format of this a byte string like this:
|
||||
# b'\x07minghui\xc0\x1a\x00\x01\x00\x01'
|
||||
# \x07 = the length of the label in this DNSQR
|
||||
# minghui = the portion of the domain we will request in the first DNSQR
|
||||
# \xc0\x1a = offset into the DNS packet where the rest of the query will be. The actual offset
|
||||
# here is the \x1a - DNS mandates that if compression is used, the first two bits be 11
|
||||
# to differentiate them from the rest. \x1A = 26, which is the length of the DNS header
|
||||
# plus the length of this DNSQR.
|
||||
# \x00\x01 = type A record
|
||||
# \x00\x01 = IN
|
||||
length = bytes([len(fhalf)])
|
||||
label = fhalf.encode()
|
||||
|
||||
# Since the domain will include an extra ".", add 1
|
||||
# 2 * 6 is the DNS header
|
||||
# 1 is the byte that determines the length of the label
|
||||
# len(label) is the length of the label
|
||||
# 2 is the offset pointer
|
||||
# 4 - other record information (class, IN)
|
||||
packet_offset = 2 * 6 + 1 + len(label) + 2 + 2 + 2
|
||||
|
||||
# The word must start with binary 11, so OR the offset with 0xC000.
|
||||
offset = (0xc000 | packet_offset).to_bytes(2, byteorder='big')
|
||||
request = b'\x00\x01\x00\x01'
|
||||
|
||||
dns_qr1 = length + label + offset + request
|
||||
|
||||
# Build the second DNS query directly. The format of the byte string is the same as above
|
||||
# b'\x02ca\x00\x00\x01\x00\x01'
|
||||
# \x02 = length of the remaining domain
|
||||
# ca = portion of the domain in this DNSQR
|
||||
# \x00 = null byte to signify the end of the query
|
||||
# \x00\x01 = type A record
|
||||
# \x00\x01 = IN
|
||||
# Since the second half could potentially contain many labels, this is done in a list comprehension
|
||||
dns_qr2 = b"".join([bytes([len(tld)]) + tld.encode() for tld in shalf.split(".")]) + b"\x00\x01\x00\x01"
|
||||
|
||||
# Next, we must rebuild the DNS packet itself. If we try to have scapy parse either dns_qr1 or dns_qr2, they
|
||||
# will look malformed, since neither contains a complete request. Therefore, we must build the entire
|
||||
# DNS packet at once. First, we must remove the original DNSQR, since this contains the original request
|
||||
del packet["DNS"].qd
|
||||
|
||||
# Once the DNSQR is removed, scapy automatically sets the qdcount to 0. Adjust it to 2
|
||||
packet["DNS"].qdcount = 2
|
||||
|
||||
# Extract the DNS header standalone now for building
|
||||
dns_header = bytes(packet["DNS"])
|
||||
|
||||
dns_packet = DNS(dns_header + dns_qr1 + dns_qr2)
|
||||
|
||||
del packet["DNS"]
|
||||
packet = packet / dns_packet
|
||||
|
||||
# Since the size and data of the packet have changed, force scapy to recalculate the important fields
|
||||
# in below layers, if applicable
|
||||
if packet.haslayer("IP"):
|
||||
del packet["IP"].chksum
|
||||
del packet["IP"].len
|
||||
if packet.haslayer("UDP"):
|
||||
del packet["UDP"].chksum
|
||||
del packet["UDP"].len
|
||||
|
||||
return packet
|
||||
|
||||
|
||||
class DNSQRLayer(Layer):
|
||||
"""
|
||||
Defines an interface to access DNSQR header fields.
|
||||
"""
|
||||
name = "DNSQR"
|
||||
protocol = DNSQR
|
||||
_fields = [
|
||||
"qname",
|
||||
"qtype",
|
||||
"qclass"
|
||||
]
|
||||
fields = _fields
|
||||
|
||||
def __init__(self, layer):
|
||||
"""
|
||||
Initializes the DNS layer.
|
||||
"""
|
||||
Layer.__init__(self, layer)
|
||||
self.getters = {
|
||||
"qname" : self.get_qname
|
||||
}
|
||||
self.generators = {
|
||||
"qname" : self.gen_qname
|
||||
}
|
||||
|
||||
def get_qname(self, field):
|
||||
"""
|
||||
Returns decoded qname from packet.
|
||||
"""
|
||||
return self.layer.qname.decode('utf-8')
|
||||
|
||||
def gen_qname(self, field):
|
||||
"""
|
||||
Generates domain name.
|
||||
"""
|
||||
return "example.com."
|
||||
|
||||
@classmethod
|
||||
def name_matches(cls, name):
|
||||
"""
|
||||
Scapy returns the name of DNSQR as _both_ DNSQR and "DNS Question Record",
|
||||
which breaks parsing. Override the name_matches method to handle that case
|
||||
here.
|
||||
"""
|
||||
return name.upper() in ["DNSQR", "DNS QUESTION RECORD"]
|
||||
|
|
|
@ -7,7 +7,9 @@ import actions.layer
|
|||
_SUPPORTED_LAYERS = [
|
||||
actions.layer.IPLayer,
|
||||
actions.layer.TCPLayer,
|
||||
actions.layer.UDPLayer
|
||||
actions.layer.UDPLayer,
|
||||
actions.layer.DNSLayer,
|
||||
actions.layer.DNSQRLayer
|
||||
]
|
||||
SUPPORTED_LAYERS = _SUPPORTED_LAYERS
|
||||
|
||||
|
@ -64,9 +66,25 @@ class Packet():
|
|||
@staticmethod
|
||||
def _str_load(packet, protocol):
|
||||
"""
|
||||
Prints packet payload
|
||||
Prints DNS header for now
|
||||
"""
|
||||
return str(packet[protocol].payload)
|
||||
if packet.haslayer("DNS") and packet.haslayer("DNSQR"):
|
||||
res = "%s:%s:%s " % (
|
||||
packet["DNSQR"].qname.decode('utf8'),
|
||||
str(packet["DNSQR"].qtype),
|
||||
str(packet["DNSQR"].qclass))
|
||||
DNS_res = ""
|
||||
for i in range(packet["DNS"].ancount):
|
||||
dnsrr = packet["DNS"].an[i]
|
||||
DNS_res += " " + ':'.join([str(dnsrr.rrname.decode('utf8')),
|
||||
str(dnsrr.type),
|
||||
str(dnsrr.rclass),
|
||||
str(dnsrr.ttl),
|
||||
str(dnsrr.rdlen),
|
||||
str(dnsrr.rdata)])
|
||||
return "%s %s" % (res, DNS_res)
|
||||
else:
|
||||
return str(packet[protocol].payload)
|
||||
|
||||
def __bytes__(self):
|
||||
"""
|
||||
|
@ -238,3 +256,77 @@ class Packet():
|
|||
return layer
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def reset_restrictions():
|
||||
"""
|
||||
Removes layer and field restrictions.
|
||||
"""
|
||||
global SUPPORTED_LAYERS, _SUPPORTED_LAYERS
|
||||
|
||||
SUPPORTED_LAYERS = _SUPPORTED_LAYERS
|
||||
for layer in SUPPORTED_LAYERS:
|
||||
layer.reset_restrictions()
|
||||
|
||||
@staticmethod
|
||||
def restrict_fields(logger, filter_protocols, filter_fields, disable_fields):
|
||||
"""
|
||||
Validates input arguments. Used by evolve.py to restrict the scope
|
||||
of this evolution.
|
||||
"""
|
||||
global SUPPORTED_LAYERS
|
||||
|
||||
if not disable_fields:
|
||||
disable_fields = []
|
||||
|
||||
# First, apply a field whitelist if it was requested
|
||||
valid = []
|
||||
if filter_fields:
|
||||
for layer in SUPPORTED_LAYERS:
|
||||
new_fields = []
|
||||
for field in filter_fields:
|
||||
if field in layer.fields:
|
||||
new_fields.append(field)
|
||||
valid.append(field)
|
||||
layer.fields = new_fields
|
||||
|
||||
if valid and logger:
|
||||
logger.info("Strategies will only be allowed to use fields: %s" % ", ".join(list(set(valid))))
|
||||
elif logger:
|
||||
logger.error("None of the given fields exist in the packet headers of given protocols.")
|
||||
|
||||
# Apply a field blacklist if it was requested
|
||||
for field in disable_fields:
|
||||
for layer in SUPPORTED_LAYERS:
|
||||
layer.fields = [f for f in layer.fields if f not in disable_fields]
|
||||
|
||||
if disable_fields and logger:
|
||||
logger.info("Strategies will not be allowed to use fields %s" % ", ".join(disable_fields))
|
||||
|
||||
allowed_layers = []
|
||||
# Finally, filter protocols
|
||||
for protocol in filter_protocols:
|
||||
allowed_layer = Packet.get_supported_protocol(protocol)
|
||||
if not allowed_layer:
|
||||
if logger:
|
||||
logger.error("%s not a supported protocol." % protocol)
|
||||
continue
|
||||
|
||||
# Only keep the layer allowed if it contains allowed fields
|
||||
if allowed_layer.fields:
|
||||
allowed_layers.append(allowed_layer)
|
||||
|
||||
assert allowed_layers, "Cannot evolve with no available packet layers!"
|
||||
|
||||
SUPPORTED_LAYERS = allowed_layers
|
||||
|
||||
if logger and allowed_layers:
|
||||
logger.info("Strategies will only be allowed to use protocols: %s" % ", ".join([l.name for l in allowed_layers]))
|
||||
|
||||
def dns_decompress(self, logger):
|
||||
"""
|
||||
Performs DNS decompression, if applicable. Returns a new packet.
|
||||
"""
|
||||
self.packet = actions.layer.DNSLayer.dns_decompress(self.packet, logger)
|
||||
self.layers = self.setup_layers()
|
||||
return self
|
||||
|
|
|
@ -2,7 +2,7 @@ from actions.action import Action
|
|||
|
||||
class SleepAction(Action):
|
||||
def __init__(self, time=1, environment_id=None):
|
||||
Action.__init__(self, "sleep", "out")
|
||||
Action.__init__(self, "sleep", "both")
|
||||
self.terminal = False
|
||||
self.branching = False
|
||||
self.time = time
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
import threading
|
||||
import os
|
||||
|
||||
import actions.packet
|
||||
from scapy.all import sniff
|
||||
from scapy.utils import PcapWriter
|
||||
|
||||
|
||||
class Sniffer():
|
||||
"""
|
||||
The sniffer class lets the user begin and end sniffing whenever in a given location with a port to filter on.
|
||||
Call start_sniffing to begin sniffing and stop_sniffing to stop sniffing.
|
||||
"""
|
||||
|
||||
def __init__(self, location, port, logger):
|
||||
"""
|
||||
Intializes a sniffer object.
|
||||
Needs a location and a port to filter on.
|
||||
"""
|
||||
self.stop_sniffing_flag = False
|
||||
self.location = location
|
||||
self.port = port
|
||||
self.pcap_thread = None
|
||||
self.packet_dumper = None
|
||||
self.logger = logger
|
||||
full_path = os.path.dirname(location)
|
||||
assert port, "Need to specify a port in order to launch a sniffer"
|
||||
if not os.path.exists(full_path):
|
||||
os.makedirs(full_path)
|
||||
|
||||
def __packet_callback(self, scapy_packet):
|
||||
"""
|
||||
This callback is called whenever a packet is applied.
|
||||
Returns true if it should finish, otherwise, returns false.
|
||||
"""
|
||||
packet = actions.packet.Packet(scapy_packet)
|
||||
for proto in ["TCP", "UDP"]:
|
||||
if(packet.haslayer(proto) and ((packet[proto].sport == self.port) or (packet[proto].dport == self.port))):
|
||||
break
|
||||
else:
|
||||
return self.stop_sniffing_flag
|
||||
|
||||
self.logger.debug(str(packet))
|
||||
self.packet_dumper.write(scapy_packet)
|
||||
return self.stop_sniffing_flag
|
||||
|
||||
def __spawn_sniffer(self):
|
||||
"""
|
||||
Saves pcaps to a file. Should be run as a thread.
|
||||
Ends when the stop_sniffing_flag is set. Should not be called by user
|
||||
"""
|
||||
self.packet_dumper = PcapWriter(self.location, append=True, sync=True)
|
||||
while(self.stop_sniffing_flag == False):
|
||||
sniff(stop_filter=self.__packet_callback, timeout=1)
|
||||
|
||||
def start_sniffing(self):
|
||||
"""
|
||||
Starts sniffing. Should be called by user.
|
||||
"""
|
||||
self.stop_sniffing_flag = False
|
||||
self.pcap_thread = threading.Thread(target=self.__spawn_sniffer)
|
||||
self.pcap_thread.start()
|
||||
self.logger.debug("Sniffer starting to port %d" % self.port)
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Defines a context manager for this sniffer; simply starts sniffing.
|
||||
"""
|
||||
self.start_sniffing()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
"""
|
||||
Defines exit context manager behavior for this sniffer; simply stops sniffing.
|
||||
"""
|
||||
self.stop_sniffing()
|
||||
|
||||
def stop_sniffing(self):
|
||||
"""
|
||||
Stops the sniffer by setting the flag and calling join
|
||||
"""
|
||||
if(self.pcap_thread):
|
||||
self.stop_sniffing_flag = True
|
||||
self.pcap_thread.join()
|
||||
self.logger.debug("Sniffer stopping")
|
|
@ -2,17 +2,25 @@
|
|||
TamperAction
|
||||
|
||||
One of the four packet-level primitives supported by Geneva. Responsible for any packet-level
|
||||
modifications (particularly header modifications). It supports replace and corrupt mode -
|
||||
in replace mode, it changes a packet field to a fixed value; in corrupt mode, it changes a packet
|
||||
field to a randomly generated value each time it is run.
|
||||
modifications (particularly header modifications). It supports the following primitives:
|
||||
- no operation: it returns the packet given
|
||||
- replace: it changes a packet field to a fixed value
|
||||
- corrupt: it changes a packet field to a randomly generated value each time it is run
|
||||
- add: adds a given value to the value in a field
|
||||
- compress: performs DNS decompression on the packet (if applicable)
|
||||
"""
|
||||
|
||||
from actions.action import Action
|
||||
import actions.utils
|
||||
from actions.layer import DNSLayer
|
||||
|
||||
import random
|
||||
|
||||
|
||||
# All supported tamper primitives
|
||||
SUPPORTED_PRIMITIVES = ["corrupt", "replace", "add", "compress"]
|
||||
|
||||
|
||||
class TamperAction(Action):
|
||||
"""
|
||||
Defines the TamperAction for Geneva.
|
||||
|
@ -23,10 +31,7 @@ class TamperAction(Action):
|
|||
self.tamper_value = tamper_value
|
||||
self.tamper_proto = actions.utils.string_to_protocol(tamper_proto)
|
||||
self.tamper_proto_str = tamper_proto
|
||||
|
||||
self.tamper_type = tamper_type
|
||||
if not self.tamper_type:
|
||||
self.tamper_type = random.choice(["corrupt", "replace"])
|
||||
|
||||
def tamper(self, packet, logger):
|
||||
"""
|
||||
|
@ -41,8 +46,19 @@ class TamperAction(Action):
|
|||
|
||||
new_value = self.tamper_value
|
||||
# If corrupting the packet field, generate a value for it
|
||||
if self.tamper_type == "corrupt":
|
||||
new_value = packet.gen(self.tamper_proto_str, self.field)
|
||||
try:
|
||||
if self.tamper_type == "corrupt":
|
||||
new_value = packet.gen(self.tamper_proto_str, self.field)
|
||||
elif self.tamper_type == "add":
|
||||
new_value = int(self.tamper_value) + int(old_value)
|
||||
elif self.tamper_type == "compress":
|
||||
return packet.dns_decompress(logger)
|
||||
except NotImplementedError:
|
||||
# If a primitive does not support the type of packet given
|
||||
return packet
|
||||
except Exception:
|
||||
# If an unexpected error has occurred
|
||||
return packet
|
||||
|
||||
logger.debug(" - Tampering %s field `%s` (%s) by %s (to %s)" %
|
||||
(self.tamper_proto_str, self.field, str(old_value), self.tamper_type, str(new_value)))
|
||||
|
@ -67,8 +83,10 @@ class TamperAction(Action):
|
|||
s = Action.__str__(self)
|
||||
if self.tamper_type == "corrupt":
|
||||
s += "{%s:%s:%s}" % (self.tamper_proto_str, self.field, self.tamper_type)
|
||||
elif self.tamper_type in ["replace"]:
|
||||
elif self.tamper_type in ["replace", "add"]:
|
||||
s += "{%s:%s:%s:%s}" % (self.tamper_proto_str, self.field, self.tamper_type, self.tamper_value)
|
||||
elif self.tamper_type == "compress":
|
||||
s += "{%s:%s:compress}" % ("DNS", "qd", )
|
||||
|
||||
return s
|
||||
|
||||
|
|
|
@ -65,14 +65,15 @@ class TraceAction(Action):
|
|||
"""
|
||||
Parses a string representation for this object.
|
||||
"""
|
||||
if not string:
|
||||
return False
|
||||
try:
|
||||
if string:
|
||||
self.start_ttl, self.end_ttl = string.split(":")
|
||||
self.start_ttl = int(self.start_ttl)
|
||||
self.end_ttl = int(self.end_ttl)
|
||||
if self.start_ttl > self.end_ttl:
|
||||
logger.error("Cannot use a trace with a start ttl greater than end_ttl (%d > %d)" % (self.start_ttl, self.end_ttl))
|
||||
return False
|
||||
self.start_ttl, self.end_ttl = string.split(":")
|
||||
self.start_ttl = int(self.start_ttl)
|
||||
self.end_ttl = int(self.end_ttl)
|
||||
if self.start_ttl > self.end_ttl:
|
||||
logger.error("Cannot use a trace with a start ttl greater than end_ttl (%d > %d)" % (self.start_ttl, self.end_ttl))
|
||||
return False
|
||||
except ValueError:
|
||||
logger.exception("Cannot parse ttls from given data %s" % string)
|
||||
return False
|
||||
|
|
|
@ -30,22 +30,6 @@ class ActionTree():
|
|||
self.environment_id = None
|
||||
self.ran = False
|
||||
|
||||
def initialize(self, num_actions, environment_id, allow_terminal=True, disabled=None):
|
||||
"""
|
||||
Sets up this action tree with a given number of random actions.
|
||||
Note that the returned action trees may have less actions than num_actions
|
||||
if terminal actions are used.
|
||||
"""
|
||||
self.environment_id = environment_id
|
||||
self.trigger = actions.trigger.Trigger(None, None, None, environment_id=environment_id)
|
||||
if not allow_terminal or random.random() > 0.1:
|
||||
allow_terminal = False
|
||||
|
||||
for _ in range(num_actions):
|
||||
new_action = self.get_rand_action(self.direction, disabled=disabled)
|
||||
self.add_action(new_action)
|
||||
return self
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
Sets up a preoder iterator for the tree.
|
||||
|
|
|
@ -27,20 +27,6 @@ class Trigger(object):
|
|||
self.bomb_trigger = bool(gas and gas < 0)
|
||||
self.ran = False
|
||||
|
||||
@staticmethod
|
||||
def get_gas():
|
||||
"""
|
||||
Returns a random value for gas for this trigger.
|
||||
"""
|
||||
if GAS_ENABLED and random.random() < 0.2:
|
||||
# Use gas in 20% of scenarios
|
||||
# Pick a number for gas between 0 - 5
|
||||
gas_remaining = int(random.random() * 5)
|
||||
else:
|
||||
# Do not use gas
|
||||
gas_remaining = None
|
||||
return gas_remaining
|
||||
|
||||
def is_applicable(self, packet, logger):
|
||||
"""
|
||||
Checks if this trigger is applicable to a given packet.
|
||||
|
|
|
@ -119,7 +119,7 @@ def get_logger(basepath, log_dir, logger_name, log_name, environment_id, log_lev
|
|||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter)
|
||||
ch.setLevel(log_level)
|
||||
CONSOLE_LOG_LEVEL = log_level
|
||||
CONSOLE_LOG_LEVEL = ch.level
|
||||
logger.addHandler(ch)
|
||||
return logger
|
||||
|
||||
|
@ -135,34 +135,6 @@ def close_logger(logger):
|
|||
handler.close()
|
||||
|
||||
|
||||
class Logger():
|
||||
"""
|
||||
Logging class context manager, as a thin wrapper around the logging class to help
|
||||
handle closing open file descriptors.
|
||||
"""
|
||||
def __init__(self, log_dir, logger_name, log_name, environment_id, log_level=logging.DEBUG):
|
||||
self.log_dir = log_dir
|
||||
self.logger_name = logger_name
|
||||
self.log_name = log_name
|
||||
self.environment_id = environment_id
|
||||
self.log_level = log_level
|
||||
self.logger = None
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Sets up a logger.
|
||||
"""
|
||||
self.logger = get_logger(PROJECT_ROOT, self.log_dir, self.logger_name, self.log_name, self.environment_id, log_level=self.log_level)
|
||||
return self.logger
|
||||
|
||||
def __exit__(self, exc_type, exc_value, tb):
|
||||
"""
|
||||
Closes file handles.
|
||||
"""
|
||||
close_logger(self.logger)
|
||||
|
||||
|
||||
|
||||
def get_console_log_level():
|
||||
"""
|
||||
returns log level of console handler
|
||||
|
@ -205,18 +177,6 @@ def setup_dirs(output_dir):
|
|||
return ga_log_dir
|
||||
|
||||
|
||||
def get_from_fuzzed_or_real_packet(environment_id, real_packet_probability, enable_options=True, enable_load=True):
|
||||
"""
|
||||
Retrieves a protocol, field, and value from a fuzzed or real packet, depending on
|
||||
the given probability and if given packets is not None.
|
||||
"""
|
||||
packets = actions.utils.read_packets(environment_id)
|
||||
if packets and random.random() < real_packet_probability:
|
||||
packet = random.choice(packets)
|
||||
return packet.get_random()
|
||||
return actions.packet.Packet().gen_random()
|
||||
|
||||
|
||||
def get_interface():
|
||||
"""
|
||||
Chooses an interface on the machine to use for socket testing.
|
||||
|
|
|
@ -34,8 +34,6 @@ class Engine():
|
|||
self.server_port = server_port
|
||||
self.seen_packets = []
|
||||
# Set up the directory and ID for logging
|
||||
if not output_directory:
|
||||
output_directory = "trials"
|
||||
actions.utils.setup_dirs(output_directory)
|
||||
if not environment_id:
|
||||
environment_id = actions.utils.get_id()
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
# Add the path to the engine so we can import it
|
||||
BASEPATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(BASEPATH)
|
||||
|
||||
import engine
|
||||
|
||||
def test_engine():
|
||||
"""
|
||||
Basic engine test
|
||||
"""
|
||||
# Port to run the engine on
|
||||
port = 80
|
||||
# Strategy to use
|
||||
strategy = "[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:corrupt},),)-| \/"
|
||||
|
||||
# Create the engine in debug mode
|
||||
with engine.Engine(port, strategy, log_level="debug") as eng:
|
||||
os.system("curl http://example.com?q=ultrasurf")
|
||||
|
||||
|
||||
def test_engine_sleep():
|
||||
"""
|
||||
Basic engine test with sleep action
|
||||
"""
|
||||
# Port to run the engine on
|
||||
port = 80
|
||||
# Strategy to use
|
||||
strategy = "[TCP:flags:S]-sleep{1}-|"
|
||||
|
||||
# Create the engine in debug mode
|
||||
with engine.Engine(port, strategy, log_level="info") as eng:
|
||||
os.system("curl http://example.com?q=ultrasurf")
|
||||
|
||||
# Strategy to use in opposite direction
|
||||
strategy = "\/ [TCP:flags:SA]-sleep{1}-|"
|
||||
|
||||
# Create the engine in debug mode
|
||||
with engine.Engine(port, strategy, log_level="debug") as eng:
|
||||
os.system("curl http://example.com?q=ultrasurf")
|
||||
|
||||
|
||||
|
||||
def test_engine_trace():
|
||||
"""
|
||||
Basic engine test with trace
|
||||
"""
|
||||
# Port to run the engine on
|
||||
port = 80
|
||||
# Strategy to use
|
||||
strategy = "[TCP:flags:PA]-trace{2:10}-|"
|
||||
|
||||
# Create the engine in debug mode
|
||||
with engine.Engine(port, strategy, log_level="debug") as eng:
|
||||
os.system("curl -m 5 http://example.com?q=ultrasurf")
|
||||
|
||||
|
||||
def test_engine_drop():
|
||||
"""
|
||||
Basic engine test with drop
|
||||
"""
|
||||
# Port to run the engine on
|
||||
port = 80
|
||||
# Strategy to use
|
||||
strategy = "\/ [TCP:flags:SA]-drop-|"
|
||||
|
||||
# Create the engine in debug mode
|
||||
with engine.Engine(port, strategy, log_level="debug") as eng:
|
||||
os.system("curl -m 3 http://example.com?q=ultrasurf")
|
||||
|
|
@ -0,0 +1,220 @@
|
|||
import logging
|
||||
import pytest
|
||||
import sys
|
||||
# Include the root of the project
|
||||
sys.path.append("..")
|
||||
|
||||
import actions.fragment
|
||||
import actions.packet
|
||||
import actions.strategy
|
||||
import actions.utils
|
||||
|
||||
from scapy.all import IP, TCP, UDP
|
||||
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
|
||||
def test_segment():
|
||||
"""
|
||||
Tests the duplicate action primitive.
|
||||
"""
|
||||
fragment = actions.fragment.FragmentAction(correct_order=True)
|
||||
assert str(fragment) == "fragment{tcp:-1:True}", "Fragment returned incorrect string representation: %s" % str(fragment)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP()/("data"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
|
||||
assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different"
|
||||
assert packet1["Raw"].load == b'da', "Left packet incorrectly fragmented"
|
||||
assert packet2["Raw"].load == b"ta", "Right packet incorrectly fragmented"
|
||||
|
||||
|
||||
def test_segment_reverse():
|
||||
"""
|
||||
Tests the duplicate action primitive in reverse!
|
||||
"""
|
||||
fragment = actions.fragment.FragmentAction(correct_order=False)
|
||||
assert str(fragment) == "fragment{tcp:-1:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP()/("data"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
|
||||
assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different"
|
||||
assert packet1["Raw"].load == b'ta', "Left packet incorrectly fragmented"
|
||||
assert packet2["Raw"].load == b"da", "Right packet incorrectly fragmented"
|
||||
|
||||
|
||||
def test_odd_fragment():
|
||||
"""
|
||||
Tests long IP fragmentation
|
||||
"""
|
||||
|
||||
fragment = actions.fragment.FragmentAction(correct_order=True, segment=False)
|
||||
assert str(fragment) == "fragment{ip:-1:True}", "Fragment returned incorrect string representation: %s" % str(fragment)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("dataisodd"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
|
||||
assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different"
|
||||
assert packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d', "Left packet incorrectly fragmented"
|
||||
assert packet2["Raw"].load == b'\x00\x00\x00dP\x02 \x00e\xc1\x00\x00dataisodd', "Right packet incorrectly fragmented"
|
||||
assert packet1["Raw"].load + packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00e\xc1\x00\x00dataisodd', "Packets fragmentation was incorrect"
|
||||
|
||||
|
||||
def test_custom_fragment():
|
||||
"""
|
||||
Tests IP fragments with custom sized lengths
|
||||
"""
|
||||
|
||||
fragment = actions.fragment.FragmentAction(correct_order=True, fragsize=3, segment=False)
|
||||
assert str(fragment) == "fragment{ip:3:True}", "Fragment returned incorrect string representation: %s" % str(fragment)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("thisissomedata"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different"
|
||||
assert packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00this', "Left packet incorrectly fragmented"
|
||||
assert packet2["Raw"].load == b'issomedata', "Right packet incorrectly fragmented"
|
||||
assert packet1["Raw"].load + packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00thisissomedata', "Packets fragmentation was incorrect"
|
||||
|
||||
|
||||
def test_reverse_fragment():
|
||||
"""
|
||||
Tests fragmentation with reversed packets
|
||||
"""
|
||||
|
||||
fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=3, segment=False)
|
||||
assert str(fragment) == "fragment{ip:3:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S")/("thisissomedata"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different"
|
||||
assert packet2["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00this', "Left packet incorrectly fragmented"
|
||||
assert packet1["Raw"].load == b'issomedata', "Right packet incorrectly fragmented"
|
||||
assert packet2["Raw"].load + packet1["Raw"].load == b'\x08\xae\r\x05\x00\x00\x00d\x00\x00\x00dP\x02 \x00zp\x00\x00thisissomedata', "Packets fragmentation was incorrect"
|
||||
|
||||
|
||||
def test_udp_fragment():
|
||||
"""
|
||||
Tests fragmentation with reversed packets
|
||||
"""
|
||||
|
||||
fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False)
|
||||
assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
assert str(packet1["Raw"].load) != str(packet2["Raw"].load), "Packets were not different"
|
||||
|
||||
|
||||
def test_parse():
|
||||
"""
|
||||
Tests parsing.
|
||||
"""
|
||||
fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False)
|
||||
assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
|
||||
|
||||
fragment.parse("fragment{tcp:5:False}", logger)
|
||||
assert fragment.correct_order == False
|
||||
assert fragment.fragsize == 5
|
||||
assert fragment.segment == True
|
||||
|
||||
with pytest.raises(Exception):
|
||||
fragment.parse("fragment{tcp:5}", logger)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
fragment.parse("fragment{tcp:a:True}", logger)
|
||||
|
||||
assert fragment.correct_order == False
|
||||
assert fragment.fragsize == 5
|
||||
assert fragment.segment == True
|
||||
|
||||
fragment = actions.fragment.FragmentAction()
|
||||
assert fragment.correct_order in [True, False]
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
|
||||
strat = actions.utils.parse("[IP:proto:6:0]-tamper{IP:proto:replace:6}(fragment{ip:-1:True}(tamper{TCP:dataofs:replace:8}(duplicate,),tamper{IP:frag:replace:0}),)-| [IP:tos:0:0]-duplicate-| \/", logger)
|
||||
strat.act_on_packet(packet, logger)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/UDP(sport=2222, dport=3333, chksum=0x4444))
|
||||
strat = actions.utils.parse("[IP:proto:6:0]-tamper{IP:proto:replace:6}(fragment{ip:-1:True}(tamper{TCP:dataofs:replace:8}(duplicate,),tamper{IP:frag:replace:0}),)-| [IP:tos:0:0]-duplicate-| \/", logger)
|
||||
strat.act_on_packet(packet, logger)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, chksum=0x4444))
|
||||
strat = actions.utils.parse("[TCP:urgptr:0]-tamper{TCP:options-altchksumopt:corrupt}(fragment{tcp:-1:True}(tamper{IP:proto:corrupt},tamper{TCP:seq:replace:654077552}),)-| \/", logger)
|
||||
strat.act_on_packet(packet, logger)
|
||||
|
||||
strat = actions.utils.parse("[TCP:options-mss:]-tamper{TCP:load:replace:}(fragment{tcp:-1:True},)-| \/", logger)
|
||||
strat.act_on_packet(packet, logger)
|
||||
|
||||
strat = actions.utils.parse("[TCP:options-mss:]-tamper{IP:frag:replace:1353}(tamper{TCP:load:replace:}(fragment{tcp:-1:True},),)-| \/", logger)
|
||||
strat.act_on_packet(packet, logger)
|
||||
|
||||
strat = actions.utils.parse("[IP:ihl:5]-duplicate-| [TCP:options-mss:]-tamper{IP:frag:replace:1353}(fragment{tcp:-1:True}(tamper{TCP:load:replace:}(fragment{tcp:-1:False},),tamper{DNSQR:qtype:replace:45416}),)-| \/", logger)
|
||||
strat.act_on_packet(packet, logger)
|
||||
|
||||
strat = actions.utils.parse("[DNSQR:qclass:25989]-duplicate(duplicate(tamper{DNSQR:qtype:replace:30882},),tamper{UDP:sport:replace:42042})-| [TCP:options-nop:]-tamper{TCP:options-nop:corrupt}(tamper{TCP:load:replace:mjkuskjzgy}(tamper{IP:frag:replace:410}(fragment{tcp:-1:True},),),)-| \/", logger)
|
||||
strat.act_on_packet(packet, logger)
|
||||
|
||||
|
||||
def test_fallback():
|
||||
"""
|
||||
Tests fallback behavior.
|
||||
"""
|
||||
fragment = actions.fragment.FragmentAction(correct_order=False, fragsize=2, segment=False)
|
||||
assert str(fragment) == "fragment{ip:2:False}", "Fragment returned incorrect string representation: %s" % str(fragment)
|
||||
|
||||
fragment.parse("fragment{ip:0:False}", logger)
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
assert str(packet1) == str(packet2)
|
||||
|
||||
fragment.parse("fragment{tcp:-1:False}", logger)
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/UDP(sport=2222, dport=3333, chksum=0x4444)/("thisissomedata"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
assert str(packet1) == str(packet2)
|
||||
|
||||
fragment.parse("fragment{tcp:-1:False}", logger)
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06)/TCP(sport=2222, dport=3333, chksum=0x4444))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
assert str(packet1) == str(packet2)
|
||||
|
||||
fragment.parse("fragment{ip:-1:False}", logger)
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1", proto=0x06))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
assert str(packet1) == str(packet2)
|
||||
|
||||
|
||||
def test_ip_only_fragment():
|
||||
"""
|
||||
Tests fragmentation without higher protocols.
|
||||
"""
|
||||
|
||||
fragment = actions.fragment.FragmentAction(correct_order=True)
|
||||
fragment.parse("fragment{ip:-1:True}", logger)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/("datadata11datadata"))
|
||||
packet1, packet2 = fragment.run(packet, logger)
|
||||
|
||||
assert id(packet1) != id(packet2), "Duplicate aliased packet objects"
|
||||
|
||||
assert packet1["Raw"].load != packet2["Raw"].load, "Packets were not different"
|
||||
assert packet1["Raw"].load == b'datadata', "Left packet incorrectly fragmented"
|
||||
assert packet2["Raw"].load == b"11datadata", "Right packet incorrectly fragmented"
|
||||
|
||||
|
|
@ -0,0 +1,544 @@
|
|||
import logging
|
||||
import pytest
|
||||
|
||||
import actions.packet
|
||||
import actions.layer
|
||||
|
||||
from scapy.all import IP, TCP, UDP, DNS, DNSQR, Raw, DNSRR
|
||||
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
|
||||
def test_parse_layers():
|
||||
"""
|
||||
Tests layer parsing.
|
||||
"""
|
||||
pkt = IP()/TCP()/Raw("")
|
||||
packet = actions.packet.Packet(pkt)
|
||||
layers = list(packet.read_layers())
|
||||
assert layers[0].name == "IP"
|
||||
assert layers[1].name == "TCP"
|
||||
|
||||
layers_dict = packet.setup_layers()
|
||||
assert layers_dict["IP"]
|
||||
assert layers_dict["TCP"]
|
||||
|
||||
|
||||
def test_get_random():
|
||||
"""
|
||||
Tests get random
|
||||
"""
|
||||
|
||||
tcplayer = actions.layer.TCPLayer(TCP())
|
||||
field, value = tcplayer.get_random()
|
||||
assert field in actions.layer.TCPLayer.fields
|
||||
|
||||
|
||||
def test_gen_random():
|
||||
"""
|
||||
Tests gen random
|
||||
"""
|
||||
for i in range(0, 2000):
|
||||
layer, field, value = actions.packet.Packet().gen_random()
|
||||
assert layer in [DNS, TCP, UDP, IP, DNSQR]
|
||||
|
||||
|
||||
def test_dnsqr():
|
||||
"""
|
||||
Tests DNSQR.
|
||||
"""
|
||||
pkt = UDP()/DNS(ancount=1)/DNSQR()
|
||||
pkt.show()
|
||||
packet = actions.packet.Packet(pkt)
|
||||
packet.show()
|
||||
assert len(packet.layers) == 3
|
||||
assert "UDP" in packet.layers
|
||||
assert "DNS" in packet.layers
|
||||
assert "DNSQR" in packet.layers
|
||||
pkt = IP()/UDP()/DNS()/DNSQR()
|
||||
packet = actions.packet.Packet(pkt)
|
||||
assert str(packet)
|
||||
|
||||
|
||||
def test_load():
|
||||
"""
|
||||
Tests loads.
|
||||
"""
|
||||
tcp = actions.layer.TCPLayer(TCP())
|
||||
assert tcp.gen("load")
|
||||
pkt = IP()/"datadata"
|
||||
p = actions.packet.Packet(pkt)
|
||||
assert p.get("IP", "load") == "datadata"
|
||||
p2 = actions.packet.Packet(IP(bytes(p)))
|
||||
assert p2.get("IP", "load") == "datadata"
|
||||
p2.set("IP", "load", "data2")
|
||||
# Check p is unchanged
|
||||
assert p.get("IP", "load") == "datadata"
|
||||
assert p2.get("IP", "load") == "data2"
|
||||
p2.show2()
|
||||
# Check that we can dump
|
||||
assert p2.show2(dump=True)
|
||||
# Check that we can dump
|
||||
assert p2.show(dump=True)
|
||||
assert p2.get("IP", "chksum") == None
|
||||
|
||||
pkt = IP()/TCP()/"datadata"
|
||||
p = actions.packet.Packet(pkt)
|
||||
assert p.get("TCP", "load") == "datadata"
|
||||
p2 = actions.packet.Packet(IP(bytes(p)))
|
||||
assert p2.get("TCP", "load") == "datadata"
|
||||
p2.set("TCP", "load", "data2")
|
||||
# Check p is unchanged
|
||||
assert p.get("TCP", "load") == "datadata"
|
||||
assert p2.get("TCP", "load") == "data2"
|
||||
p2.show2()
|
||||
assert p2.get("IP", "chksum") == None
|
||||
|
||||
|
||||
def test_parse_load():
|
||||
"""
|
||||
Tests load parsing.
|
||||
"""
|
||||
pkt = actions.packet.Packet(IP()/TCP()/"TYPE A\r\n")
|
||||
print("Parsed: %s" % pkt.get("TCP", "load"))
|
||||
|
||||
strat = actions.utils.parse("[TCP:load:TYPE%20A%0D%0A]-drop-| \/", logger)
|
||||
results = strat.act_on_packet(pkt, logger)
|
||||
assert not results
|
||||
|
||||
value = pkt.gen("TCP", "load") + " " + pkt.gen("TCP", "load")
|
||||
pkt.set("TCP", "load", value)
|
||||
assert " " not in pkt.get("TCP", "load"), "%s contained a space!" % pkt.get("TCP", "load")
|
||||
|
||||
|
||||
def test_dns():
|
||||
"""
|
||||
Tests DNS layer.
|
||||
"""
|
||||
dns = actions.layer.DNSLayer(DNS())
|
||||
print(dns.gen("id"))
|
||||
assert dns.gen("id")
|
||||
|
||||
p = actions.packet.Packet(DNS(id=0xabcd))
|
||||
p2 = actions.packet.Packet(DNS(bytes(p)))
|
||||
assert p.get("DNS", "id") == 0xabcd
|
||||
assert p2.get("DNS", "id") == 0xabcd
|
||||
|
||||
p2.set("DNS", "id", 0x4321)
|
||||
assert p.get("DNS", "id") == 0xabcd # Check p is unchanged
|
||||
assert p2.get("DNS", "id") == 0x4321
|
||||
|
||||
dns = actions.packet.Packet(DNS(aa=1))
|
||||
assert dns.get("DNS", "aa") == 1
|
||||
aa = dns.gen("DNS", "aa")
|
||||
assert aa == 0 or aa == 1
|
||||
assert dns.get("DNS", "aa") == 1 # Original value unchanged
|
||||
|
||||
dns = actions.packet.Packet(DNS(opcode=15))
|
||||
assert dns.get("DNS", "opcode") == 15
|
||||
opcode = dns.gen("DNS", "opcode")
|
||||
assert opcode >= 0 and opcode <= 15
|
||||
assert dns.get("DNS", "opcode") == 15 # Original value unchanged
|
||||
|
||||
dns.set("DNS", "opcode", 3)
|
||||
assert dns.get("DNS", "opcode") == 3
|
||||
|
||||
dns = actions.packet.Packet(DNS(qr=0))
|
||||
assert dns.get("DNS", "qr") == 0
|
||||
qr = dns.gen("DNS", "qr")
|
||||
assert qr == 0 or qr == 1
|
||||
assert dns.get("DNS", "qr") == 0 # Original value unchanged
|
||||
|
||||
dns.set("DNS", "qr", 1)
|
||||
assert dns.get("DNS", "qr") == 1
|
||||
|
||||
dns = actions.packet.Packet(DNS(arcount=0xAABB))
|
||||
assert dns.get("DNS", "arcount") == 0xAABB
|
||||
arcount = dns.gen("DNS", "arcount")
|
||||
assert arcount >= 0 and arcount <= 0xffff
|
||||
assert dns.get("DNS", "arcount") == 0xAABB # Original value unchanged
|
||||
|
||||
dns.set("DNS", "arcount", 65432)
|
||||
assert dns.get("DNS", "arcount") == 65432
|
||||
|
||||
dns = actions.layer.DNSLayer(DNS()/DNSQR(qname="example.com"))
|
||||
assert isinstance(dns.get_next_layer(), DNSQR)
|
||||
print(dns.gen("id"))
|
||||
assert dns.gen("id")
|
||||
|
||||
p = actions.packet.Packet(DNS(id=0xabcd))
|
||||
p2 = actions.packet.Packet(DNS(bytes(p)))
|
||||
assert p.get("DNS", "id") == 0xabcd
|
||||
assert p2.get("DNS", "id") == 0xabcd
|
||||
|
||||
|
||||
def test_read_layers():
|
||||
"""
|
||||
Tests the ability to read each layer
|
||||
"""
|
||||
packet = IP() / UDP() / TCP() / DNS() / DNSQR(qname="example.com") / DNSQR(qname="example2.com") / DNSQR(qname="example3.com")
|
||||
packet_geneva = actions.packet.Packet(packet)
|
||||
packet_geneva.setup_layers()
|
||||
|
||||
i = 0
|
||||
for layer in packet_geneva.read_layers():
|
||||
if i == 0:
|
||||
assert isinstance(layer, actions.layer.IPLayer)
|
||||
elif i == 1:
|
||||
assert isinstance(layer, actions.layer.UDPLayer)
|
||||
elif i == 2:
|
||||
assert isinstance(layer, actions.layer.TCPLayer)
|
||||
elif i == 3:
|
||||
assert isinstance(layer, actions.layer.DNSLayer)
|
||||
elif i == 4:
|
||||
assert isinstance(layer, actions.layer.DNSQRLayer)
|
||||
assert layer.layer.qname == b"example.com"
|
||||
elif i == 5:
|
||||
assert isinstance(layer, actions.layer.DNSQRLayer)
|
||||
assert layer.layer.qname == b"example2.com"
|
||||
elif i == 6:
|
||||
assert isinstance(layer, actions.layer.DNSQRLayer)
|
||||
assert layer.layer.qname == b"example3.com"
|
||||
i += 1
|
||||
|
||||
def test_multi_opts():
|
||||
"""
|
||||
Tests various option getting/setting.
|
||||
"""
|
||||
pkt = IP()/TCP(options=[('MSS', 1460), ('SAckOK', b''), ('Timestamp', (4154603075, 0)), ('NOP', None), ('WScale', 7)])
|
||||
packet = actions.packet.Packet(pkt)
|
||||
assert packet.get("TCP", "options-sackok") == ''
|
||||
assert packet.get("TCP", "options-mss") == 1460
|
||||
assert packet.get("TCP", "options-timestamp") == 4154603075
|
||||
assert packet.get("TCP", "options-wscale") == 7
|
||||
packet.set("TCP", "options-timestamp", 400000000)
|
||||
assert packet.get("TCP", "options-sackok") == ''
|
||||
assert packet.get("TCP", "options-mss") == 1460
|
||||
assert packet.get("TCP", "options-timestamp") == 400000000
|
||||
assert packet.get("TCP", "options-wscale") == 7
|
||||
pkt = IP()/TCP(options=[('SAckOK', b''), ('Timestamp', (4154603075, 0)), ('NOP', None), ('WScale', 7)])
|
||||
packet = actions.packet.Packet(pkt)
|
||||
# If the option isn't present, it will be returned as an empty string
|
||||
assert packet.get("TCP", "options-mss") == ''
|
||||
packet.set("TCP", "options-mss", "")
|
||||
assert packet.get("TCP", "options-mss") == 0
|
||||
|
||||
|
||||
def test_options_eol():
|
||||
"""
|
||||
Tests options-eol.
|
||||
"""
|
||||
pkt = TCP(options=[("EOL", None)])
|
||||
p = actions.packet.Packet(pkt)
|
||||
assert p.get("TCP", "options-eol") == ""
|
||||
p2 = actions.packet.Packet(TCP(bytes(p)))
|
||||
assert p2.get("TCP", "options-eol") == ""
|
||||
p = actions.packet.Packet(IP()/TCP(options=[]))
|
||||
assert p.get("TCP", "options-eol") == ""
|
||||
p.set("TCP", "options-eol", "")
|
||||
p.show()
|
||||
assert len(p["TCP"].options) == 1
|
||||
assert any(k == "EOL" for k, v in p["TCP"].options)
|
||||
value = p.gen("TCP", "options-eol")
|
||||
assert value == "", "eol cannot store data"
|
||||
p.set("TCP", "options-eol", value)
|
||||
p2 = TCP(bytes(p))
|
||||
assert any(k == "EOL" for k, v in p2["TCP"].options)
|
||||
|
||||
|
||||
def test_options_mss():
|
||||
"""
|
||||
Tests options-eol.
|
||||
"""
|
||||
pkt = TCP(options=[("MSS", 1440)])
|
||||
p = actions.packet.Packet(pkt)
|
||||
assert p.get("TCP", "options-mss") == 1440
|
||||
p2 = actions.packet.Packet(TCP(bytes(p)))
|
||||
assert p2.get("TCP", "options-mss") == 1440
|
||||
p = actions.packet.Packet(TCP(options=[]))
|
||||
assert p.get("TCP", "options-mss") == ""
|
||||
p.set("TCP", "options-mss", 2880)
|
||||
p.show()
|
||||
assert len(p["TCP"].options) == 1
|
||||
assert any(k == "MSS" for k, v in p["TCP"].options)
|
||||
value = p.gen("TCP", "options-mss")
|
||||
p.set("TCP", "options-mss", value)
|
||||
p2 = TCP(bytes(p))
|
||||
assert any(k == "MSS" for k, v in p2["TCP"].options)
|
||||
|
||||
|
||||
def check_get(protocol, field, value):
|
||||
"""
|
||||
Checks if the get method worked for this protocol, field, and value.
|
||||
"""
|
||||
pkt = protocol()
|
||||
setattr(pkt, field, value)
|
||||
packet = actions.packet.Packet(pkt)
|
||||
assert packet.get(protocol.__name__, field) == value
|
||||
|
||||
|
||||
def get_test_configs():
|
||||
"""
|
||||
Generates test configurations for the getters.
|
||||
"""
|
||||
return [
|
||||
(IP, 'version', 4),
|
||||
(IP, 'version', 6),
|
||||
(IP, 'version', 0),
|
||||
(IP, 'ihl', 0),
|
||||
(IP, 'tos', 0),
|
||||
(IP, 'len', 50),
|
||||
(IP, 'len', 6),
|
||||
(IP, 'flags', 'MF'),
|
||||
(IP, 'flags', 'DF'),
|
||||
(IP, 'flags', 'MF+DF'),
|
||||
(IP, 'ttl', 25),
|
||||
(IP, 'proto', 4),
|
||||
(IP, 'chksum', 0x4444),
|
||||
(IP, 'src', '127.0.0.1'),
|
||||
(IP, 'dst', '127.0.0.1'),
|
||||
(TCP, 'sport', 12345),
|
||||
(TCP, 'dport', 55555),
|
||||
(TCP, 'seq', 123123123),
|
||||
(TCP, 'ack', 181818181),
|
||||
(TCP, 'dataofs', 5),
|
||||
(TCP, 'dataofs', 0),
|
||||
(TCP, 'dataofs', 15),
|
||||
(TCP, 'reserved', 0),
|
||||
(TCP, 'window', 100),
|
||||
(TCP, 'chksum', 0x4444),
|
||||
(TCP, 'urgptr', 1),
|
||||
|
||||
(DNS, 'id', 0xabcd),
|
||||
(DNS, 'qr', 1),
|
||||
(DNS, 'opcode', 9),
|
||||
(DNS, 'aa', 0),
|
||||
(DNS, 'tc', 1),
|
||||
(DNS, 'rd', 0),
|
||||
(DNS, 'ra', 1),
|
||||
(DNS, 'z', 0),
|
||||
(DNS, 'ad', 1),
|
||||
(DNS, 'cd', 0),
|
||||
(DNS, 'qdcount', 0x1234),
|
||||
(DNS, 'ancount', 12345),
|
||||
(DNS, 'nscount', 49870),
|
||||
(DNS, 'arcount', 0xABCD),
|
||||
|
||||
(DNSQR, 'qname', 'example.com.'),
|
||||
(DNSQR, 'qtype', 1),
|
||||
(DNSQR, 'qclass', 0),
|
||||
]
|
||||
|
||||
|
||||
def get_custom_configs():
|
||||
"""
|
||||
Generates test configurations that can use the custom getters.
|
||||
"""
|
||||
return [
|
||||
(IP, 'flags', ''),
|
||||
(TCP, 'options-eol', ''),
|
||||
(TCP, 'options-nop', ''),
|
||||
(TCP, 'options-mss', 0),
|
||||
(TCP, 'options-mss', 1440),
|
||||
(TCP, 'options-mss', 5000),
|
||||
(TCP, 'options-wscale', 20),
|
||||
(TCP, 'options-sackok', ''),
|
||||
(TCP, 'options-sack', ''),
|
||||
(TCP, 'options-timestamp', 12345678),
|
||||
(TCP, 'options-altchksum', 0x44),
|
||||
(TCP, 'options-altchksumopt', ''),
|
||||
(TCP, 'options-uto', 1),
|
||||
#(TCP, 'options-md5header', 'deadc0ffee')
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", get_test_configs(),
|
||||
ids=['%s-%s-%s' % (proto.__name__, field, str(val)) for proto, field, val in get_test_configs()])
|
||||
def test_get(config):
|
||||
"""
|
||||
Tests value retrieval.
|
||||
"""
|
||||
proto, field, val = config
|
||||
check_get(proto, field, val)
|
||||
|
||||
|
||||
def check_set_get(protocol, field, value):
|
||||
"""
|
||||
Checks if the get method worked for this protocol, field, and value.
|
||||
"""
|
||||
pkt = actions.packet.Packet(protocol())
|
||||
pkt.set(protocol.__name__, field, value)
|
||||
assert pkt.get(protocol.__name__, field) == value
|
||||
# Rebuild the packet to confirm the type survived
|
||||
pkt2 = actions.packet.Packet(protocol(bytes(pkt)))
|
||||
assert pkt2.get(protocol.__name__, field) == value, "Value %s for header %s didn't survive packet parsing." % (value, field)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", get_test_configs() + get_custom_configs(),
|
||||
ids=['%s-%s-%s' % (proto.__name__, field, str(val)) for proto, field, val in get_test_configs() + get_custom_configs()])
|
||||
def test_set_get(config):
|
||||
"""
|
||||
Tests value retrieval.
|
||||
"""
|
||||
proto, field, value = config
|
||||
check_set_get(proto, field, value)
|
||||
|
||||
|
||||
def check_gen_set_get(protocol, field):
|
||||
"""
|
||||
Checks if the get method worked for this protocol, field, and value.
|
||||
"""
|
||||
pkt = actions.packet.Packet(protocol())
|
||||
new_value = pkt.gen(protocol.__name__, field)
|
||||
pkt.set(protocol.__name__, field, new_value)
|
||||
assert pkt.get(protocol.__name__, field) == new_value
|
||||
# Rebuild the packet to confirm the type survived
|
||||
pkt2 = actions.packet.Packet(protocol(bytes(pkt)))
|
||||
assert pkt2.get(protocol.__name__, field) == new_value
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", get_test_configs() + get_custom_configs(),
|
||||
ids=['%s-%s' % (proto.__name__, field) for proto, field, _ in get_test_configs() + get_custom_configs()])
|
||||
def test_gen_set_get(config):
|
||||
"""
|
||||
Tests value retrieval.
|
||||
"""
|
||||
# Test each generator 50 times to hit a range of values
|
||||
for i in range(0, 50):
|
||||
proto, field, _ = config
|
||||
check_gen_set_get(proto, field)
|
||||
|
||||
|
||||
def test_custom_get():
|
||||
"""
|
||||
Tests value retrieval for custom getters.
|
||||
"""
|
||||
pkt = IP()/TCP()/Raw(load="AAAA")
|
||||
tcp = actions.packet.Packet(pkt)
|
||||
assert tcp.get("TCP", "load") == "AAAA"
|
||||
|
||||
|
||||
def test_restrict_fields():
|
||||
"""
|
||||
Tests packet field restriction.
|
||||
"""
|
||||
actions.packet.SUPPORTED_LAYERS = [
|
||||
actions.layer.IPLayer,
|
||||
actions.layer.TCPLayer,
|
||||
actions.layer.UDPLayer
|
||||
]
|
||||
tcpfields = actions.layer.TCPLayer.fields
|
||||
udpfields = actions.layer.UDPLayer.fields
|
||||
ipfields = actions.layer.IPLayer.fields
|
||||
|
||||
actions.packet.Packet.restrict_fields(logger, ["TCP", "UDP"], [], [])
|
||||
assert len(actions.packet.SUPPORTED_LAYERS) == 2
|
||||
assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert not actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
|
||||
pkt = IP()/TCP()
|
||||
packet = actions.packet.Packet(pkt)
|
||||
assert "TCP" in packet.layers
|
||||
assert not "IP" in packet.layers
|
||||
assert len(packet.layers) == 1
|
||||
|
||||
for i in range(0, 2000):
|
||||
layer, proto, field = actions.packet.Packet().gen_random()
|
||||
assert layer in [TCP, UDP]
|
||||
|
||||
# Check we can't retrieve any IP fields
|
||||
for field in actions.layer.IPLayer.fields:
|
||||
with pytest.raises(AssertionError):
|
||||
packet.get("IP", field)
|
||||
|
||||
# Check we can get all the TCP fields
|
||||
for field in actions.layer.TCPLayer.fields:
|
||||
packet.get("TCP", field)
|
||||
|
||||
actions.packet.Packet.restrict_fields(logger, ["TCP", "UDP"], ["flags"], [])
|
||||
packet = actions.packet.Packet(pkt)
|
||||
assert len(actions.packet.SUPPORTED_LAYERS) == 1
|
||||
assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert not actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert not actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert actions.layer.TCPLayer.fields == ["flags"]
|
||||
assert not actions.layer.UDPLayer.fields
|
||||
|
||||
# Check we can't retrieve any IP fields
|
||||
for field in actions.layer.IPLayer.fields:
|
||||
with pytest.raises(AssertionError):
|
||||
packet.get("IP", field)
|
||||
|
||||
# Check we can get all the TCP fields
|
||||
for field in tcpfields:
|
||||
if field == "flags":
|
||||
packet.get("TCP", field)
|
||||
else:
|
||||
with pytest.raises(AssertionError):
|
||||
packet.get("TCP", field)
|
||||
|
||||
for i in range(0, 2000):
|
||||
layer, field, value = actions.packet.Packet().gen_random()
|
||||
assert layer == TCP
|
||||
assert field == "flags"
|
||||
|
||||
actions.packet.Packet.reset_restrictions()
|
||||
actions.packet.SUPPORTED_LAYERS = [
|
||||
actions.layer.IPLayer,
|
||||
actions.layer.TCPLayer,
|
||||
actions.layer.UDPLayer
|
||||
]
|
||||
actions.packet.Packet.restrict_fields(logger, ["TCP", "IP"], [], ["sport", "dport", "seq", "src"])
|
||||
packet = actions.packet.Packet(pkt)
|
||||
packet = packet.copy()
|
||||
assert packet.has_supported_layers()
|
||||
assert len(actions.packet.SUPPORTED_LAYERS) == 2
|
||||
assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert not actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert set(actions.layer.TCPLayer.fields) == set([f for f in tcpfields if f not in ["sport", "dport", "seq"]])
|
||||
assert set(actions.layer.IPLayer.fields) == set([f for f in ipfields if f not in ["src"]])
|
||||
|
||||
# Check we can't retrieve any IP fields
|
||||
for field in actions.layer.IPLayer.fields:
|
||||
if field == "src":
|
||||
with pytest.raises(AssertionError):
|
||||
packet.get("IP", field)
|
||||
else:
|
||||
packet.get("IP", field)
|
||||
|
||||
# Check we can get all the TCP fields
|
||||
for field in tcpfields:
|
||||
if field in ["sport", "dport", "seq"]:
|
||||
with pytest.raises(AssertionError):
|
||||
packet.get("TCP", field)
|
||||
else:
|
||||
packet.get("TCP", field)
|
||||
|
||||
for i in range(0, 2000):
|
||||
layer, field, value = actions.packet.Packet().gen_random()
|
||||
assert layer in [TCP, IP]
|
||||
assert field not in ["sport", "dport", "seq", "src"]
|
||||
|
||||
actions.packet.Packet.reset_restrictions()
|
||||
actions.packet.SUPPORTED_LAYERS = [
|
||||
actions.layer.IPLayer,
|
||||
actions.layer.TCPLayer,
|
||||
actions.layer.UDPLayer
|
||||
]
|
||||
|
||||
actions.packet.Packet.restrict_fields(logger, ["IP", "UDP", "DNS"], [], ["version"])
|
||||
packet = actions.packet.Packet(pkt)
|
||||
proto, field, value = packet.get_random()
|
||||
assert proto.__name__ in ["IP", "UDP"]
|
||||
assert len(actions.packet.SUPPORTED_LAYERS) == 2
|
||||
assert not actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS
|
||||
assert set(actions.layer.IPLayer.fields) == set([f for f in ipfields if f not in ["version"]])
|
||||
assert set(actions.layer.UDPLayer.fields) == set(udpfields)
|
||||
|
||||
actions.packet.Packet.reset_restrictions()
|
||||
for layer in actions.packet.SUPPORTED_LAYERS:
|
||||
assert layer.fields, '%s has no fields - reset failed!' % str(layer)
|
|
@ -0,0 +1,94 @@
|
|||
import logging
|
||||
import pytest
|
||||
|
||||
import actions.tree
|
||||
import actions.drop
|
||||
import actions.tamper
|
||||
import actions.trace
|
||||
import actions.duplicate
|
||||
import actions.sleep
|
||||
import actions.utils
|
||||
import actions.strategy
|
||||
|
||||
from scapy.all import IP, TCP
|
||||
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
|
||||
def test_run():
|
||||
"""
|
||||
Tests strategy execution.
|
||||
"""
|
||||
strat1 = actions.utils.parse("[TCP:flags:R]-duplicate-| \/", logger)
|
||||
strat2 = actions.utils.parse("[TCP:flags:S]-drop-| \/", logger)
|
||||
strat3 = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:dataofs:replace:0},)-| \/", logger)
|
||||
strat4 = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:15239},),duplicate(tamper{TCP:flags:replace:S}(tamper{TCP:chksum:replace:14539}(tamper{TCP:seq:corrupt},),),))-| \/", logger)
|
||||
|
||||
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
packets = strat1.act_on_packet(p1, logger, direction="out")
|
||||
assert packets, "Strategy dropped SYN packets"
|
||||
assert len(packets) == 1
|
||||
assert packets[0]["TCP"].flags == "S"
|
||||
|
||||
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
packets = strat2.act_on_packet(p1, logger, direction="out")
|
||||
assert not packets, "Strategy failed to drop SYN packets"
|
||||
|
||||
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="A", dataofs=5))
|
||||
packets = strat3.act_on_packet(p1, logger, direction="out")
|
||||
assert packets, "Strategy dropped packets"
|
||||
assert len(packets) == 2, "Incorrect number of packets emerged from forest"
|
||||
assert packets[0]["TCP"].dataofs == 0, "Packet tamper failed"
|
||||
assert packets[1]["TCP"].dataofs == 5, "Duplicate packet was tampered"
|
||||
|
||||
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="A", dataofs=5, chksum=100))
|
||||
packets = strat4.act_on_packet(p1, logger, direction="out")
|
||||
assert packets, "Strategy dropped packets"
|
||||
assert len(packets) == 3, "Incorrect number of packets emerged from forest"
|
||||
assert packets[0]["TCP"].flags == "R", "Packet tamper failed"
|
||||
assert packets[0]["TCP"].chksum != p1["TCP"].chksum, "Packet tamper failed"
|
||||
assert packets[1]["TCP"].flags == "S", "Packet tamper failed"
|
||||
assert packets[1]["TCP"].chksum != p1["TCP"].chksum, "Packet tamper failed"
|
||||
assert packets[1]["TCP"].seq != p1["TCP"].seq, "Packet tamper failed"
|
||||
assert packets[2]["TCP"].flags == "A", "Duplicate failed"
|
||||
|
||||
strat4 = actions.utils.parse("[TCP:load:]-tamper{TCP:load:replace:mhe76jm0bd}(fragment{ip:-1:True}(tamper{IP:load:corrupt},drop),)-| \/ ", logger)
|
||||
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
packets = strat4.act_on_packet(p1, logger)
|
||||
|
||||
# Will fail with scapy 2.4.2 if packet is reparsed
|
||||
strat5 = actions.utils.parse("\"[TCP:options-eol:]-tamper{TCP:load:replace:o}(tamper{TCP:dataofs:replace:11},)-| \/\"", logger)
|
||||
p1 = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
packets = strat5.act_on_packet(p1, logger)
|
||||
|
||||
|
||||
def test_pretty_print():
|
||||
"""
|
||||
Tests if the string representation of this strategy is correct
|
||||
"""
|
||||
logger = logging.getLogger("test")
|
||||
strat = actions.utils.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:corrupt},),)-| \/ ", logger)
|
||||
correct = "TCP:flags:A\nduplicate\n├── tamper{TCP:flags:replace:R}\n│ └── tamper{TCP:chksum:corrupt}\n│ └── ===> \n└── ===> \n \n \/ \n "
|
||||
assert strat.pretty_print() == correct
|
||||
|
||||
|
||||
def test_sleep_parse_handling():
|
||||
"""
|
||||
Tests that the sleep action handles bad parsing.
|
||||
"""
|
||||
|
||||
print("Testing incorrect parsing:")
|
||||
assert not actions.sleep.SleepAction().parse("THISHSOULDFAIL", logger)
|
||||
|
||||
assert actions.sleep.SleepAction().parse("10.5", logger)
|
||||
|
||||
|
||||
def test_trace_parse_handling():
|
||||
"""
|
||||
Tests that the sleep action handles bad parsing.
|
||||
"""
|
||||
|
||||
print("Testing incorrect parsing:")
|
||||
assert not actions.trace.TraceAction().parse("5:4", logger)
|
||||
assert not actions.trace.TraceAction().parse("THISHOULDFAIL", logger)
|
||||
assert not actions.trace.TraceAction().parse("", logger)
|
|
@ -0,0 +1,389 @@
|
|||
import copy
|
||||
import logging
|
||||
import sys
|
||||
import pytest
|
||||
import random
|
||||
# Include the root of the project
|
||||
sys.path.append("..")
|
||||
|
||||
import actions.strategy
|
||||
import actions.packet
|
||||
import actions.utils
|
||||
import actions.tamper
|
||||
import actions.layer
|
||||
|
||||
from scapy.all import IP, TCP, UDP, DNS, DNSQR, sr1
|
||||
|
||||
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
|
||||
def test_tamper():
|
||||
"""
|
||||
Tests tampering with replace
|
||||
"""
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="R")
|
||||
lpacket, rpacket = tamper.run(packet, logger)
|
||||
assert not rpacket, "Tamper must not return right child"
|
||||
assert lpacket, "Tamper must give a left child"
|
||||
assert id(lpacket) == id(packet), "Tamper must edit in place"
|
||||
|
||||
# Confirm tamper replaced the field it was supposed to
|
||||
assert packet[TCP].flags == "R", "Tamper did not replace flags."
|
||||
new_value = packet[TCP].flags
|
||||
|
||||
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
|
||||
# ._fix()-ed, it will return different values each time it's requested
|
||||
for _ in range(0, 5):
|
||||
assert packet[TCP].flags == new_value, "Replaced value is not stable"
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, ["flags"])
|
||||
|
||||
# Confirm tamper didn't corrupt anything in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
|
||||
def test_tamper_ip():
|
||||
"""
|
||||
Tests tampering with IP
|
||||
"""
|
||||
packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper = actions.tamper.TamperAction(None, field="src", tamper_type="replace", tamper_value="192.168.1.1", tamper_proto="IP")
|
||||
lpacket, rpacket = tamper.run(packet, logger)
|
||||
assert not rpacket, "Tamper must not return right child"
|
||||
assert lpacket, "Tamper must give a left child"
|
||||
assert id(lpacket) == id(packet), "Tamper must edit in place"
|
||||
|
||||
# Confirm tamper replaced the field it was supposed to
|
||||
assert packet[IP].src == "192.168.1.1", "Tamper did not replace flags."
|
||||
|
||||
# Confirm tamper didn't corrupt anything in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, [])
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, ["src"])
|
||||
|
||||
|
||||
def test_tamper_udp():
|
||||
"""
|
||||
Tests tampering with UDP
|
||||
"""
|
||||
packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/UDP(sport=2222, dport=53))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper = actions.tamper.TamperAction(None, field="chksum", tamper_type="replace", tamper_value=4444, tamper_proto="UDP")
|
||||
lpacket, rpacket = tamper.run(packet, logger)
|
||||
assert not rpacket, "Tamper must not return right child"
|
||||
assert lpacket, "Tamper must give a left child"
|
||||
assert id(lpacket) == id(packet), "Tamper must edit in place"
|
||||
|
||||
# Confirm tamper replaced the field it was supposed to
|
||||
assert packet[UDP].chksum == 4444, "Tamper did not replace flags."
|
||||
|
||||
# Confirm tamper didn't corrupt anything in the TCP header
|
||||
assert confirm_unchanged(packet, original, UDP, ["chksum"])
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
|
||||
def test_tamper_ip_ident():
|
||||
"""
|
||||
Tests tampering with IP and that the checksum is correctly changed
|
||||
"""
|
||||
|
||||
packet = actions.packet.Packet(IP(src='127.0.0.1', dst='127.0.0.1')/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper = actions.tamper.TamperAction(None, field='id', tamper_type='replace', tamper_value=3333, tamper_proto="IP")
|
||||
lpacket, rpacket = tamper.run(packet, logger)
|
||||
assert not rpacket, "Tamper must not return right child"
|
||||
assert lpacket, "Tamper must give a left child"
|
||||
assert id(lpacket) == id(packet), "Tamper must edit in place"
|
||||
|
||||
# Confirm tamper replaced the field it was supposed to
|
||||
assert packet[IP].id == 3333, "Tamper did not replace flags."
|
||||
|
||||
# Confirm tamper didn't corrupt anything in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, [])
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, ["id"])
|
||||
|
||||
|
||||
def confirm_unchanged(packet, original, protocol, changed):
|
||||
"""
|
||||
Checks that no other field besides the given array of changed fields
|
||||
are different between these two packets.
|
||||
"""
|
||||
for header in packet.layers:
|
||||
if packet.layers[header].protocol != protocol:
|
||||
continue
|
||||
for field in packet.layers[header].fields:
|
||||
# Skip checking the field we just changed
|
||||
if field in changed or field == "load":
|
||||
continue
|
||||
assert packet.get(protocol.__name__, field) == original.get(protocol.__name__, field), "Tamper changed %s field %s." % (str(protocol), field)
|
||||
return True
|
||||
|
||||
|
||||
def test_parse_parameters():
|
||||
"""
|
||||
Tests that tamper properly rejects malformed tamper actions
|
||||
"""
|
||||
with pytest.raises(Exception):
|
||||
actions.tamper.TamperAction().parse("this:has:too:many:parameters", logger)
|
||||
with pytest.raises(Exception):
|
||||
actions.tamper.TamperAction().parse("not:enough", logger)
|
||||
|
||||
|
||||
def test_corrupt():
|
||||
"""
|
||||
Tests the tamper 'corrupt' primitive.
|
||||
"""
|
||||
tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="corrupt", tamper_value="R")
|
||||
assert tamper.field == "flags", "Tamper action changed fields."
|
||||
assert tamper.tamper_type == "corrupt", "Tamper action changed types."
|
||||
assert str(tamper) == "tamper{TCP:flags:corrupt}", "Tamper returned incorrect string representation: %s" % str(tamper)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper.tamper(packet, logger)
|
||||
|
||||
new_value = packet[TCP].flags
|
||||
|
||||
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
|
||||
# ._fix()-ed, it will return different values each time it's requested
|
||||
for _ in range(0, 5):
|
||||
assert packet[TCP].flags == new_value, "Corrupted value is not stable"
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, ["flags"])
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
|
||||
def test_add():
|
||||
"""
|
||||
Tests the tamper 'add' primitive.
|
||||
"""
|
||||
tamper = actions.tamper.TamperAction(None, field="seq", tamper_type="add", tamper_value=10)
|
||||
assert tamper.field == "seq", "Tamper action changed fields."
|
||||
assert tamper.tamper_type == "add", "Tamper action changed types."
|
||||
assert str(tamper) == "tamper{TCP:seq:add:10}", "Tamper returned incorrect string representation: %s" % str(tamper)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper.tamper(packet, logger)
|
||||
|
||||
new_value = packet[TCP].seq
|
||||
assert new_value == 110, "Tamper did not add"
|
||||
|
||||
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
|
||||
# ._fix()-ed, it will return different values each time it's requested
|
||||
for _ in range(0, 5):
|
||||
assert packet[TCP].seq == new_value, "Corrupted value is not stable"
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, ["seq"])
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
|
||||
def test_decompress():
|
||||
"""
|
||||
Tests the tamper 'decompress' primitive.
|
||||
"""
|
||||
tamper = actions.tamper.TamperAction(None, field="qd", tamper_type="compress", tamper_value=10, tamper_proto="DNS")
|
||||
assert tamper.field == "qd", "Tamper action changed fields."
|
||||
assert tamper.tamper_type == "compress", "Tamper action changed types."
|
||||
assert str(tamper) == "tamper{DNS:qd:compress}", "Tamper returned incorrect string representation: %s" % str(tamper)
|
||||
|
||||
packet = actions.packet.Packet(IP(dst="8.8.8.8")/UDP(dport=53)/DNS(qd=DNSQR(qname="minghui.ca.")))
|
||||
original = packet.copy()
|
||||
tamper.tamper(packet, logger)
|
||||
assert bytes(packet["DNS"]) == b'\x00\x00\x01\x00\x00\x02\x00\x00\x00\x00\x00\x00\x07minghui\xc0\x1a\x00\x01\x00\x01\x02ca\x00\x00\x01\x00\x01'
|
||||
resp = sr1(packet.packet)
|
||||
assert resp["DNS"]
|
||||
assert resp["DNS"].rcode != 1
|
||||
assert resp["DNSQR"]
|
||||
assert resp["DNSRR"].rdata
|
||||
assert confirm_unchanged(packet, original, IP, ["len"])
|
||||
print(resp.summary())
|
||||
|
||||
packet = actions.packet.Packet(IP(dst="8.8.8.8")/UDP(dport=53)/DNS(qd=DNSQR(qname="maps.google.com")))
|
||||
original = packet.copy()
|
||||
tamper.tamper(packet, logger)
|
||||
assert bytes(packet["DNS"]) == b'\x00\x00\x01\x00\x00\x02\x00\x00\x00\x00\x00\x00\x04maps\xc0\x17\x00\x01\x00\x01\x06google\x03com\x00\x00\x01\x00\x01'
|
||||
resp = sr1(packet.packet)
|
||||
assert resp["DNS"]
|
||||
assert resp["DNS"].rcode != 1
|
||||
assert resp["DNSQR"]
|
||||
assert resp["DNSRR"].rdata
|
||||
assert confirm_unchanged(packet, original, IP, ["len"])
|
||||
print(resp.summary())
|
||||
|
||||
# Confirm this is a NOP on normal packets
|
||||
packet = actions.packet.Packet(IP()/UDP())
|
||||
original = packet.copy()
|
||||
tamper.tamper(packet, logger)
|
||||
assert packet.packet.summary() == original.packet.summary()
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the TCP header
|
||||
assert confirm_unchanged(packet, original, UDP, [])
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
packet = actions.packet.Packet(IP(dst="8.8.8.8")/TCP(dport=53)/DNS(qd=DNSQR(qname="maps.google.com")))
|
||||
original = packet.copy()
|
||||
tamper.tamper(packet, logger)
|
||||
assert bytes(packet) == bytes(original)
|
||||
|
||||
|
||||
|
||||
def test_corrupt_chksum():
|
||||
"""
|
||||
Tests the tamper 'replace' primitive.
|
||||
"""
|
||||
tamper = actions.tamper.TamperAction(None, field="chksum", tamper_type="corrupt", tamper_value="R")
|
||||
assert tamper.field == "chksum", "Tamper action changed checksum."
|
||||
assert tamper.tamper_type == "corrupt", "Tamper action changed types."
|
||||
assert str(tamper) == "tamper{TCP:chksum:corrupt}", "Tamper returned incorrect string representation: %s" % str(tamper)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper.tamper(packet, logger)
|
||||
|
||||
# Confirm tamper actually corrupted the checksum
|
||||
assert packet[TCP].chksum != 0
|
||||
new_value = packet[TCP].chksum
|
||||
|
||||
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
|
||||
# ._fix()-ed, it will return different values each time it's requested
|
||||
for _ in range(0, 5):
|
||||
assert packet[TCP].chksum == new_value, "Corrupted value is not stable"
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, ["chksum"])
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
|
||||
def test_corrupt_dataofs():
|
||||
"""
|
||||
Tests the tamper 'replace' primitive.
|
||||
"""
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S", dataofs="6L"))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper = actions.tamper.TamperAction(None, field="dataofs", tamper_type="corrupt")
|
||||
|
||||
tamper.tamper(packet, logger)
|
||||
|
||||
# Confirm tamper actually corrupted the checksum
|
||||
assert packet[TCP].dataofs != "0"
|
||||
new_value = packet[TCP].dataofs
|
||||
|
||||
# Must run this check repeatedly - if a scapy fuzz-ed value is not properly
|
||||
# ._fix()-ed, it will return different values each time it's requested
|
||||
for _ in range(0, 5):
|
||||
assert packet[TCP].dataofs == new_value, "Corrupted value is not stable"
|
||||
|
||||
# Confirm tamper didn't corrupt anything else in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, ["dataofs"])
|
||||
|
||||
# Confirm tamper didn't corrupt anything in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
|
||||
def test_replace():
|
||||
"""
|
||||
Tests the tamper 'replace' primitive.
|
||||
"""
|
||||
tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="R")
|
||||
|
||||
assert tamper.field == "flags", "Tamper action changed fields."
|
||||
assert tamper.tamper_type == "replace", "Tamper action changed types."
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
original = copy.deepcopy(packet)
|
||||
tamper.tamper(packet, logger)
|
||||
|
||||
# Confirm tamper replaced the field it was supposed to
|
||||
assert packet[TCP].flags == "R", "Tamper did not replace flags."
|
||||
# Confirm tamper didn't replace anything else in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, ["flags"])
|
||||
|
||||
# Confirm tamper didn't replace anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
# chksums must be handled specially by tamper, so run a second check on this value
|
||||
tamper.field = "chksum"
|
||||
tamper.tamper_value = 0x4444
|
||||
original = copy.deepcopy(packet)
|
||||
tamper.tamper(packet, logger)
|
||||
assert packet[TCP].chksum == 0x4444, "Tamper failed to change chksum."
|
||||
# Confirm tamper didn't replace anything else in the TCP header
|
||||
assert confirm_unchanged(packet, original, TCP, ["chksum"])
|
||||
# Confirm tamper didn't replace anything else in the IP header
|
||||
assert confirm_unchanged(packet, original, IP, [])
|
||||
|
||||
|
||||
def test_parse_flags():
|
||||
"""
|
||||
Tests the tamper 'replace' primitive.
|
||||
"""
|
||||
tamper = actions.tamper.TamperAction(None, field="flags", tamper_type="replace", tamper_value="FRAPUN")
|
||||
assert tamper.field == "flags", "Tamper action changed checksum."
|
||||
assert tamper.tamper_type == "replace", "Tamper action changed types."
|
||||
assert str(tamper) == "tamper{TCP:flags:replace:FRAPUN}", "Tamper returned incorrect string representation: %s" % str(tamper)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
tamper.tamper(packet, logger)
|
||||
assert packet[TCP].flags == "FRAPUN", "Tamper failed to change flags."
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_type", ["parsed", "direct"])
|
||||
@pytest.mark.parametrize("value", ["EOL", "NOP", "Timestamp", "MSS", "WScale", "SAckOK", "SAck", "Timestamp", "AltChkSum", "AltChkSumOpt", "UTO"])
|
||||
def test_options(value, test_type):
|
||||
"""
|
||||
Tests tampering options
|
||||
"""
|
||||
if test_type == "direct":
|
||||
tamper = actions.tamper.TamperAction(None, field="options-%s" % value.lower(), tamper_type="corrupt", tamper_value=bytes([12]))
|
||||
else:
|
||||
tamper = actions.tamper.TamperAction(None)
|
||||
assert tamper.parse("TCP:options-%s:replace:" % value.lower(), logger)
|
||||
assert tamper.parse("TCP:options-%s:corrupt" % value.lower(), logger)
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="S"))
|
||||
tamper.run(packet, logger)
|
||||
opts_dict_lookup = value.lower().replace(" ", "_")
|
||||
|
||||
for optname, optval in packet["TCP"].options:
|
||||
if optname == value:
|
||||
break
|
||||
elif optname == actions.layer.TCPLayer.options_names[opts_dict_lookup]:
|
||||
break
|
||||
else:
|
||||
pytest.fail("Failed to find %s in options" % value)
|
||||
assert len(packet["TCP"].options) == 1
|
||||
raw_p = bytes(packet)
|
||||
assert raw_p, "options broke scapy bytes"
|
||||
p2 = actions.packet.Packet(IP(bytes(raw_p)))
|
||||
assert p2.haslayer("IP")
|
||||
assert p2.haslayer("TCP")
|
||||
# EOLs might be added for padding, so just check >= 1
|
||||
assert len(p2["TCP"].options) >= 1
|
||||
for optname, optval in p2["TCP"].options:
|
||||
if optname == value:
|
||||
break
|
||||
elif optname == actions.layer.TCPLayer.options_names[opts_dict_lookup]:
|
||||
break
|
||||
else:
|
||||
pytest.fail("Failed to find %s in options" % value)
|
|
@ -0,0 +1,549 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from scapy.all import IP, TCP
|
||||
import actions.tree
|
||||
import actions.drop
|
||||
import actions.tamper
|
||||
import actions.duplicate
|
||||
import actions.utils
|
||||
|
||||
|
||||
def test_init():
|
||||
"""
|
||||
Tests initialization
|
||||
"""
|
||||
print(actions.action.Action.get_actions("out"))
|
||||
|
||||
|
||||
def test_count_leaves():
|
||||
"""
|
||||
Tests leaf count is correct.
|
||||
"""
|
||||
a = actions.tree.ActionTree("out")
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
assert not a.parse("TCP:reserved:0tamper{TCP:flags:replace:S}-|", logger), "Tree parsed malformed DNA"
|
||||
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
duplicate2 = actions.duplicate.DuplicateAction()
|
||||
drop = actions.drop.DropAction()
|
||||
|
||||
assert a.count_leaves() == 1
|
||||
assert a.remove_one()
|
||||
a.add_action(duplicate)
|
||||
assert a.count_leaves() == 1
|
||||
duplicate.left = duplicate2
|
||||
assert a.count_leaves() == 1
|
||||
duplicate.right = drop
|
||||
assert a.count_leaves() == 2
|
||||
|
||||
|
||||
def test_check():
|
||||
"""
|
||||
Tests action tree check function.
|
||||
"""
|
||||
a = actions.tree.ActionTree("out")
|
||||
logger = logging.getLogger("test")
|
||||
a.parse("[TCP:flags:RA]-tamper{TCP:flags:replace:S}-|", logger)
|
||||
p = actions.packet.Packet(IP()/TCP(flags="A"))
|
||||
assert not a.check(p, logger)
|
||||
p = actions.packet.Packet(IP(ttl=64)/TCP(flags="RA"))
|
||||
assert a.check(p, logger)
|
||||
assert a.remove_one()
|
||||
assert a.check(p, logger)
|
||||
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
|
||||
assert a.check(p, logger)
|
||||
a.parse("[IP:ttl:64]-tamper{TCP:flags:replace:S}-|", logger)
|
||||
assert a.check(p, logger)
|
||||
p = actions.packet.Packet(IP(ttl=15)/TCP(flags="RA"))
|
||||
assert not a.check(p, logger)
|
||||
|
||||
|
||||
def test_scapy():
|
||||
"""
|
||||
Tests misc. scapy aspects relevant to strategies.
|
||||
"""
|
||||
a = actions.tree.ActionTree("out")
|
||||
logger = logging.getLogger("test")
|
||||
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
|
||||
p = actions.packet.Packet(IP()/TCP(flags="A"))
|
||||
assert a.check(p, logger)
|
||||
packets = a.run(p, logger)
|
||||
assert packets[0][TCP].flags == "S"
|
||||
p = actions.packet.Packet(IP()/TCP(flags="A"))
|
||||
assert a.check(p, logger)
|
||||
a.parse("[TCP:reserved:0]-tamper{TCP:chksum:corrupt}-|", logger)
|
||||
packets = a.run(p, logger)
|
||||
assert packets[0][TCP].chksum
|
||||
assert a.check(p, logger)
|
||||
|
||||
|
||||
def test_str():
|
||||
"""
|
||||
Tests string representation.
|
||||
"""
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
t = actions.trigger.Trigger("field", "flags", "TCP")
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
assert str(a).strip() == "[%s]-|" % str(t)
|
||||
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
|
||||
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
assert a.add_action(tamper)
|
||||
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|"
|
||||
# Tree will not add a duplicate action
|
||||
assert not a.add_action(tamper)
|
||||
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|"
|
||||
assert a.add_action(tamper2)
|
||||
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|"
|
||||
assert a.add_action(actions.duplicate.DuplicateAction())
|
||||
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
|
||||
drop = actions.drop.DropAction()
|
||||
assert a.add_action(drop)
|
||||
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(drop,),),)-|" or \
|
||||
str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(,drop),),)-|"
|
||||
assert a.remove_action(drop)
|
||||
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
|
||||
# Cannot remove action that is not present
|
||||
assert not a.remove_action(drop)
|
||||
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
|
||||
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
orig = "[TCP:urgptr:15963]-duplicate(,drop)-|"
|
||||
a.parse(orig, logger)
|
||||
assert a.remove_one()
|
||||
assert orig != str(a)
|
||||
assert str(a) in ["[TCP:urgptr:15963]-drop-|", "[TCP:urgptr:15963]-duplicate-|"]
|
||||
|
||||
|
||||
def test_pretty_print_send():
|
||||
t = actions.trigger.Trigger("field", "flags", "TCP")
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
a.add_action(duplicate)
|
||||
correct_string = "TCP:flags:0\nduplicate\n├── ===> \n└── ===> "
|
||||
assert a.pretty_print() == correct_string
|
||||
|
||||
|
||||
def test_pretty_print():
|
||||
"""
|
||||
Print complex tree, although difficult to test
|
||||
"""
|
||||
t = actions.trigger.Trigger("field", "flags", "TCP")
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
|
||||
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
duplicate2 = actions.duplicate.DuplicateAction()
|
||||
duplicate3 = actions.duplicate.DuplicateAction()
|
||||
duplicate4 = actions.duplicate.DuplicateAction()
|
||||
duplicate5 = actions.duplicate.DuplicateAction()
|
||||
drop = actions.drop.DropAction()
|
||||
drop2 = actions.drop.DropAction()
|
||||
drop3 = actions.drop.DropAction()
|
||||
drop4 = actions.drop.DropAction()
|
||||
|
||||
duplicate.left = duplicate2
|
||||
duplicate.right = duplicate3
|
||||
duplicate2.left = tamper
|
||||
duplicate2.right = drop
|
||||
duplicate3.left = duplicate4
|
||||
duplicate3.right = drop2
|
||||
duplicate4.left = duplicate5
|
||||
duplicate4.right = drop3
|
||||
duplicate5.left = drop4
|
||||
duplicate5.right = tamper2
|
||||
|
||||
a.add_action(duplicate)
|
||||
correct_string = "TCP:flags:0\nduplicate\n├── duplicate\n│ ├── tamper{TCP:flags:replace:S}\n│ │ └── ===> \n│ └── drop\n└── duplicate\n ├── duplicate\n │ ├── duplicate\n │ │ ├── drop\n │ │ └── tamper{TCP:flags:replace:R}\n │ │ └── ===> \n │ └── drop\n └── drop"
|
||||
assert a.pretty_print() == correct_string
|
||||
assert a.pretty_print(visual=True)
|
||||
assert os.path.exists("tree.png")
|
||||
os.remove("tree.png")
|
||||
a.parse("[TCP:flags:0]-|", logging.getLogger("test"))
|
||||
a.pretty_print(visual=True) # Empty action tree
|
||||
assert not os.path.exists("tree.png")
|
||||
|
||||
def test_pretty_print_order():
|
||||
"""
|
||||
Tests the left/right ordering by reading in a new tree
|
||||
"""
|
||||
logger = logging.getLogger("test")
|
||||
a = actions.tree.ActionTree("out")
|
||||
assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S}(tamper{TCP:chksum:replace:14239},),))-|", logger)
|
||||
correct_pretty_print = "TCP:flags:A\nduplicate\n├── tamper{TCP:flags:replace:R}\n│ └── tamper{TCP:chksum:replace:14239}\n│ └── ===> \n└── duplicate\n ├── tamper{TCP:flags:replace:S}\n │ └── tamper{TCP:chksum:replace:14239}\n │ └── ===> \n └── ===> "
|
||||
assert a.pretty_print() == correct_pretty_print
|
||||
|
||||
def test_parse():
|
||||
"""
|
||||
Tests string parsing.
|
||||
"""
|
||||
logger = logging.getLogger("test")
|
||||
t = actions.trigger.Trigger("field", "flags", "TCP")
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
|
||||
base_t = actions.trigger.Trigger("field", "flags", "TCP")
|
||||
base_a = actions.tree.ActionTree("out", trigger=base_t)
|
||||
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
|
||||
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
|
||||
tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
a.parse("[TCP:flags:0]-|", logger)
|
||||
assert str(a) == str(base_a)
|
||||
assert len(a) == 0
|
||||
|
||||
base_a.add_action(tamper)
|
||||
|
||||
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}-|", logger)
|
||||
|
||||
assert str(a) == str(base_a)
|
||||
assert len(a) == 1
|
||||
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|", logging.getLogger("test"))
|
||||
base_a.add_action(tamper2)
|
||||
assert str(a) == str(base_a)
|
||||
assert len(a) == 2
|
||||
|
||||
base_a.add_action(tamper3)
|
||||
base_a.add_action(tamper4)
|
||||
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},),),)-|", logging.getLogger("test"))
|
||||
assert str(a) == str(base_a)
|
||||
assert len(a) == 4
|
||||
|
||||
base_t = actions.trigger.Trigger("field", "flags", "TCP")
|
||||
base_a = actions.tree.ActionTree("out", trigger=base_t)
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
assert a.parse("[TCP:flags:0]-duplicate-|", logger)
|
||||
base_a.add_action(duplicate)
|
||||
assert str(a) == str(base_a)
|
||||
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
|
||||
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="A")
|
||||
tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
duplicate.left = tamper
|
||||
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},)-|", logger)
|
||||
assert str(a) == str(base_a)
|
||||
|
||||
duplicate.right = tamper2
|
||||
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-|", logger)
|
||||
assert str(a) == str(base_a)
|
||||
|
||||
tamper2.left = tamper3
|
||||
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:A},))-|", logger)
|
||||
assert str(a) == str(base_a)
|
||||
|
||||
strategy = actions.utils.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-| \/", logger)
|
||||
assert strategy
|
||||
assert len(strategy.out_actions[0]) == 3
|
||||
assert len(strategy.in_actions) == 0
|
||||
|
||||
assert not a.parse("[]", logger) # No valid trigger
|
||||
assert not a.parse("[TCP:flags:0]-", logger) # No valid ending "|"
|
||||
assert not a.parse("[TCP:]-|", logger) # invalid trigger
|
||||
assert not a.parse("[TCP:flags:0]-foo-|", logger) # Non-existent action
|
||||
assert not a.parse("[TCP:flags:0]--|", logger) # Empty action
|
||||
assert not a.parse("[TCP:flags:0]-duplicate(,,,)-|", logger) # Bad tree
|
||||
assert not a.parse("[TCP:flags:0]-duplicate()))-|", logger) # Bad tree
|
||||
assert not a.parse("[TCP:flags:0]-duplicate(((()-|", logger) # Bad tree
|
||||
assert not a.parse("[TCP:flags:0]-duplicate(,))))-|", logger) # Bad tree
|
||||
assert not a.parse("[TCP:flags:0]-drop(duplicate,)-|", logger) # Terminal action with children
|
||||
assert not a.parse("[TCP:flags:0]-drop(duplicate,duplicate)-|", logger) # Terminal action with children
|
||||
assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(,duplicate)-|", logger) # Non-branching action with right child
|
||||
assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(drop,duplicate)-|", logger) # Non-branching action with children
|
||||
|
||||
|
||||
def test_tree():
|
||||
"""
|
||||
Tests basic tree functionality.
|
||||
"""
|
||||
t = actions.trigger.Trigger(None, None, None)
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
tamper = actions.tamper.TamperAction()
|
||||
tamper2 = actions.tamper.TamperAction()
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
assert a.get_parent(None) == (None, None)
|
||||
|
||||
a.add_action(None)
|
||||
a.add_action(tamper)
|
||||
assert a.get_slots() == 1
|
||||
a.add_action(tamper2)
|
||||
assert a.get_parent(tamper2) == (tamper, "left")
|
||||
assert a.get_slots() == 1
|
||||
a.add_action(duplicate)
|
||||
assert a.get_slots() == 2
|
||||
|
||||
t = actions.trigger.Trigger(None, None, None)
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
drop = actions.drop.DropAction()
|
||||
a.add_action(drop)
|
||||
assert a.get_parent(drop) == (None, None)
|
||||
assert a.get_slots() == 0
|
||||
add_success = a.add_action(tamper)
|
||||
assert not add_success
|
||||
assert a.get_slots() == 0
|
||||
|
||||
rep = ""
|
||||
for s in a.string_repr(a.action_root):
|
||||
rep += s
|
||||
assert rep == "drop"
|
||||
|
||||
print(str(a))
|
||||
|
||||
assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:seq:corrupt},)-|", logging.getLogger("test"))
|
||||
for act in a:
|
||||
print(str(a))
|
||||
assert len(a) == 2
|
||||
assert a.get_slots() == 2
|
||||
|
||||
|
||||
def test_remove():
|
||||
"""
|
||||
Tests remove
|
||||
"""
|
||||
t = actions.trigger.Trigger(None, None, None)
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
tamper = actions.tamper.TamperAction()
|
||||
tamper2 = actions.tamper.TamperAction()
|
||||
tamper3 = actions.tamper.TamperAction()
|
||||
assert not a.remove_action(tamper)
|
||||
a.add_action(tamper)
|
||||
assert a.remove_action(tamper)
|
||||
a.add_action(tamper)
|
||||
a.add_action(tamper2)
|
||||
a.add_action(tamper3)
|
||||
assert a.remove_action(tamper2)
|
||||
assert tamper2 not in a
|
||||
assert tamper.left == tamper3
|
||||
assert not tamper.right
|
||||
assert len(a) == 2
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
tamper = actions.tamper.TamperAction()
|
||||
tamper2 = actions.tamper.TamperAction()
|
||||
tamper3 = actions.tamper.TamperAction()
|
||||
a.add_action(tamper)
|
||||
assert a.action_root == tamper
|
||||
duplicate.left = tamper2
|
||||
duplicate.right = tamper3
|
||||
a.add_action(duplicate)
|
||||
assert a.get_parent(tamper3) == (duplicate, "right")
|
||||
assert len(a) == 4
|
||||
assert a.remove_action(duplicate)
|
||||
assert duplicate not in a
|
||||
assert tamper.left == tamper2
|
||||
assert not tamper.right
|
||||
assert len(a) == 2
|
||||
|
||||
a.parse("[TCP:flags:A]-|", logging.getLogger("test"))
|
||||
assert not a.remove_one(), "Cannot remove one with no action root"
|
||||
|
||||
|
||||
def test_len():
|
||||
"""
|
||||
Tests length calculation.
|
||||
"""
|
||||
t = actions.trigger.Trigger(None, None, None)
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
tamper = actions.tamper.TamperAction()
|
||||
tamper2 = actions.tamper.TamperAction()
|
||||
assert len(a) == 0, "__len__ returned wrong length"
|
||||
a.add_action(tamper)
|
||||
assert len(a) == 1, "__len__ returned wrong length"
|
||||
a.add_action(tamper)
|
||||
assert len(a) == 1, "__len__ returned wrong length"
|
||||
a.add_action(tamper2)
|
||||
assert len(a) == 2, "__len__ returned wrong length"
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
a.add_action(duplicate)
|
||||
assert len(a) == 3, "__len__ returned wrong length"
|
||||
|
||||
|
||||
def test_contains():
|
||||
"""
|
||||
Tests contains method
|
||||
"""
|
||||
t = actions.trigger.Trigger(None, None, None)
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
tamper = actions.tamper.TamperAction()
|
||||
tamper2 = actions.tamper.TamperAction()
|
||||
tamper3 = actions.tamper.TamperAction()
|
||||
|
||||
assert not a.contains(tamper), "contains incorrect behavior"
|
||||
assert not a.contains(tamper2), "contains incorrect behavior"
|
||||
a.add_action(tamper)
|
||||
assert a.contains(tamper), "contains incorrect behavior"
|
||||
assert not a.contains(tamper2), "contains incorrect behavior"
|
||||
add_success = a.add_action(tamper)
|
||||
assert not add_success, "added duplicate action"
|
||||
assert a.contains(tamper), "contains incorrect behavior"
|
||||
assert not a.contains(tamper2), "contains incorrect behavior"
|
||||
a.add_action(tamper2)
|
||||
assert a.contains(tamper), "contains incorrect behavior"
|
||||
assert a.contains(tamper2), "contains incorrect behavior"
|
||||
a.remove_action(tamper2)
|
||||
assert a.contains(tamper), "contains incorrect behavior"
|
||||
assert not a.contains(tamper2), "contains incorrect behavior"
|
||||
a.add_action(tamper2)
|
||||
assert a.contains(tamper), "contains incorrect behavior"
|
||||
assert a.contains(tamper2), "contains incorrect behavior"
|
||||
remove_success = a.remove_action(tamper)
|
||||
assert remove_success
|
||||
assert not a.contains(tamper), "contains incorrect behavior"
|
||||
assert a.contains(tamper2), "contains incorrect behavior"
|
||||
a.add_action(tamper3)
|
||||
assert a.contains(tamper3), "contains incorrect behavior"
|
||||
assert len(a) == 2, "len incorrect return"
|
||||
remove_success = a.remove_action(tamper2)
|
||||
assert remove_success
|
||||
|
||||
|
||||
def test_iter():
|
||||
"""
|
||||
Tests iterator.
|
||||
"""
|
||||
t = actions.trigger.Trigger(None, None, None)
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
|
||||
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
|
||||
assert a.add_action(tamper)
|
||||
assert a.add_action(tamper2)
|
||||
assert not a.add_action(tamper)
|
||||
for node in a:
|
||||
print(node)
|
||||
|
||||
|
||||
def test_run():
|
||||
"""
|
||||
Tests running packets through the chain.
|
||||
"""
|
||||
logger = logging.getLogger("test")
|
||||
t = actions.trigger.Trigger(None, None, None)
|
||||
a = actions.tree.ActionTree("out", trigger=t)
|
||||
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
|
||||
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
duplicate2 = actions.duplicate.DuplicateAction()
|
||||
drop = actions.drop.DropAction()
|
||||
|
||||
packet = actions.packet.Packet(IP()/TCP())
|
||||
a.add_action(tamper)
|
||||
packets = a.run(packet, logging.getLogger("test"))
|
||||
assert len(packets) == 1
|
||||
assert None not in packets
|
||||
assert packets[0].get("TCP", "flags") == "S"
|
||||
a.add_action(tamper2)
|
||||
print(str(a))
|
||||
|
||||
packet = actions.packet.Packet(IP()/TCP())
|
||||
assert not a.add_action(tamper), "tree added duplicate action"
|
||||
packets = a.run(packet, logging.getLogger("test"))
|
||||
assert len(packets) == 1
|
||||
assert None not in packets
|
||||
assert packets[0].get("TCP", "flags") == "R"
|
||||
print(str(a))
|
||||
|
||||
a.remove_action(tamper2)
|
||||
a.remove_action(tamper)
|
||||
a.add_action(duplicate)
|
||||
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
|
||||
packets = a.run(packet, logging.getLogger("test"))
|
||||
assert len(packets) == 2
|
||||
assert None not in packets
|
||||
assert packets[0][TCP].flags == "RA"
|
||||
assert packets[1][TCP].flags == "RA"
|
||||
print(str(a))
|
||||
|
||||
duplicate.left = tamper
|
||||
duplicate.right = tamper2
|
||||
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
|
||||
packets = a.run(packet, logging.getLogger("test"))
|
||||
assert len(packets) == 2
|
||||
assert None not in packets
|
||||
print(str(a))
|
||||
print(str(packets[0]))
|
||||
print(str(packets[1]))
|
||||
assert packets[0][TCP].flags == "S"
|
||||
assert packets[1][TCP].flags == "R"
|
||||
print(str(a))
|
||||
|
||||
tamper.left = duplicate2
|
||||
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
|
||||
packets = a.run(packet, logging.getLogger("test"))
|
||||
assert len(packets) == 3
|
||||
assert None not in packets
|
||||
assert packets[0][TCP].flags == "S"
|
||||
assert packets[1][TCP].flags == "S"
|
||||
assert packets[2][TCP].flags == "R"
|
||||
print(str(a))
|
||||
|
||||
tamper2.left = drop
|
||||
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
|
||||
packets = a.run(packet, logging.getLogger("test"))
|
||||
assert len(packets) == 2
|
||||
assert None not in packets
|
||||
assert packets[0][TCP].flags == "S"
|
||||
assert packets[1][TCP].flags == "S"
|
||||
print(str(a))
|
||||
|
||||
assert a.remove_action(duplicate2)
|
||||
tamper.left = actions.drop.DropAction()
|
||||
packet = actions.packet.Packet(IP()/TCP(flags="RA"))
|
||||
packets = a.run(packet, logger )
|
||||
assert len(packets) == 0
|
||||
print(str(a))
|
||||
|
||||
a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S},))-|", logger)
|
||||
packet = actions.packet.Packet(IP()/TCP(flags="A"))
|
||||
assert a.check(packet, logger)
|
||||
packets = a.run(packet, logger)
|
||||
assert len(packets) == 3
|
||||
assert packets[0][TCP].flags == "R"
|
||||
assert packets[1][TCP].flags == "S"
|
||||
assert packets[2][TCP].flags == "A"
|
||||
|
||||
|
||||
def test_index():
|
||||
"""
|
||||
Tests index
|
||||
"""
|
||||
a = actions.tree.ActionTree("out")
|
||||
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
|
||||
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
|
||||
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="F")
|
||||
|
||||
assert a.add_action(tamper)
|
||||
assert a[0] == tamper
|
||||
assert not a[1]
|
||||
assert a.add_action(tamper2)
|
||||
assert a[0] == tamper
|
||||
assert a[1] == tamper2
|
||||
assert a[-1] == tamper2
|
||||
assert not a[10]
|
||||
assert a.add_action(tamper3)
|
||||
assert a[-1] == tamper3
|
||||
assert not a[-11]
|
||||
|
||||
|
||||
def test_choose_one():
|
||||
"""
|
||||
Tests choose_one functionality
|
||||
"""
|
||||
a = actions.tree.ActionTree("out")
|
||||
drop = actions.drop.DropAction()
|
||||
assert not a.choose_one()
|
||||
assert a.add_action(drop)
|
||||
assert a.choose_one() == drop
|
||||
assert a.remove_action(drop)
|
||||
assert not a.choose_one()
|
||||
duplicate = actions.duplicate.DuplicateAction()
|
||||
a.add_action(duplicate)
|
||||
assert a.choose_one() == duplicate
|
||||
duplicate.left = drop
|
||||
assert a.choose_one() in [duplicate, drop]
|
||||
# Make sure that both actions get chosen
|
||||
chosen = set()
|
||||
for i in range(0, 10000):
|
||||
act = a.choose_one()
|
||||
chosen.add(act)
|
||||
assert chosen == set([duplicate, drop])
|
|
@ -0,0 +1,151 @@
|
|||
import logging
|
||||
import sys
|
||||
# Include the root of the project
|
||||
sys.path.append("..")
|
||||
|
||||
import actions.packet
|
||||
import actions.strategy
|
||||
import actions.tamper
|
||||
import actions.utils
|
||||
|
||||
from scapy.all import IP, TCP
|
||||
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
|
||||
def test_trigger_gas():
|
||||
"""
|
||||
Tests triggers having gas, including changing that gas while in use
|
||||
"""
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA"))
|
||||
trigger = actions.trigger.Trigger("field", "flags", "TCP", trigger_value="SA", gas=1)
|
||||
print(trigger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
print(trigger)
|
||||
# test add gas #
|
||||
trigger.add_gas(3)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
|
||||
# Test disable, set, and enable gas #
|
||||
trigger.disable_gas()
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
trigger.set_gas(3)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
trigger.enable_gas()
|
||||
trigger.set_gas(2)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
|
||||
|
||||
def test_bomb_trigger_gas():
|
||||
"""
|
||||
Tests triggers having bomb gas, including changing that gas while in use
|
||||
"""
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA"))
|
||||
trigger = actions.trigger.Trigger("field", "flags", "TCP", trigger_value="SA", gas=-1)
|
||||
print(trigger)
|
||||
assert not trigger.is_applicable(packet, logger), "trigger should not fire on first run"
|
||||
assert trigger.is_applicable(packet, logger), "trigger should fire on second run"
|
||||
print(trigger)
|
||||
# test add gas #
|
||||
trigger.add_gas(-3)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
|
||||
# Test disable, set, and enable gas #
|
||||
trigger.disable_gas()
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
trigger.set_gas(-3)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
trigger.enable_gas()
|
||||
trigger.set_gas(-2)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
|
||||
|
||||
def test_trigger_parse_gas():
|
||||
"""
|
||||
Tests triggers having gas, including changing that gas while in use
|
||||
"""
|
||||
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA"))
|
||||
|
||||
|
||||
# parse a trigger with 1 gas
|
||||
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:1")
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
|
||||
# parse a trigger with no gas left
|
||||
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:0")
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
|
||||
# parse a trigger not using gas
|
||||
trigger = actions.trigger.Trigger.parse("TCP:flags:SA")
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
# Check that adding gas while gas is disabled does not work
|
||||
trigger.add_gas(10)
|
||||
assert trigger.gas_remaining == None
|
||||
|
||||
trigger.enable_gas()
|
||||
trigger.set_gas(2)
|
||||
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
|
||||
# Test that it can handle leading/trailing []
|
||||
trigger = actions.trigger.Trigger.parse("[TCP:flags:SA]")
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
|
||||
|
||||
def test_bomb_trigger_parse_gas():
|
||||
"""
|
||||
Tests bomb triggers having gas, including changing that gas while in use
|
||||
"""
|
||||
packet = actions.packet.Packet(IP(src="127.0.0.1", dst="127.0.0.1")/TCP(sport=2222, dport=3333, seq=100, ack=100, flags="SA"))
|
||||
|
||||
# parse a bomb trigger with 1 gas
|
||||
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:-1")
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
|
||||
# parse a trigger with no gas left
|
||||
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:0")
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
|
||||
trigger = actions.trigger.Trigger.parse("TCP:flags:SA:-1")
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
|
||||
# parse a trigger not using gas
|
||||
trigger = actions.trigger.Trigger.parse("TCP:flags:SA")
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
# Check that adding gas while gas is disabled does not work
|
||||
trigger.add_gas(10)
|
||||
assert trigger.gas_remaining == None
|
||||
|
||||
trigger.enable_gas()
|
||||
trigger.set_gas(2)
|
||||
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert trigger.is_applicable(packet, logger)
|
||||
assert not trigger.is_applicable(packet, logger)
|
||||
|
||||
# Test that it can handle leading/trailing []
|
||||
trigger = actions.trigger.Trigger.parse("[TCP:flags:SA]")
|
||||
assert trigger.is_applicable(packet, logger)
|
|
@ -0,0 +1,42 @@
|
|||
import sys
|
||||
import pytest
|
||||
# Include the root of the project
|
||||
sys.path.append("..")
|
||||
|
||||
import actions.action
|
||||
import actions.strategy
|
||||
import actions.utils
|
||||
import actions.duplicate
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("test")
|
||||
|
||||
|
||||
def get_test_configs():
|
||||
"""
|
||||
Sets up the tests
|
||||
"""
|
||||
tests = [
|
||||
("both", True, ['DuplicateAction', 'DropAction', 'SleepAction', 'TraceAction', 'TamperAction', 'FragmentAction']),
|
||||
("in", True, ['DropAction', 'TamperAction', 'SleepAction']),
|
||||
("out", True, ['DropAction', 'TamperAction', 'TraceAction', 'SleepAction', 'DuplicateAction', 'FragmentAction']),
|
||||
("both", False, ['DuplicateAction', 'SleepAction', 'TamperAction', 'FragmentAction']),
|
||||
("in", False, ['TamperAction', 'SleepAction']),
|
||||
("out", False, ['TamperAction', 'SleepAction', 'DuplicateAction', 'FragmentAction']),
|
||||
]
|
||||
# To ensure caching is not breaking anything, double the tests
|
||||
return tests + tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize("direction,allow_terminal,supported_actions", get_test_configs())
|
||||
def test_get_actions(direction, allow_terminal, supported_actions):
|
||||
"""
|
||||
Tests the duplicate action primitive.
|
||||
"""
|
||||
collected_actions = actions.action.Action.get_actions(direction, allow_terminal=allow_terminal)
|
||||
names = []
|
||||
for name, action_class in collected_actions:
|
||||
names.append(name)
|
||||
assert set(names) == set(supported_actions)
|
||||
assert len(names) == len(supported_actions)
|
Loading…
Reference in New Issue