From 7495b334bbcbc44276201ba1a1d8d67b1eace1c2 Mon Sep 17 00:00:00 2001 From: Kkevsterrr Date: Thu, 12 Dec 2019 21:02:03 -0500 Subject: [PATCH] Added DNS layers and packet test --- README.md | 2 +- actions/layer.py | 246 ++++++++++++++++++- actions/packet.py | 98 +++++++- tests/test_packet.py | 544 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 884 insertions(+), 6 deletions(-) create mode 100644 tests/test_packet.py diff --git a/README.md b/README.md index e2e2804..14075b2 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Geneva +# Geneva [![Build Status](https://travis-ci.com/Kkevsterrr/geneva.svg?branch=master)](https://travis-ci.com/Kkevsterrr/geneva) [![codecov](https://codecov.io/gh/Kkevsterrr/geneva/branch/master/graph/badge.svg)](https://codecov.io/gh/Kkevsterrr/geneva) Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors, such as those in China, India, and Kazakhstan. Unlike many other anti-censorship solutions which require assistance from outside the censoring regime (Tor, VPNs, etc.), Geneva runs strictly on the client. diff --git a/actions/layer.py b/actions/layer.py index 8b976c5..36e8b59 100644 --- a/actions/layer.py +++ b/actions/layer.py @@ -4,7 +4,8 @@ import string import os import urllib.parse -from scapy.all import IP, RandIP, UDP, Raw, TCP, fuzz +from scapy.all import IP, RandIP, UDP, DNS, DNSQR, Raw, TCP, fuzz + class Layer(): """ @@ -179,6 +180,11 @@ class Layer(): value = urllib.parse.unquote(value) value = value.encode('utf-8') + dns_payload = b"\x009ib\x81\x80\x00\x01\x00\x01\x00\x00\x00\x01\x08examples\x03com\x00\x00\x01\x00\x01\xc0\x0c\x00\x01\x00\x01\x00\x00\x01+\x00\x04\xc7\xbf2I\x00\x00)\x02\x00\x00\x00\x00\x00\x00\x00" + http_payload = b"GET / HTTP/1.1\r\nHost: www.example.com\r\n\r\n" + + value = value.replace(b"__DNS_REQUEST__", dns_payload) + value = value.replace(b"__HTTP_REQUEST__", http_payload) self.layer.payload = Raw(value) @@ -188,7 +194,7 @@ class Layer(): as a field properly. """ load = ''.join([random.choice(string.ascii_lowercase + string.digits) for k in range(10)]) - return urllib.parse.quote(load) + return random.choice(urllib.parse.quote(load)) class IPLayer(Layer): @@ -592,3 +598,239 @@ class UDPLayer(Layer): self.generators = { 'load' : self.gen_load, } + + +class DNSLayer(Layer): + """ + Defines an interface to access DNS header fields. + """ + name = "DNS" + protocol = DNS + _fields = [ + "id", + "qr", + "opcode", + "aa", + "tc", + "rd", + "ra", + "z", + "ad", + "cd", + "qd", + "rcode", + "qdcount", + "ancount", + "nscount", + "arcount" + ] + fields = _fields + def __init__(self, layer): + """ + Initializes the DNS layer. + """ + Layer.__init__(self, layer) + + self.getters = { + "qr" : self.get_bitfield, + "aa" : self.get_bitfield, + "tc" : self.get_bitfield, + "rd" : self.get_bitfield, + "ra" : self.get_bitfield, + "z" : self.get_bitfield, + "ad" : self.get_bitfield, + "cd" : self.get_bitfield + } + + self.setters = { + "qr" : self.set_bitfield, + "aa" : self.set_bitfield, + "tc" : self.set_bitfield, + "rd" : self.set_bitfield, + "ra" : self.set_bitfield, + "z" : self.set_bitfield, + "ad" : self.set_bitfield, + "cd" : self.set_bitfield + } + + self.generators = { + "id" : self.gen_id, + "qr" : self.gen_bitfield, + "opcode" : self.gen_opcode, + "aa" : self.gen_bitfield, + "tc" : self.gen_bitfield, + "rd" : self.gen_bitfield, + "ra" : self.gen_bitfield, + "z" : self.gen_bitfield, + "ad" : self.gen_bitfield, + "cd" : self.gen_bitfield, + "rcode" : self.gen_rcode, + "qdcount" : self.gen_count, + "ancount" : self.gen_count, + "nscount" : self.gen_count, + "arcount" : self.gen_count + } + + def get_bitfield(self, field): + """""" + return int(getattr(self.layer, field)) + + def set_bitfield(self, packet, field, value): + """""" + return setattr(self.layer, field, int(value)) + + def gen_bitfield(self, field): + """""" + return random.choice([0,1]) + + def gen_id(self, field): + return random.randint(0, 65535) + + def gen_opcode(self, field): + return random.randint(0, 15) + + def gen_rcode(self, field): + return random.randint(0, 15) + + def gen_count(self, field): + return random.randint(0, 65535) + + @staticmethod + def dns_decompress(packet, logger): + """ + Performs DNS decompression on the given scapy packet, if applicable. + Note that DNS compression/decompression must be done on the boundaries + of a label, so DNS compression does not support arbitrary offsets. + """ + # If this is a TCP packet + if packet.haslayer("TCP"): + raise NotImplementedError + + # Perform no action if this is not a DNS or DNSRQ packet + if not packet.haslayer("DNS") or not packet.haslayer("DNSQR"): + return packet + + # Extract the query from the DNSQR layer + query = packet["DNSQR"].qname.decode() + if query[len(query) - 1] != '.': + query += '.' + + # Split the query by label + labels = query.split(".") + + # Collect the first and second half of the query + fhalf = labels[0] + shalf = ".".join(labels[1:]) + + # Build the first DNS query directly. The format of this a byte string like this: + # b'\x07minghui\xc0\x1a\x00\x01\x00\x01' + # \x07 = the length of the label in this DNSQR + # minghui = the portion of the domain we will request in the first DNSQR + # \xc0\x1a = offset into the DNS packet where the rest of the query will be. The actual offset + # here is the \x1a - DNS mandates that if compression is used, the first two bits be 11 + # to differentiate them from the rest. \x1A = 26, which is the length of the DNS header + # plus the length of this DNSQR. + # \x00\x01 = type A record + # \x00\x01 = IN + length = bytes([len(fhalf)]) + label = fhalf.encode() + + # Since the domain will include an extra ".", add 1 + # 2 * 6 is the DNS header + # 1 is the byte that determines the length of the label + # len(label) is the length of the label + # 2 is the offset pointer + # 4 - other record information (class, IN) + packet_offset = 2 * 6 + 1 + len(label) + 2 + 2 + 2 + + # The word must start with binary 11, so OR the offset with 0xC000. + offset = (0xc000 | packet_offset).to_bytes(2, byteorder='big') + request = b'\x00\x01\x00\x01' + + dns_qr1 = length + label + offset + request + + # Build the second DNS query directly. The format of the byte string is the same as above + # b'\x02ca\x00\x00\x01\x00\x01' + # \x02 = length of the remaining domain + # ca = portion of the domain in this DNSQR + # \x00 = null byte to signify the end of the query + # \x00\x01 = type A record + # \x00\x01 = IN + # Since the second half could potentially contain many labels, this is done in a list comprehension + dns_qr2 = b"".join([bytes([len(tld)]) + tld.encode() for tld in shalf.split(".")]) + b"\x00\x01\x00\x01" + + # Next, we must rebuild the DNS packet itself. If we try to have scapy parse either dns_qr1 or dns_qr2, they + # will look malformed, since neither contains a complete request. Therefore, we must build the entire + # DNS packet at once. First, we must remove the original DNSQR, since this contains the original request + del packet["DNS"].qd + + # Once the DNSQR is removed, scapy automatically sets the qdcount to 0. Adjust it to 2 + packet["DNS"].qdcount = 2 + + # Extract the DNS header standalone now for building + dns_header = bytes(packet["DNS"]) + + dns_packet = DNS(dns_header + dns_qr1 + dns_qr2) + + del packet["DNS"] + packet = packet / dns_packet + + # Since the size and data of the packet have changed, force scapy to recalculate the important fields + # in below layers, if applicable + if packet.haslayer("IP"): + del packet["IP"].chksum + del packet["IP"].len + if packet.haslayer("TCP"): + del packet["TCP"].chksum + if packet.haslayer("UDP"): + del packet["UDP"].chksum + del packet["UDP"].len + + return packet + + +class DNSQRLayer(Layer): + """ + Defines an interface to access DNSQR header fields. + """ + name = "DNSQR" + protocol = DNSQR + _fields = [ + "qname", + "qtype", + "qclass" + ] + fields = _fields + + def __init__(self, layer): + """ + Initializes the DNS layer. + """ + Layer.__init__(self, layer) + self.getters = { + "qname" : self.get_qname + } + self.generators = { + "qname" : self.gen_qname + } + + def get_qname(self, field): + """ + Returns decoded qname from packet. + """ + return self.layer.qname.decode('utf-8') + + def gen_qname(self, field): + """ + Generates domain name. + """ + return "example.com." + + @classmethod + def name_matches(cls, name): + """ + Scapy returns the name of DNSQR as _both_ DNSQR and "DNS Question Record", + which breaks parsing. Override the name_matches method to handle that case + here. + """ + return name.upper() in ["DNSQR", "DNS QUESTION RECORD"] diff --git a/actions/packet.py b/actions/packet.py index ed1c575..f40b6d1 100644 --- a/actions/packet.py +++ b/actions/packet.py @@ -7,7 +7,9 @@ import actions.layer _SUPPORTED_LAYERS = [ actions.layer.IPLayer, actions.layer.TCPLayer, - actions.layer.UDPLayer + actions.layer.UDPLayer, + actions.layer.DNSLayer, + actions.layer.DNSQRLayer ] SUPPORTED_LAYERS = _SUPPORTED_LAYERS @@ -64,9 +66,25 @@ class Packet(): @staticmethod def _str_load(packet, protocol): """ - Prints packet payload + Prints DNS header for now """ - return str(packet[protocol].payload) + if packet.haslayer("DNS") and packet.haslayer("DNSQR"): + res = "%s:%s:%s " % ( + packet["DNSQR"].qname.decode('utf8'), + str(packet["DNSQR"].qtype), + str(packet["DNSQR"].qclass)) + DNS_res = "" + for i in range(packet["DNS"].ancount): + dnsrr = packet["DNS"].an[i] + DNS_res += " " + ':'.join([str(dnsrr.rrname.decode('utf8')), + str(dnsrr.type), + str(dnsrr.rclass), + str(dnsrr.ttl), + str(dnsrr.rdlen), + str(dnsrr.rdata)]) + return "%s %s" % (res, DNS_res) + else: + return str(packet[protocol].payload) def __bytes__(self): """ @@ -238,3 +256,77 @@ class Packet(): return layer return None + + @staticmethod + def reset_restrictions(): + """ + Removes layer and field restrictions. + """ + global SUPPORTED_LAYERS, _SUPPORTED_LAYERS + + SUPPORTED_LAYERS = _SUPPORTED_LAYERS + for layer in SUPPORTED_LAYERS: + layer.reset_restrictions() + + @staticmethod + def restrict_fields(logger, filter_protocols, filter_fields, disable_fields): + """ + Validates input arguments. Used by evolve.py to restrict the scope + of this evolution. + """ + global SUPPORTED_LAYERS + + if not disable_fields: + disable_fields = [] + + # First, apply a field whitelist if it was requested + valid = [] + if filter_fields: + for layer in SUPPORTED_LAYERS: + new_fields = [] + for field in filter_fields: + if field in layer.fields: + new_fields.append(field) + valid.append(field) + layer.fields = new_fields + + if valid and logger: + logger.info("Strategies will only be allowed to use fields: %s" % ", ".join(list(set(valid)))) + elif logger: + logger.error("None of the given fields exist in the packet headers of given protocols.") + + # Apply a field blacklist if it was requested + for field in disable_fields: + for layer in SUPPORTED_LAYERS: + layer.fields = [f for f in layer.fields if f not in disable_fields] + + if disable_fields and logger: + logger.info("Strategies will not be allowed to use fields %s" % ", ".join(disable_fields)) + + allowed_layers = [] + # Finally, filter protocols + for protocol in filter_protocols: + allowed_layer = Packet.get_supported_protocol(protocol) + if not allowed_layer: + if logger: + logger.error("%s not a supported protocol." % protocol) + continue + + # Only keep the layer allowed if it contains allowed fields + if allowed_layer.fields: + allowed_layers.append(allowed_layer) + + assert allowed_layers, "Cannot evolve with no available packet layers!" + + SUPPORTED_LAYERS = allowed_layers + + if logger and allowed_layers: + logger.info("Strategies will only be allowed to use protocols: %s" % ", ".join([l.name for l in allowed_layers])) + + def dns_decompress(self, logger): + """ + Performs DNS decompression, if applicable. Returns a new packet. + """ + self.packet = actions.layer.DNSLayer.dns_decompress(self.packet, logger) + self.layers = self.setup_layers() + return self diff --git a/tests/test_packet.py b/tests/test_packet.py new file mode 100644 index 0000000..6048f33 --- /dev/null +++ b/tests/test_packet.py @@ -0,0 +1,544 @@ +import logging +import pytest + +import actions.packet +import actions.layer + +from scapy.all import IP, TCP, UDP, DNS, DNSQR, Raw, DNSRR + +logger = logging.getLogger("test") + + +def test_parse_layers(): + """ + Tests layer parsing. + """ + pkt = IP()/TCP()/Raw("") + packet = actions.packet.Packet(pkt) + layers = list(packet.read_layers()) + assert layers[0].name == "IP" + assert layers[1].name == "TCP" + + layers_dict = packet.setup_layers() + assert layers_dict["IP"] + assert layers_dict["TCP"] + + +def test_get_random(): + """ + Tests get random + """ + + tcplayer = actions.layer.TCPLayer(TCP()) + field, value = tcplayer.get_random() + assert field in actions.layer.TCPLayer.fields + + +def test_gen_random(): + """ + Tests gen random + """ + for i in range(0, 2000): + layer, field, value = actions.packet.Packet().gen_random() + assert layer in [DNS, TCP, UDP, IP, DNSQR] + + +def test_dnsqr(): + """ + Tests DNSQR. + """ + pkt = UDP()/DNS(ancount=1)/DNSQR() + pkt.show() + packet = actions.packet.Packet(pkt) + packet.show() + assert len(packet.layers) == 3 + assert "UDP" in packet.layers + assert "DNS" in packet.layers + assert "DNSQR" in packet.layers + pkt = IP()/UDP()/DNS()/DNSQR() + packet = actions.packet.Packet(pkt) + assert str(packet) + + +def test_load(): + """ + Tests loads. + """ + tcp = actions.layer.TCPLayer(TCP()) + assert tcp.gen("load") + pkt = IP()/"datadata" + p = actions.packet.Packet(pkt) + assert p.get("IP", "load") == "datadata" + p2 = actions.packet.Packet(IP(bytes(p))) + assert p2.get("IP", "load") == "datadata" + p2.set("IP", "load", "data2") + # Check p is unchanged + assert p.get("IP", "load") == "datadata" + assert p2.get("IP", "load") == "data2" + p2.show2() + # Check that we can dump + assert p2.show2(dump=True) + # Check that we can dump + assert p2.show(dump=True) + assert p2.get("IP", "chksum") == None + + pkt = IP()/TCP()/"datadata" + p = actions.packet.Packet(pkt) + assert p.get("TCP", "load") == "datadata" + p2 = actions.packet.Packet(IP(bytes(p))) + assert p2.get("TCP", "load") == "datadata" + p2.set("TCP", "load", "data2") + # Check p is unchanged + assert p.get("TCP", "load") == "datadata" + assert p2.get("TCP", "load") == "data2" + p2.show2() + assert p2.get("IP", "chksum") == None + + +def test_parse_load(log_level): + """ + Tests load parsing. + """ + pkt = actions.packet.Packet(IP()/TCP()/"TYPE A\r\n") + print("Parsed: %s" % pkt.get("TCP", "load")) + + strat = actions.utils.parse("[TCP:load:TYPE%20A%0D%0A]-drop-| \/", logger) + results = strat.act_on_packet(pkt, logger) + assert not results + + value = pkt.gen("TCP", "load") + " " + pkt.gen("TCP", "load") + pkt.set("TCP", "load", value) + assert " " not in pkt.get("TCP", "load"), "%s contained a space!" % pkt.get("TCP", "load") + + +def test_dns(): + """ + Tests DNS layer. + """ + dns = actions.layer.DNSLayer(DNS()) + print(dns.gen("id")) + assert dns.gen("id") + + p = actions.packet.Packet(DNS(id=0xabcd)) + p2 = actions.packet.Packet(DNS(bytes(p))) + assert p.get("DNS", "id") == 0xabcd + assert p2.get("DNS", "id") == 0xabcd + + p2.set("DNS", "id", 0x4321) + assert p.get("DNS", "id") == 0xabcd # Check p is unchanged + assert p2.get("DNS", "id") == 0x4321 + + dns = actions.packet.Packet(DNS(aa=1)) + assert dns.get("DNS", "aa") == 1 + aa = dns.gen("DNS", "aa") + assert aa == 0 or aa == 1 + assert dns.get("DNS", "aa") == 1 # Original value unchanged + + dns = actions.packet.Packet(DNS(opcode=15)) + assert dns.get("DNS", "opcode") == 15 + opcode = dns.gen("DNS", "opcode") + assert opcode >= 0 and opcode <= 15 + assert dns.get("DNS", "opcode") == 15 # Original value unchanged + + dns.set("DNS", "opcode", 3) + assert dns.get("DNS", "opcode") == 3 + + dns = actions.packet.Packet(DNS(qr=0)) + assert dns.get("DNS", "qr") == 0 + qr = dns.gen("DNS", "qr") + assert qr == 0 or qr == 1 + assert dns.get("DNS", "qr") == 0 # Original value unchanged + + dns.set("DNS", "qr", 1) + assert dns.get("DNS", "qr") == 1 + + dns = actions.packet.Packet(DNS(arcount=0xAABB)) + assert dns.get("DNS", "arcount") == 0xAABB + arcount = dns.gen("DNS", "arcount") + assert arcount >= 0 and arcount <= 0xffff + assert dns.get("DNS", "arcount") == 0xAABB # Original value unchanged + + dns.set("DNS", "arcount", 65432) + assert dns.get("DNS", "arcount") == 65432 + + dns = actions.layer.DNSLayer(DNS()/DNSQR(qname="example.com")) + assert isinstance(dns.get_next_layer(), DNSQR) + print(dns.gen("id")) + assert dns.gen("id") + + p = actions.packet.Packet(DNS(id=0xabcd)) + p2 = actions.packet.Packet(DNS(bytes(p))) + assert p.get("DNS", "id") == 0xabcd + assert p2.get("DNS", "id") == 0xabcd + + +def test_read_layers(): + """ + Tests the ability to read each layer + """ + packet = IP() / UDP() / TCP() / DNS() / DNSQR(qname="example.com") / DNSQR(qname="example2.com") / DNSQR(qname="example3.com") + packet_geneva = actions.packet.Packet(packet) + packet_geneva.setup_layers() + + i = 0 + for layer in packet_geneva.read_layers(): + if i == 0: + assert isinstance(layer, actions.layer.IPLayer) + elif i == 1: + assert isinstance(layer, actions.layer.UDPLayer) + elif i == 2: + assert isinstance(layer, actions.layer.TCPLayer) + elif i == 3: + assert isinstance(layer, actions.layer.DNSLayer) + elif i == 4: + assert isinstance(layer, actions.layer.DNSQRLayer) + assert layer.layer.qname == b"example.com" + elif i == 5: + assert isinstance(layer, actions.layer.DNSQRLayer) + assert layer.layer.qname == b"example2.com" + elif i == 6: + assert isinstance(layer, actions.layer.DNSQRLayer) + assert layer.layer.qname == b"example3.com" + i += 1 + +def test_multi_opts(): + """ + Tests various option getting/setting. + """ + pkt = IP()/TCP(options=[('MSS', 1460), ('SAckOK', b''), ('Timestamp', (4154603075, 0)), ('NOP', None), ('WScale', 7)]) + packet = actions.packet.Packet(pkt) + assert packet.get("TCP", "options-sackok") == '' + assert packet.get("TCP", "options-mss") == 1460 + assert packet.get("TCP", "options-timestamp") == 4154603075 + assert packet.get("TCP", "options-wscale") == 7 + packet.set("TCP", "options-timestamp", 400000000) + assert packet.get("TCP", "options-sackok") == '' + assert packet.get("TCP", "options-mss") == 1460 + assert packet.get("TCP", "options-timestamp") == 400000000 + assert packet.get("TCP", "options-wscale") == 7 + pkt = IP()/TCP(options=[('SAckOK', b''), ('Timestamp', (4154603075, 0)), ('NOP', None), ('WScale', 7)]) + packet = actions.packet.Packet(pkt) + # If the option isn't present, it will be returned as an empty string + assert packet.get("TCP", "options-mss") == '' + packet.set("TCP", "options-mss", "") + assert packet.get("TCP", "options-mss") == 0 + + +def test_options_eol(): + """ + Tests options-eol. + """ + pkt = TCP(options=[("EOL", None)]) + p = actions.packet.Packet(pkt) + assert p.get("TCP", "options-eol") == "" + p2 = actions.packet.Packet(TCP(bytes(p))) + assert p2.get("TCP", "options-eol") == "" + p = actions.packet.Packet(IP()/TCP(options=[])) + assert p.get("TCP", "options-eol") == "" + p.set("TCP", "options-eol", "") + p.show() + assert len(p["TCP"].options) == 1 + assert any(k == "EOL" for k, v in p["TCP"].options) + value = p.gen("TCP", "options-eol") + assert value == "", "eol cannot store data" + p.set("TCP", "options-eol", value) + p2 = TCP(bytes(p)) + assert any(k == "EOL" for k, v in p2["TCP"].options) + + +def test_options_mss(): + """ + Tests options-eol. + """ + pkt = TCP(options=[("MSS", 1440)]) + p = actions.packet.Packet(pkt) + assert p.get("TCP", "options-mss") == 1440 + p2 = actions.packet.Packet(TCP(bytes(p))) + assert p2.get("TCP", "options-mss") == 1440 + p = actions.packet.Packet(TCP(options=[])) + assert p.get("TCP", "options-mss") == "" + p.set("TCP", "options-mss", 2880) + p.show() + assert len(p["TCP"].options) == 1 + assert any(k == "MSS" for k, v in p["TCP"].options) + value = p.gen("TCP", "options-mss") + p.set("TCP", "options-mss", value) + p2 = TCP(bytes(p)) + assert any(k == "MSS" for k, v in p2["TCP"].options) + + +def check_get(protocol, field, value): + """ + Checks if the get method worked for this protocol, field, and value. + """ + pkt = protocol() + setattr(pkt, field, value) + packet = actions.packet.Packet(pkt) + assert packet.get(protocol.__name__, field) == value + + +def get_test_configs(): + """ + Generates test configurations for the getters. + """ + return [ + (IP, 'version', 4), + (IP, 'version', 6), + (IP, 'version', 0), + (IP, 'ihl', 0), + (IP, 'tos', 0), + (IP, 'len', 50), + (IP, 'len', 6), + (IP, 'flags', 'MF'), + (IP, 'flags', 'DF'), + (IP, 'flags', 'MF+DF'), + (IP, 'ttl', 25), + (IP, 'proto', 4), + (IP, 'chksum', 0x4444), + (IP, 'src', '127.0.0.1'), + (IP, 'dst', '127.0.0.1'), + (TCP, 'sport', 12345), + (TCP, 'dport', 55555), + (TCP, 'seq', 123123123), + (TCP, 'ack', 181818181), + (TCP, 'dataofs', 5), + (TCP, 'dataofs', 0), + (TCP, 'dataofs', 15), + (TCP, 'reserved', 0), + (TCP, 'window', 100), + (TCP, 'chksum', 0x4444), + (TCP, 'urgptr', 1), + + (DNS, 'id', 0xabcd), + (DNS, 'qr', 1), + (DNS, 'opcode', 9), + (DNS, 'aa', 0), + (DNS, 'tc', 1), + (DNS, 'rd', 0), + (DNS, 'ra', 1), + (DNS, 'z', 0), + (DNS, 'ad', 1), + (DNS, 'cd', 0), + (DNS, 'qdcount', 0x1234), + (DNS, 'ancount', 12345), + (DNS, 'nscount', 49870), + (DNS, 'arcount', 0xABCD), + + (DNSQR, 'qname', 'example.com.'), + (DNSQR, 'qtype', 1), + (DNSQR, 'qclass', 0), + ] + + +def get_custom_configs(): + """ + Generates test configurations that can use the custom getters. + """ + return [ + (IP, 'flags', ''), + (TCP, 'options-eol', ''), + (TCP, 'options-nop', ''), + (TCP, 'options-mss', 0), + (TCP, 'options-mss', 1440), + (TCP, 'options-mss', 5000), + (TCP, 'options-wscale', 20), + (TCP, 'options-sackok', ''), + (TCP, 'options-sack', ''), + (TCP, 'options-timestamp', 12345678), + (TCP, 'options-altchksum', 0x44), + (TCP, 'options-altchksumopt', ''), + (TCP, 'options-uto', 1), + #(TCP, 'options-md5header', 'deadc0ffee') + ] + + +@pytest.mark.parametrize("config", get_test_configs(), + ids=['%s-%s-%s' % (proto.__name__, field, str(val)) for proto, field, val in get_test_configs()]) +def test_get(config): + """ + Tests value retrieval. + """ + proto, field, val = config + check_get(proto, field, val) + + +def check_set_get(protocol, field, value): + """ + Checks if the get method worked for this protocol, field, and value. + """ + pkt = actions.packet.Packet(protocol()) + pkt.set(protocol.__name__, field, value) + assert pkt.get(protocol.__name__, field) == value + # Rebuild the packet to confirm the type survived + pkt2 = actions.packet.Packet(protocol(bytes(pkt))) + assert pkt2.get(protocol.__name__, field) == value, "Value %s for header %s didn't survive packet parsing." % (value, field) + + +@pytest.mark.parametrize("config", get_test_configs() + get_custom_configs(), + ids=['%s-%s-%s' % (proto.__name__, field, str(val)) for proto, field, val in get_test_configs() + get_custom_configs()]) +def test_set_get(config): + """ + Tests value retrieval. + """ + proto, field, value = config + check_set_get(proto, field, value) + + +def check_gen_set_get(protocol, field): + """ + Checks if the get method worked for this protocol, field, and value. + """ + pkt = actions.packet.Packet(protocol()) + new_value = pkt.gen(protocol.__name__, field) + pkt.set(protocol.__name__, field, new_value) + assert pkt.get(protocol.__name__, field) == new_value + # Rebuild the packet to confirm the type survived + pkt2 = actions.packet.Packet(protocol(bytes(pkt))) + assert pkt2.get(protocol.__name__, field) == new_value + + +@pytest.mark.parametrize("config", get_test_configs() + get_custom_configs(), + ids=['%s-%s' % (proto.__name__, field) for proto, field, _ in get_test_configs() + get_custom_configs()]) +def test_gen_set_get(config): + """ + Tests value retrieval. + """ + # Test each generator 50 times to hit a range of values + for i in range(0, 50): + proto, field, _ = config + check_gen_set_get(proto, field) + + +def test_custom_get(): + """ + Tests value retrieval for custom getters. + """ + pkt = IP()/TCP()/Raw(load="AAAA") + tcp = actions.packet.Packet(pkt) + assert tcp.get("TCP", "load") == "AAAA" + + +def test_restrict_fields(): + """ + Tests packet field restriction. + """ + actions.packet.SUPPORTED_LAYERS = [ + actions.layer.IPLayer, + actions.layer.TCPLayer, + actions.layer.UDPLayer + ] + tcpfields = actions.layer.TCPLayer.fields + udpfields = actions.layer.UDPLayer.fields + ipfields = actions.layer.IPLayer.fields + + actions.packet.Packet.restrict_fields(logger, ["TCP", "UDP"], [], []) + assert len(actions.packet.SUPPORTED_LAYERS) == 2 + assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS + assert not actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS + + pkt = IP()/TCP() + packet = actions.packet.Packet(pkt) + assert "TCP" in packet.layers + assert not "IP" in packet.layers + assert len(packet.layers) == 1 + + for i in range(0, 2000): + layer, proto, field = actions.packet.Packet().gen_random() + assert layer in [TCP, UDP] + + # Check we can't retrieve any IP fields + for field in actions.layer.IPLayer.fields: + with pytest.raises(AssertionError): + packet.get("IP", field) + + # Check we can get all the TCP fields + for field in actions.layer.TCPLayer.fields: + packet.get("TCP", field) + + actions.packet.Packet.restrict_fields(logger, ["TCP", "UDP"], ["flags"], []) + packet = actions.packet.Packet(pkt) + assert len(actions.packet.SUPPORTED_LAYERS) == 1 + assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS + assert not actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS + assert not actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.TCPLayer.fields == ["flags"] + assert not actions.layer.UDPLayer.fields + + # Check we can't retrieve any IP fields + for field in actions.layer.IPLayer.fields: + with pytest.raises(AssertionError): + packet.get("IP", field) + + # Check we can get all the TCP fields + for field in tcpfields: + if field == "flags": + packet.get("TCP", field) + else: + with pytest.raises(AssertionError): + packet.get("TCP", field) + + for i in range(0, 2000): + layer, field, value = actions.packet.Packet().gen_random() + assert layer == TCP + assert field == "flags" + + actions.packet.Packet.reset_restrictions() + actions.packet.SUPPORTED_LAYERS = [ + actions.layer.IPLayer, + actions.layer.TCPLayer, + actions.layer.UDPLayer + ] + actions.packet.Packet.restrict_fields(logger, ["TCP", "IP"], [], ["sport", "dport", "seq", "src"]) + packet = actions.packet.Packet(pkt) + packet = packet.copy() + assert packet.has_supported_layers() + assert len(actions.packet.SUPPORTED_LAYERS) == 2 + assert actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS + assert not actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS + assert set(actions.layer.TCPLayer.fields) == set([f for f in tcpfields if f not in ["sport", "dport", "seq"]]) + assert set(actions.layer.IPLayer.fields) == set([f for f in ipfields if f not in ["src"]]) + + # Check we can't retrieve any IP fields + for field in actions.layer.IPLayer.fields: + if field == "src": + with pytest.raises(AssertionError): + packet.get("IP", field) + else: + packet.get("IP", field) + + # Check we can get all the TCP fields + for field in tcpfields: + if field in ["sport", "dport", "seq"]: + with pytest.raises(AssertionError): + packet.get("TCP", field) + else: + packet.get("TCP", field) + + for i in range(0, 2000): + layer, field, value = actions.packet.Packet().gen_random() + assert layer in [TCP, IP] + assert field not in ["sport", "dport", "seq", "src"] + + actions.packet.Packet.reset_restrictions() + actions.packet.SUPPORTED_LAYERS = [ + actions.layer.IPLayer, + actions.layer.TCPLayer, + actions.layer.UDPLayer + ] + + actions.packet.Packet.restrict_fields(logger, ["IP", "UDP", "DNS"], [], ["version"]) + packet = actions.packet.Packet(pkt) + proto, field, value = packet.get_random() + assert proto.__name__ in ["IP", "UDP"] + assert len(actions.packet.SUPPORTED_LAYERS) == 2 + assert not actions.layer.TCPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.UDPLayer in actions.packet.SUPPORTED_LAYERS + assert actions.layer.IPLayer in actions.packet.SUPPORTED_LAYERS + assert set(actions.layer.IPLayer.fields) == set([f for f in ipfields if f not in ["version"]]) + assert set(actions.layer.UDPLayer.fields) == set(udpfields) + + actions.packet.Packet.reset_restrictions() + for layer in actions.packet.SUPPORTED_LAYERS: + assert layer.fields, '%s has no fields - reset failed!' % str(layer)