diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..9827159 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,39 @@ +sudo: required +dist: "bionic" + +language: python + +python: + - "3.6" + +install: + # Travis recently added systemd-resolvd to their VMs. Since full Geneva often runs its own DNS + # server to test DNS strategies, we need to disable system-resolvd. + # First disable the service + - sudo systemctl disable systemd-resolved.service + # Stop the service + - sudo systemctl stop systemd-resolved + # With systemd not running, our own hostname won't resolve - this causes issues with sudo. + # Add back our hostname to /etc/hosts/ so sudo does not complain + - echo $(hostname -I | cut -d\ -f1) $(hostname) | sudo tee -a /etc/hosts + # Replace the 127.0.0.53 nameserver with Google's + - sudo sed 's/nameserver.*/nameserver 8.8.8.8/' /etc/resolv.conf > /tmp/resolv.conf.new + - sudo mv /tmp/resolv.conf.new /etc/resolv.conf + # Now that systemd-resolv.conf is safely disabled, we can now setup for Geneva + - sudo apt-get clean # travis having mirror sync issues + # Install dependencies + - sudo apt-get update + - sudo apt-get -y install libnetfilter-queue-dev python3 python3-pip python3-setuptools graphviz + # Since sudo is required but travis does not set up the root environment, we must override the + # secure_path in sudoers in order for travis's setup to take effect for sudo commands + - printf "Defaults\tenv_reset\nDefaults\tmail_badpass\nDefaults\tsecure_path="/home/travis/virtualenv/python3.6.7/bin/:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin"\nroot\tALL=(ALL:ALL) ALL\n#includedir /etc/sudoers.d\n" > /tmp/sudoers.tmp + # Verify the sudoers file + - sudo visudo -c -f /tmp/sudoers.tmp + # Copy in the sudoers file + - sudo cp /tmp/sudoers.tmp /etc/sudoers + # Now that sudo is good to go, finish installing dependencies + - sudo python3 -m pip install -r requirements.txt + - sudo python3 -m pip install slackclient pytest-cov + +script: + - sudo python3 -m pytest --cov=./ -sv tests/ --tb=short diff --git a/tests/test_tree.py b/tests/test_tree.py new file mode 100644 index 0000000..8bbc7c7 --- /dev/null +++ b/tests/test_tree.py @@ -0,0 +1,545 @@ +import logging +import os + +from scapy.all import IP, TCP +import actions.tree +import actions.drop +import actions.tamper +import actions.duplicate +import actions.utils + + +def test_init(): + """ + Tests initialization + """ + print(actions.action.Action.get_actions("out")) + + +def test_count_leaves(): + """ + Tests leaf count is correct. + """ + a = actions.tree.ActionTree("out") + logger = logging.getLogger("test") + + assert not a.parse("TCP:reserved:0tamper{TCP:flags:replace:S}-|", logger), "Tree parsed malformed DNA" + a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger) + duplicate = actions.duplicate.DuplicateAction() + duplicate2 = actions.duplicate.DuplicateAction() + drop = actions.drop.DropAction() + + assert a.count_leaves() == 1 + assert a.remove_one() + a.add_action(duplicate) + assert a.count_leaves() == 1 + duplicate.left = duplicate2 + assert a.count_leaves() == 1 + duplicate.right = drop + assert a.count_leaves() == 2 + + +def test_check(): + """ + Tests action tree check function. + """ + a = actions.tree.ActionTree("out") + logger = logging.getLogger("test") + a.parse("[TCP:flags:RA]-tamper{TCP:flags:replace:S}-|", logger) + p = actions.packet.Packet(IP()/TCP(flags="A")) + assert not a.check(p, logger) + p = actions.packet.Packet(IP(ttl=64)/TCP(flags="RA")) + assert a.check(p, logger) + assert a.remove_one() + assert a.check(p, logger) + a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger) + assert a.check(p, logger) + a.parse("[IP:ttl:64]-tamper{TCP:flags:replace:S}-|", logger) + assert a.check(p, logger) + p = actions.packet.Packet(IP(ttl=15)/TCP(flags="RA")) + assert not a.check(p, logger) + + +def test_scapy(): + """ + Tests misc. scapy aspects relevant to strategies. + """ + a = actions.tree.ActionTree("out") + logger = logging.getLogger("test") + a.parse("[TCP:reserved:0]-tamper{TCP:flags:replace:S}-|", logger) + p = actions.packet.Packet(IP()/TCP(flags="A")) + assert a.check(p, logger) + packets = a.run(p, logger) + assert packets[0][TCP].flags == "S" + p = actions.packet.Packet(IP()/TCP(flags="A")) + assert a.check(p, logger) + a.parse("[TCP:reserved:0]-tamper{TCP:chksum:corrupt}-|", logger) + packets = a.run(p, logger) + assert packets[0][TCP].chksum + assert a.check(p, logger) + + +def test_str(): + """ + Tests string representation. + """ + logger = logging.getLogger("test") + + t = actions.trigger.Trigger("field", "flags", "TCP") + a = actions.tree.ActionTree("out", trigger=t) + assert str(a).strip() == "[%s]-|" % str(t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + assert a.add_action(tamper) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|" + # Tree will not add a duplicate action + assert not a.add_action(tamper) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}-|" + assert a.add_action(tamper2) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|" + assert a.add_action(actions.duplicate.DuplicateAction()) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|" + drop = actions.drop.DropAction() + assert a.add_action(drop) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(drop,),),)-|" or \ + str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate(,drop),),)-|" + assert a.remove_action(drop) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|" + # Cannot remove action that is not present + assert not a.remove_action(drop) + assert str(a).strip() == "[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(duplicate,),)-|" + + a = actions.tree.ActionTree("out", trigger=t) + orig = "[TCP:urgptr:15963]-duplicate(,drop)-|" + a.parse(orig, logger) + assert a.remove_one() + assert orig != str(a) + assert str(a) in ["[TCP:urgptr:15963]-drop-|", "[TCP:urgptr:15963]-duplicate-|"] + + +def test_pretty_print_send(): + t = actions.trigger.Trigger("field", "flags", "TCP") + a = actions.tree.ActionTree("out", trigger=t) + duplicate = actions.duplicate.DuplicateAction() + a.add_action(duplicate) + correct_string = "TCP:flags:0\nduplicate\n├── ===> \n└── ===> " + assert a.pretty_print() == correct_string + + +def test_pretty_print(): + """ + Print complex tree, although difficult to test + """ + t = actions.trigger.Trigger("field", "flags", "TCP") + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + duplicate = actions.duplicate.DuplicateAction() + duplicate2 = actions.duplicate.DuplicateAction() + duplicate3 = actions.duplicate.DuplicateAction() + duplicate4 = actions.duplicate.DuplicateAction() + duplicate5 = actions.duplicate.DuplicateAction() + drop = actions.drop.DropAction() + drop2 = actions.drop.DropAction() + drop3 = actions.drop.DropAction() + drop4 = actions.drop.DropAction() + + duplicate.left = duplicate2 + duplicate.right = duplicate3 + duplicate2.left = tamper + duplicate2.right = drop + duplicate3.left = duplicate4 + duplicate3.right = drop2 + duplicate4.left = duplicate5 + duplicate4.right = drop3 + duplicate5.left = drop4 + duplicate5.right = tamper2 + + a.add_action(duplicate) + correct_string = "TCP:flags:0\nduplicate\n├── duplicate\n│ ├── tamper{TCP:flags:replace:S}\n│ │ └── ===> \n│ └── drop\n└── duplicate\n ├── duplicate\n │ ├── duplicate\n │ │ ├── drop\n │ │ └── tamper{TCP:flags:replace:R}\n │ │ └── ===> \n │ └── drop\n └── drop" + assert a.pretty_print() == correct_string + assert a.pretty_print(visual=True) + assert os.path.exists("tree.png") + os.remove("tree.png") + a.parse("[TCP:flags:0]-|", logging.getLogger("test")) + a.pretty_print(visual=True) # Empty action tree + assert not os.path.exists("tree.png") + +def test_pretty_print_order(): + """ + Tests the left/right ordering by reading in a new tree + """ + logger = logging.getLogger("test") + a = actions.tree.ActionTree("out") + assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S}(tamper{TCP:chksum:replace:14239},),))-|", logger) + correct_pretty_print = "TCP:flags:A\nduplicate\n├── tamper{TCP:flags:replace:R}\n│ └── tamper{TCP:chksum:replace:14239}\n│ └── ===> \n└── duplicate\n ├── tamper{TCP:flags:replace:S}\n │ └── tamper{TCP:chksum:replace:14239}\n │ └── ===> \n └── ===> " + assert a.pretty_print() == correct_pretty_print + +def test_parse(): + """ + Tests string parsing. + """ + logger = logging.getLogger("test") + t = actions.trigger.Trigger("field", "flags", "TCP") + a = actions.tree.ActionTree("out", trigger=t) + + base_t = actions.trigger.Trigger("field", "flags", "TCP") + base_a = actions.tree.ActionTree("out", trigger=base_t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + a.parse("[TCP:flags:0]-|", logger) + assert str(a) == str(base_a) + assert len(a) == 0 + + base_a.add_action(tamper) + + assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}-|", logger) + + assert str(a) == str(base_a) + assert len(a) == 1 + assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},)-|", logging.getLogger("test")) + base_a.add_action(tamper2) + assert str(a) == str(base_a) + assert len(a) == 2 + + base_a.add_action(tamper3) + base_a.add_action(tamper4) + assert a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:S}(tamper{TCP:flags:replace:R},),),)-|", logging.getLogger("test")) + assert str(a) == str(base_a) + assert len(a) == 4 + + base_t = actions.trigger.Trigger("field", "flags", "TCP") + base_a = actions.tree.ActionTree("out", trigger=base_t) + duplicate = actions.duplicate.DuplicateAction() + assert a.parse("[TCP:flags:0]-duplicate-|", logger) + base_a.add_action(duplicate) + assert str(a) == str(base_a) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="A") + tamper4 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + duplicate.left = tamper + assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},)-|", logger) + assert str(a) == str(base_a) + + duplicate.right = tamper2 + assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-|", logger) + assert str(a) == str(base_a) + + tamper2.left = tamper3 + assert a.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R}(tamper{TCP:flags:replace:A},))-|", logger) + assert str(a) == str(base_a) + + strategy = actions.utils.parse("[TCP:flags:0]-duplicate(tamper{TCP:flags:replace:S},tamper{TCP:flags:replace:R})-| \/", logger) + assert strategy + assert len(strategy.out_actions[0]) == 3 + assert len(strategy.in_actions) == 0 + + assert not a.parse("[]", logger) # No valid trigger + assert not a.parse("[TCP:flags:0]-", logger) # No valid ending "|" + assert not a.parse("[TCP:]-|", logger) # invalid trigger + assert not a.parse("[TCP:flags:0]-foo-|", logger) # Non-existent action + assert not a.parse("[TCP:flags:0]--|", logger) # Empty action + assert not a.parse("[TCP:flags:0]-duplicate(,,,)-|", logger) # Bad tree + assert not a.parse("[TCP:flags:0]-duplicate()))-|", logger) # Bad tree + assert not a.parse("[TCP:flags:0]-duplicate(((()-|", logger) # Bad tree + assert not a.parse("[TCP:flags:0]-duplicate(,))))-|", logger) # Bad tree + assert not a.parse("[TCP:flags:0]-drop(duplicate,)-|", logger) # Terminal action with children + assert not a.parse("[TCP:flags:0]-drop(duplicate,duplicate)-|", logger) # Terminal action with children + assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(,duplicate)-|", logger) # Non-branching action with right child + assert not a.parse("[TCP:flags:0]-tamper{TCP:flags:replace:S}(drop,duplicate)-|", logger) # Non-branching action with children + + +def test_tree(): + """ + Tests basic tree functionality. + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + duplicate = actions.duplicate.DuplicateAction() + + a.add_action(None) + a.add_action(tamper) + assert a.get_slots() == 1 + a.add_action(tamper2) + assert a.get_slots() == 1 + a.add_action(duplicate) + assert a.get_slots() == 2 + + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + drop = actions.drop.DropAction() + a.add_action(drop) + assert a.get_slots() == 0 + add_success = a.add_action(tamper) + assert not add_success + assert a.get_slots() == 0 + + rep = "" + for s in a.string_repr(a.action_root): + rep += s + assert rep == "drop" + + print(str(a)) + + assert a.parse("[TCP:flags:A]-duplicate(tamper{TCP:seq:corrupt},)-|", logging.getLogger("test")) + for act in a: + print(str(a)) + assert len(a) == 2 + assert a.get_slots() == 2 + + +def test_remove(): + """ + Tests remove + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + tamper3 = actions.tamper.TamperAction() + assert not a.remove_action(tamper) + a.add_action(tamper) + assert a.remove_action(tamper) + a.add_action(tamper) + a.add_action(tamper2) + a.add_action(tamper3) + assert a.remove_action(tamper2) + assert tamper2 not in a + assert tamper.left == tamper3 + assert not tamper.right + assert len(a) == 2 + a = actions.tree.ActionTree("out", trigger=t) + duplicate = actions.duplicate.DuplicateAction() + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + tamper3 = actions.tamper.TamperAction() + a.add_action(tamper) + assert a.action_root == tamper + duplicate.left = tamper2 + duplicate.right = tamper3 + a.add_action(duplicate) + assert len(a) == 4 + assert a.remove_action(duplicate) + assert duplicate not in a + assert tamper.left == tamper2 + assert not tamper.right + assert len(a) == 2 + + a.parse("[TCP:flags:A]-|", logging.getLogger("test")) + assert not a.remove_one(), "Cannot remove one with no action root" + + +def test_len(): + """ + Tests length calculation. + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + assert len(a) == 0, "__len__ returned wrong length" + a.add_action(tamper) + assert len(a) == 1, "__len__ returned wrong length" + a.add_action(tamper) + assert len(a) == 1, "__len__ returned wrong length" + a.add_action(tamper2) + assert len(a) == 2, "__len__ returned wrong length" + duplicate = actions.duplicate.DuplicateAction() + a.add_action(duplicate) + assert len(a) == 3, "__len__ returned wrong length" + + +def test_contains(): + """ + Tests contains method + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction() + tamper2 = actions.tamper.TamperAction() + tamper3 = actions.tamper.TamperAction() + + assert not a.contains(tamper), "contains incorrect behavior" + assert not a.contains(tamper2), "contains incorrect behavior" + a.add_action(tamper) + assert a.contains(tamper), "contains incorrect behavior" + assert not a.contains(tamper2), "contains incorrect behavior" + add_success = a.add_action(tamper) + assert not add_success, "added duplicate action" + assert a.contains(tamper), "contains incorrect behavior" + assert not a.contains(tamper2), "contains incorrect behavior" + a.add_action(tamper2) + assert a.contains(tamper), "contains incorrect behavior" + assert a.contains(tamper2), "contains incorrect behavior" + a.remove_action(tamper2) + assert a.contains(tamper), "contains incorrect behavior" + assert not a.contains(tamper2), "contains incorrect behavior" + a.add_action(tamper2) + assert a.contains(tamper), "contains incorrect behavior" + assert a.contains(tamper2), "contains incorrect behavior" + remove_success = a.remove_action(tamper) + assert remove_success + assert not a.contains(tamper), "contains incorrect behavior" + assert a.contains(tamper2), "contains incorrect behavior" + a.add_action(tamper3) + assert a.contains(tamper3), "contains incorrect behavior" + assert len(a) == 2, "len incorrect return" + remove_success = a.remove_action(tamper2) + assert remove_success + + +def test_iter(): + """ + Tests iterator. + """ + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + + assert a.add_action(tamper) + assert a.add_action(tamper2) + assert not a.add_action(tamper) + for node in a: + print(node) + + +def test_run(): + """ + Tests running packets through the chain. + """ + logger = logging.getLogger("test") + t = actions.trigger.Trigger(None, None, None) + a = actions.tree.ActionTree("out", trigger=t) + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + duplicate = actions.duplicate.DuplicateAction() + duplicate2 = actions.duplicate.DuplicateAction() + drop = actions.drop.DropAction() + + packet = actions.packet.Packet(IP()/TCP()) + a.add_action(tamper) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 1 + assert None not in packets + assert packets[0].get("TCP", "flags") == "S" + a.add_action(tamper2) + print(str(a)) + + packet = actions.packet.Packet(IP()/TCP()) + assert not a.add_action(tamper), "tree added duplicate action" + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 1 + assert None not in packets + assert packets[0].get("TCP", "flags") == "R" + print(str(a)) + + a.remove_action(tamper2) + a.remove_action(tamper) + a.add_action(duplicate) + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 2 + assert None not in packets + assert packets[0][TCP].flags == "RA" + assert packets[1][TCP].flags == "RA" + print(str(a)) + + duplicate.left = tamper + duplicate.right = tamper2 + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 2 + assert None not in packets + print(str(a)) + print(str(packets[0])) + print(str(packets[1])) + assert packets[0][TCP].flags == "S" + assert packets[1][TCP].flags == "R" + print(str(a)) + + tamper.left = duplicate2 + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 3 + assert None not in packets + assert packets[0][TCP].flags == "S" + assert packets[1][TCP].flags == "S" + assert packets[2][TCP].flags == "R" + print(str(a)) + + tamper2.left = drop + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logging.getLogger("test")) + assert len(packets) == 2 + assert None not in packets + assert packets[0][TCP].flags == "S" + assert packets[1][TCP].flags == "S" + print(str(a)) + + assert a.remove_action(duplicate2) + tamper.left = actions.drop.DropAction() + packet = actions.packet.Packet(IP()/TCP(flags="RA")) + packets = a.run(packet, logger ) + assert len(packets) == 0 + print(str(a)) + + a.parse("[TCP:flags:A]-duplicate(tamper{TCP:flags:replace:R}(tamper{TCP:chksum:replace:14239},),duplicate(tamper{TCP:flags:replace:S},))-|", logger) + packet = actions.packet.Packet(IP()/TCP(flags="A")) + assert a.check(packet, logger) + packets = a.run(packet, logger) + assert len(packets) == 3 + assert packets[0][TCP].flags == "R" + assert packets[1][TCP].flags == "S" + assert packets[2][TCP].flags == "A" + + +def test_index(): + """ + Tests index + """ + a = actions.tree.ActionTree("out") + tamper = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="S") + tamper2 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="R") + tamper3 = actions.tamper.TamperAction(field="flags", tamper_type="replace", tamper_value="F") + + assert a.add_action(tamper) + assert a[0] == tamper + assert not a[1] + assert a.add_action(tamper2) + assert a[0] == tamper + assert a[1] == tamper2 + assert a[-1] == tamper2 + assert not a[10] + assert a.add_action(tamper3) + assert a[-1] == tamper3 + assert not a[-11] + + +def test_choose_one(): + """ + Tests choose_one functionality + """ + a = actions.tree.ActionTree("out") + drop = actions.drop.DropAction() + assert not a.choose_one() + assert a.add_action(drop) + assert a.choose_one() == drop + assert a.remove_action(drop) + assert not a.choose_one() + duplicate = actions.duplicate.DuplicateAction() + a.add_action(duplicate) + assert a.choose_one() == duplicate + duplicate.left = drop + assert a.choose_one() in [duplicate, drop] + # Make sure that both actions get chosen + chosen = set() + for i in range(0, 10000): + act = a.choose_one() + chosen.add(act) + assert chosen == set([duplicate, drop])