geneva/tests/test_tree.py
2020-06-24 05:20:51 -07:00

686 lines
27 KiB
Python

import logging
import os
from scapy.all import IP, TCP
import actions.tree
import actions.drop
import actions.tamper
import actions.duplicate
import actions.utils
import layers.packet
def test_init():
"""
Tests initialization
"""
print(actions.action.Action.get_actions("out"))
def test_count_leaves():
"""
Tests leaf count is correct.
"""
a = actions.tree.ActionTree("out")
logger = logging.getLogger("test")
assert not a.parse("TCP:reserved:0tamper{TCP:flags:replace:S}-|", logger), "Tree parsed malformed DNA"
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
assert a.count_leaves() == 1
assert a.remove_one()
a.add_action(duplicate)
assert a.count_leaves() == 1
duplicate.left = duplicate2
assert a.count_leaves() == 1
duplicate.right = drop
assert a.count_leaves() == 2
def test_check():
"""
Tests action tree check function.
"""
a = actions.tree.ActionTree("out")
logger = logging.getLogger("test")
a.parse("[TCP:flags:RA]-tamper{TCP:flags:replace:S}-|", logger)
p = layers.packet.Packet(IP()/TCP(flags="A"))
assert not a.check(p, logger)
p = layers.packet.Packet(IP(ttl=64)/TCP(flags="RA"))
assert a.check(p, logger)
assert a.remove_one()
assert a.check(p, logger)
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
assert a.check(p, logger)
a.parse("[IP:ttl:64]-tamper{TCP:flags:replace:S}-|", logger)
assert a.check(p, logger)
p = layers.packet.Packet(IP(ttl=15)/TCP(flags="RA"))
assert not a.check(p, logger)
def test_scapy():
"""
Tests misc. scapy aspects relevant to strategies.
"""
a = actions.tree.ActionTree("out")
logger = logging.getLogger("test")
a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger)
p = layers.packet.Packet(IP()/TCP(flags="A"))
assert a.check(p, logger)
packets = a.run(p, logger)
assert packets[0][TCP].flags == "S"
p = layers.packet.Packet(IP()/TCP(flags="A"))
assert a.check(p, logger)
a.parse("[TCP:reserved:0]-tamper{TCP:chksum:corrupt}-|", logger)
packets = a.run(p, logger)
assert packets[0][TCP].chksum
assert a.check(p, logger)
def test_str():
"""
Tests string representation.
"""
logger = logging.getLogger("test")
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
assert str(a).strip() == "[%s]-|" % str(t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
assert a.add_action(tamper)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|"
# Tree will not add a duplicate action
assert not a.add_action(tamper)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|"
assert a.add_action(tamper2)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|"
assert a.add_action(actions.duplicate.DuplicateAction())
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
drop = actions.drop.DropAction()
assert a.add_action(drop)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(drop,),),)-|" or \
str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(,drop),),)-|"
assert a.remove_action(drop)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
# Cannot remove action that is not present
assert not a.remove_action(drop)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|"
a = actions.tree.ActionTree("out", trigger=t)
orig = "[TCP:urgptr:15963]-duplicate(,drop)-|"
a.parse(orig, logger)
assert a.remove_one()
assert orig != str(a)
assert str(a) in ["[TCP:urgptr:15963]-drop-|", "[TCP:urgptr:15963]-duplicate-|"]
def test_pretty_print_send():
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
duplicate = actions.duplicate.DuplicateAction()
a.add_action(duplicate)
correct_string = "TCP:flags:0\nduplicate\n├── ===> \n└── ===> "
assert a.pretty_print() == correct_string
def test_pretty_print(logger):
"""
Print complex tree, although difficult to test
"""
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
duplicate3 = actions.duplicate.DuplicateAction()
duplicate4 = actions.duplicate.DuplicateAction()
duplicate5 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
drop2 = actions.drop.DropAction()
drop3 = actions.drop.DropAction()
drop4 = actions.drop.DropAction()
duplicate.left = duplicate2
duplicate.right = duplicate3
duplicate2.left = tamper
duplicate2.right = drop
duplicate3.left = duplicate4
duplicate3.right = drop2
duplicate4.left = duplicate5
duplicate4.right = drop3
duplicate5.left = drop4
duplicate5.right = tamper2
a.add_action(duplicate)
correct_string = "TCP:flags:0\nduplicate\n├── duplicate\n│ ├── tamper{TCP:flags:replace:S}\n│ │ └── ===> \n│ └── drop\n└── duplicate\n ├── duplicate\n │ ├── duplicate\n │ │ ├── drop\n │ │ └── tamper{TCP:flags:replace:R}\n │ │ └── ===> \n │ └── drop\n └── drop"
assert a.pretty_print() == correct_string
assert a.pretty_print(visual=True)
assert os.path.exists("tree.png")
os.remove("tree.png")
a.parse("[TCP:flags:0]-|", logger)
a.pretty_print(visual=True) # Empty action tree
assert not os.path.exists("tree.png")
def test_pretty_print_order():
"""
Tests the left/right ordering by reading in a new tree
"""
logger = logging.getLogger("test")
a = actions.tree.ActionTree("out")
assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S}(tamper{TCP:chksum:replace:14239},),))-|", logger)
correct_pretty_print = "TCP:flags:A\nduplicate\n├── tamper{TCP:flags:replace:R}\n│ └── tamper{TCP:chksum:replace:14239}\n│ └── ===> \n└── duplicate\n ├── tamper{TCP:flags:replace:S}\n │ └── tamper{TCP:chksum:replace:14239}\n │ └── ===> \n └── ===> "
assert a.pretty_print() == correct_pretty_print
def test_parse():
"""
Tests string parsing.
"""
logger = logging.getLogger("test")
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
base_t = actions.trigger.Trigger("field", "flags", "TCP")
base_a = actions.tree.ActionTree("out", trigger=base_t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
a.parse("[TCP:flags:0]-|", logger)
assert str(a) == str(base_a)
assert len(a) == 0
base_a.add_action(tamper)
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}-|", logger)
assert str(a) == str(base_a)
assert len(a) == 1
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|", logging.getLogger("test"))
base_a.add_action(tamper2)
assert str(a) == str(base_a)
assert len(a) == 2
base_a.add_action(tamper3)
base_a.add_action(tamper4)
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},),),)-|", logging.getLogger("test"))
assert str(a) == str(base_a)
assert len(a) == 4
base_t = actions.trigger.Trigger("field", "flags", "TCP")
base_a = actions.tree.ActionTree("out", trigger=base_t)
duplicate = actions.duplicate.DuplicateAction()
assert a.parse("[TCP:flags:0]-duplicate-|", logger)
base_a.add_action(duplicate)
assert str(a) == str(base_a)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="A")
tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
duplicate.left = tamper
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},)-|", logger)
assert str(a) == str(base_a)
duplicate.right = tamper2
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-|", logger)
assert str(a) == str(base_a)
tamper2.left = tamper3
assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:A},))-|", logger)
assert str(a) == str(base_a)
strategy = actions.utils.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-| \/", logger)
assert strategy
assert len(strategy.out_actions[0]) == 3
assert len(strategy.in_actions) == 0
assert not a.parse("[]", logger) # No valid trigger
assert not a.parse("[TCP:flags:0]-", logger) # No valid ending "|"
assert not a.parse("[TCP:]-|", logger) # invalid trigger
assert not a.parse("[TCP:flags:0]-foo-|", logger) # Non-existent action
assert not a.parse("[TCP:flags:0]--|", logger) # Empty action
assert not a.parse("[TCP:flags:0]-duplicate(,,,)-|", logger) # Bad tree
assert not a.parse("[TCP:flags:0]-duplicate()))-|", logger) # Bad tree
assert not a.parse("[TCP:flags:0]-duplicate(((()-|", logger) # Bad tree
assert not a.parse("[TCP:flags:0]-duplicate(,))))-|", logger) # Bad tree
assert not a.parse("[TCP:flags:0]-drop(duplicate,)-|", logger) # Terminal action with children
assert not a.parse("[TCP:flags:0]-drop(duplicate,duplicate)-|", logger) # Terminal action with children
assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(,duplicate)-|", logger) # Non-branching action with right child
assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(drop,duplicate)-|", logger) # Non-branching action with children
def test_tree():
"""
Tests basic tree functionality.
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
duplicate = actions.duplicate.DuplicateAction()
a.add_action(None)
a.add_action(tamper)
assert a.get_slots() == 1
a.add_action(tamper2)
assert a.get_slots() == 1
a.add_action(duplicate)
assert a.get_slots() == 2
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
drop = actions.drop.DropAction()
a.add_action(drop)
assert a.get_slots() == 0
add_success = a.add_action(tamper)
assert not add_success
assert a.get_slots() == 0
rep = ""
for s in a.string_repr(a.action_root):
rep += s
assert rep == "drop"
print(str(a))
assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:seq:corrupt},)-|", logging.getLogger("test"))
for act in a:
print(str(a))
assert len(a) == 2
assert a.get_slots() == 2
for _ in range(100):
assert str(a.get_rand_action("out", request="DropAction")) == "drop"
def test_remove():
"""
Tests remove
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
tamper3 = actions.tamper.TamperAction()
assert not a.remove_action(tamper)
a.add_action(tamper)
assert a.remove_action(tamper)
a.add_action(tamper)
a.add_action(tamper2)
a.add_action(tamper3)
assert a.remove_action(tamper2)
assert tamper2 not in a
assert tamper.left == tamper3
assert not tamper.right
assert len(a) == 2
a = actions.tree.ActionTree("out", trigger=t)
duplicate = actions.duplicate.DuplicateAction()
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
tamper3 = actions.tamper.TamperAction()
a.add_action(tamper)
assert a.action_root == tamper
duplicate.left = tamper2
duplicate.right = tamper3
a.add_action(duplicate)
assert len(a) == 4
assert a.remove_action(duplicate)
assert duplicate not in a
assert tamper.left == tamper2
assert not tamper.right
assert len(a) == 2
a.parse("[TCP:flags:A]-|", logging.getLogger("test"))
assert not a.remove_one(), "Cannot remove one with no action root"
def test_len():
"""
Tests length calculation.
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
assert len(a) == 0, "__len__ returned wrong length"
a.add_action(tamper)
assert len(a) == 1, "__len__ returned wrong length"
a.add_action(tamper)
assert len(a) == 1, "__len__ returned wrong length"
a.add_action(tamper2)
assert len(a) == 2, "__len__ returned wrong length"
duplicate = actions.duplicate.DuplicateAction()
a.add_action(duplicate)
assert len(a) == 3, "__len__ returned wrong length"
def test_contains():
"""
Tests contains method
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction()
tamper2 = actions.tamper.TamperAction()
tamper3 = actions.tamper.TamperAction()
assert not a.contains(tamper), "contains incorrect behavior"
assert not a.contains(tamper2), "contains incorrect behavior"
a.add_action(tamper)
assert a.contains(tamper), "contains incorrect behavior"
assert not a.contains(tamper2), "contains incorrect behavior"
add_success = a.add_action(tamper)
assert not add_success, "added duplicate action"
assert a.contains(tamper), "contains incorrect behavior"
assert not a.contains(tamper2), "contains incorrect behavior"
a.add_action(tamper2)
assert a.contains(tamper), "contains incorrect behavior"
assert a.contains(tamper2), "contains incorrect behavior"
a.remove_action(tamper2)
assert a.contains(tamper), "contains incorrect behavior"
assert not a.contains(tamper2), "contains incorrect behavior"
a.add_action(tamper2)
assert a.contains(tamper), "contains incorrect behavior"
assert a.contains(tamper2), "contains incorrect behavior"
remove_success = a.remove_action(tamper)
assert remove_success
assert not a.contains(tamper), "contains incorrect behavior"
assert a.contains(tamper2), "contains incorrect behavior"
a.add_action(tamper3)
assert a.contains(tamper3), "contains incorrect behavior"
assert len(a) == 2, "len incorrect return"
remove_success = a.remove_action(tamper2)
assert remove_success
def test_iter():
"""
Tests iterator.
"""
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
assert a.add_action(tamper)
assert a.add_action(tamper2)
assert not a.add_action(tamper)
for node in a:
print(node)
def test_run():
"""
Tests running packets through the chain.
"""
logger = logging.getLogger("test")
t = actions.trigger.Trigger(None, None, None)
a = actions.tree.ActionTree("out", trigger=t)
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
packet = layers.packet.Packet(IP()/TCP())
a.add_action(tamper)
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 1
assert None not in packets
assert packets[0].get("TCP", "flags") == "S"
a.add_action(tamper2)
print(str(a))
packet = layers.packet.Packet(IP()/TCP())
assert not a.add_action(tamper), "tree added duplicate action"
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 1
assert None not in packets
assert packets[0].get("TCP", "flags") == "R"
print(str(a))
a.remove_action(tamper2)
a.remove_action(tamper)
a.add_action(duplicate)
packet = layers.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 2
assert None not in packets
assert packets[0][TCP].flags == "RA"
assert packets[1][TCP].flags == "RA"
print(str(a))
duplicate.left = tamper
duplicate.right = tamper2
packet = layers.packet.Packet(IP()/TCP(flags="RA"))
print("ABUT TO RUN")
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 2
assert None not in packets
print(str(a))
print(str(packets[0]))
print(str(packets[1]))
assert packets[0][TCP].flags == "S"
assert packets[1][TCP].flags == "R"
print(str(a))
tamper.left = duplicate2
packet = layers.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 3
assert None not in packets
assert packets[0][TCP].flags == "S"
assert packets[1][TCP].flags == "S"
assert packets[2][TCP].flags == "R"
print(str(a))
tamper2.left = drop
packet = layers.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logging.getLogger("test"))
assert len(packets) == 2
assert None not in packets
assert packets[0][TCP].flags == "S"
assert packets[1][TCP].flags == "S"
print(str(a))
assert a.remove_action(duplicate2)
tamper.left = actions.drop.DropAction()
packet = layers.packet.Packet(IP()/TCP(flags="RA"))
packets = a.run(packet, logger )
assert len(packets) == 0
print(str(a))
a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S},))-|", logger)
packet = layers.packet.Packet(IP()/TCP(flags="A"))
assert a.check(packet, logger)
packets = a.run(packet, logger)
assert len(packets) == 3
assert packets[0][TCP].flags == "R"
assert packets[1][TCP].flags == "S"
assert packets[2][TCP].flags == "A"
def test_index():
"""
Tests index
"""
a = actions.tree.ActionTree("out")
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="F")
assert a.add_action(tamper)
assert a[0] == tamper
assert not a[1]
assert a.add_action(tamper2)
assert a[0] == tamper
assert a[1] == tamper2
assert a[-1] == tamper2
assert not a[10]
assert a.add_action(tamper3)
assert a[-1] == tamper3
assert not a[-11]
def test_mate():
"""
Tests mate primitive
"""
logger = logging.getLogger("test")
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
assert not a.choose_one()
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
other_a = actions.tree.ActionTree("out", trigger=t)
assert not a.mate(other_a), "Can't mate empty trees"
assert a.add_action(tamper)
assert other_a.add_action(tamper2)
assert a.choose_one() == tamper
assert other_a.choose_one() == tamper2
assert a.get_parent(tamper) == (None, None)
assert other_a.get_parent(tamper2) == (None, None)
assert a.add_action(duplicate)
assert a.get_parent(duplicate) == (tamper, "left")
duplicate.right = drop
assert a.get_parent(drop) == (duplicate, "right")
assert other_a.add_action(duplicate2)
# Test mating a full tree with a full tree
assert str(a) == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(duplicate(,drop),)-|"
assert str(other_a) == "[TCP:flags:0]-tamper{TCP:flags:replace:R}(duplicate,)-|"
assert a.swap(duplicate, other_a, duplicate2)
assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(duplicate,)-|"
assert str(other_a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:R}(duplicate(,drop),)-|"
assert len(a) == 2
assert len(other_a) == 3
assert duplicate2 not in other_a
assert duplicate not in a
assert tamper.left == duplicate2
assert tamper2.left == duplicate
assert other_a.get_parent(duplicate) == (tamper2, "left")
assert a.get_parent(duplicate2) == (tamper, "left")
assert other_a.get_parent(drop) == (duplicate, "right")
assert a.get_parent(None) == (None, None)
# Test mating two trees with just root nodes
t = actions.trigger.Trigger("field", "flags", "TCP")
a = actions.tree.ActionTree("out", trigger=t)
assert not a.choose_one()
tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S")
tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R")
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
other_a = actions.tree.ActionTree("out", trigger=t)
assert not a.mate(other_a)
assert a.add_action(duplicate)
assert other_a.add_action(duplicate2)
assert a.mate(other_a)
assert a.action_root == duplicate2
assert other_a.action_root == duplicate
assert not duplicate.left and not duplicate.right
assert not duplicate2.left and not duplicate2.right
# Confirm that no nodes have been aliased or connected between the trees
for node in a:
for other_node in other_a:
assert not node.left == other_node
assert not node.right == other_node
# Test mating two trees where one is empty
assert a.remove_action(duplicate2)
# This should swap the duplicate action to be the action root of the other tree
assert str(a) == "[TCP:flags:0]-|"
assert str(other_a) == "[TCP:flags:0]-duplicate-|"
assert a.mate(other_a)
assert not other_a.action_root
assert a.action_root == duplicate
assert len(a) == 1
assert len(other_a) == 0
# Confirm that no nodes have been aliased or connected between the trees
for node in a:
for other_node in other_a:
if other_node:
assert not node.left == other_node
assert not node.right == other_node
assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(duplicate(,drop),)-|", logger)
drop = a.action_root.left.right
assert str(drop) == "drop"
# Note that this will return a valid ActionTree, but because it is empty,
# it is technically a False-y value, as it's length is 0
assert other_a.parse("[TCP:flags:0]-|", logger) == other_a
a.swap(drop, other_a, None)
assert other_a.action_root == drop
assert not a.action_root.left.right
assert str(other_a) == "[TCP:flags:0]-drop-|"
assert str(a) == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(duplicate,)-|"
other_a.swap(drop, a, a.action_root.left)
# Confirm that no nodes have been aliased or connected between the trees
for node in a:
for other_node in other_a:
if other_node:
assert not node.left == other_node
assert not node.right == other_node
assert str(other_a) == "[TCP:flags:0]-duplicate-|"
assert str(a) == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(drop,)-|"
a.parse("[TCP:flags:0]-drop-|", logger)
other_a.parse("[TCP:flags:0]-duplicate(drop,drop)-|", logger)
a_drop = a.action_root
other_duplicate = other_a.action_root
a.swap(a_drop, other_a, other_duplicate)
print(str(a))
print(str(other_a))
assert str(other_a) == "[TCP:flags:0]-drop-|"
assert str(a) == "[TCP:flags:0]-duplicate(drop,drop)-|"
duplicate = actions.duplicate.DuplicateAction()
duplicate2 = actions.duplicate.DuplicateAction()
drop = actions.drop.DropAction()
drop2 = actions.drop.DropAction()
drop3 = actions.drop.DropAction()
a = actions.tree.ActionTree("out", trigger=t)
a.add_action(duplicate)
a.add_action(drop)
a.add_action(drop2)
assert str(a) == "[TCP:flags:0]-duplicate(drop,drop)-|"
assert a.get_slots() == 0
other_a = actions.tree.ActionTree("out", trigger=t)
other_a.add_action(drop3)
a.swap(drop, other_a, drop3)
assert str(a) == "[TCP:flags:0]-duplicate(drop,drop)-|"
a.swap(drop3, other_a, drop)
assert str(a) == "[TCP:flags:0]-duplicate(drop,drop)-|"
assert a.mate(other_a)
def test_choose_one():
"""
Tests choose_one functionality
"""
a = actions.tree.ActionTree("out")
drop = actions.drop.DropAction()
assert not a.choose_one()
assert a.add_action(drop)
assert a.choose_one() == drop
assert a.remove_action(drop)
assert not a.choose_one()
duplicate = actions.duplicate.DuplicateAction()
a.add_action(duplicate)
assert a.choose_one() == duplicate
duplicate.left = drop
assert a.choose_one() in [duplicate, drop]
# Make sure that both actions get chosen
chosen = set()
for i in range(0, 10000):
act = a.choose_one()
chosen.add(act)
assert chosen == set([duplicate, drop])