mirror of https://github.com/Kkevsterrr/geneva
527 lines
16 KiB
Python
527 lines
16 KiB
Python
import copy
|
|
import datetime
|
|
import importlib
|
|
import inspect
|
|
import json
|
|
import logging
|
|
import os
|
|
import string
|
|
import sys
|
|
import random
|
|
import urllib.parse
|
|
import re
|
|
import socket
|
|
import random
|
|
import struct
|
|
|
|
import actions.action
|
|
import actions.trigger
|
|
import layers.packet
|
|
import plugins.plugin_client
|
|
import plugins.plugin_server
|
|
|
|
from scapy.all import TCP, IP, UDP, rdpcap
|
|
import netifaces
|
|
|
|
|
|
RUN_DIRECTORY = os.path.join("trials", datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"))
|
|
|
|
# Hard coded options
|
|
FLAGFOLDER = "flags"
|
|
|
|
# Holds copy of console file handler's log level
|
|
CONSOLE_LOG_LEVEL = "debug"
|
|
|
|
|
|
BASEPATH = os.path.dirname(os.path.abspath(__file__))
|
|
PROJECT_ROOT = os.path.dirname(BASEPATH)
|
|
|
|
|
|
class SkipStrategyException(Exception):
|
|
"""
|
|
Raised to signal that this strategy evaluation should be cut off.
|
|
"""
|
|
def __init__(self, msg, fitness):
|
|
"""
|
|
Creates the exception with the fitness to pass back
|
|
"""
|
|
self.fitness = fitness
|
|
self.msg = msg
|
|
|
|
|
|
def parse(requested_trees, logger):
|
|
"""
|
|
Parses a string representation of a solution into its object form.
|
|
"""
|
|
# First, strip off any hanging quotes at beginning/end of the strategy
|
|
if requested_trees.startswith("\""):
|
|
requested_trees = requested_trees[1:]
|
|
if requested_trees.endswith("\""):
|
|
requested_trees = requested_trees[:-1]
|
|
|
|
# Define a blank strategy to initialize with the user specified string
|
|
strat = actions.strategy.Strategy([], [])
|
|
|
|
# Actions for the in and out forest are separated by a "\/".
|
|
# Split the given string by this token
|
|
out_in_actions = requested_trees.split("\\/")
|
|
|
|
# Specify that we're starting with the out forest before we parse the in forest
|
|
out = True
|
|
direction = "out"
|
|
# For each string representation of the action directions, in or out
|
|
for str_actions in out_in_actions:
|
|
# Individual action trees always end in "|" to signify the end - split the
|
|
# entire action sequence into individual trees
|
|
str_actions = str_actions.split("|")
|
|
|
|
# For each string representation of each tree in the forest
|
|
for str_action in str_actions:
|
|
# If it's an empty action, skip it
|
|
if not str_action.strip():
|
|
continue
|
|
|
|
assert " " not in str_action.strip(), "Strategy includes a space - malformed!"
|
|
|
|
# Get rid of hanging whitespace from the splitting
|
|
str_action = str_action.strip()
|
|
|
|
# ActionTree uses the last "|" as a sanity check for well-formed
|
|
# strategies, so restore the "|" that was lost from the split
|
|
str_action = str_action + "|"
|
|
new_tree = actions.tree.ActionTree(direction)
|
|
success = new_tree.parse(str_action, logger)
|
|
if success is False:
|
|
raise actions.tree.ActionTreeParseError("Failed to parse tree")
|
|
|
|
# Once all the actions are parsed, add this tree to the
|
|
# current direction of actions
|
|
if out:
|
|
strat.out_actions.append(new_tree)
|
|
else:
|
|
strat.in_actions.append(new_tree)
|
|
# Change the flag to tell it to parse the IN direction during the next loop iteration
|
|
out = False
|
|
direction = "in"
|
|
return strat
|
|
|
|
|
|
|
|
def get_logger(basepath, log_dir, logger_name, log_name, environment_id, log_level="DEBUG", file_log_level="DEBUG", demo_mode=False):
|
|
"""
|
|
Configures and returns a logger.
|
|
"""
|
|
if type(log_level) == str:
|
|
log_level = log_level.upper()
|
|
if type(file_log_level) == str:
|
|
file_log_level = file_log_level.upper()
|
|
global CONSOLE_LOG_LEVEL
|
|
full_path = os.path.join(basepath, log_dir, "logs")
|
|
if not os.path.exists(full_path):
|
|
os.makedirs(full_path)
|
|
flag_path = os.path.join(basepath, log_dir, "flags")
|
|
if not os.path.exists(flag_path):
|
|
os.makedirs(flag_path)
|
|
# Set up a client logger
|
|
logger = logging.getLogger(logger_name + environment_id)
|
|
logger.setLevel("DEBUG")
|
|
# Disable the root logger to avoid double printing
|
|
logger.propagate = False
|
|
|
|
# If we've already setup the handlers for this logger, just return it
|
|
if logger.handlers:
|
|
return logger
|
|
fh = logging.FileHandler(os.path.join(basepath, log_dir, "logs", "%s.%s.log" % (environment_id, log_name)))
|
|
|
|
log_prefix = "[%s] " % log_name.upper()
|
|
formatter = logging.Formatter("%(asctime)s %(levelname)s:" + log_prefix + "%(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
|
file_formatter = logging.Formatter(log_prefix + "%(asctime)s %(message)s")
|
|
fh.setFormatter(file_formatter)
|
|
logger.addHandler(fh)
|
|
|
|
ch = logging.StreamHandler()
|
|
ch.setFormatter(formatter)
|
|
ch.setLevel(log_level)
|
|
fh.setLevel(file_log_level)
|
|
CONSOLE_LOG_LEVEL = log_level.lower()
|
|
logger.addHandler(ch)
|
|
return CustomAdapter(logger, {}) if demo_mode else logger
|
|
|
|
class CustomAdapter(logging.LoggerAdapter):
|
|
"""
|
|
Used for demo mode, to change sensitive IP addresses where necessary. Can be used (mostly) like a regular logger.
|
|
"""
|
|
regex = re.compile(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}")
|
|
|
|
def __init__(self, logger, extras):
|
|
super().__init__(logger, extras)
|
|
self.handlers = logger.handlers
|
|
self.ips = {}
|
|
|
|
def debug(self, msg, *args, **kwargs):
|
|
"""
|
|
Print a debug message, uses logger.debug.
|
|
"""
|
|
msg, args, kwargs = self.process(msg, args, kwargs)
|
|
|
|
self.logger.debug(msg, *args, **kwargs)
|
|
|
|
def info(self, msg, *args, **kwargs):
|
|
"""
|
|
Print an info message, uses logger.info.
|
|
"""
|
|
msg, args, kwargs = self.process(msg, args, kwargs)
|
|
|
|
self.logger.info(msg, *args, **kwargs)
|
|
|
|
def warning(self, msg, *args, **kwargs):
|
|
"""
|
|
Print a warning message, uses logger.warning.
|
|
"""
|
|
msg, args, kwargs = self.process(msg, args, kwargs)
|
|
|
|
self.logger.warning(msg, *args, **kwargs)
|
|
|
|
def error(self, msg, *args, **kwargs):
|
|
"""
|
|
Print an error message, uses logger.error.
|
|
"""
|
|
msg, args, kwargs = self.process(msg, args, kwargs)
|
|
|
|
self.logger.error(msg, *args, **kwargs)
|
|
|
|
def critical(self, msg, *args, **kwargs):
|
|
"""
|
|
Print a critical message, uses logger.critical.
|
|
"""
|
|
msg, args, kwargs = self.process(msg, args, kwargs)
|
|
|
|
self.logger.critical(msg, *args, **kwargs)
|
|
|
|
def get_ip(self, ip):
|
|
"""
|
|
Lookup the assigned random IP for a given real IP.
|
|
If no random IP exists, a new one is created and a message is logged indicating it.
|
|
"""
|
|
if ip not in self.ips:
|
|
random_ip = socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff)))
|
|
self.logger.info("Registering new random IP: %s" % random_ip)
|
|
self.ips[ip] = random_ip
|
|
|
|
def process(self, msg, args, kwargs):
|
|
"""
|
|
Modify the log message to replace any instance of an IP in msg or args with its assigned random IP.
|
|
"""
|
|
new_args = []
|
|
for arg in args:
|
|
if type(arg) == str:
|
|
for ip in self.regex.findall(arg):
|
|
new_ip = self.get_ip(ip)
|
|
|
|
arg = arg.replace(ip, self.ips[ip])
|
|
|
|
new_args.append(arg)
|
|
|
|
for ip in self.regex.findall(msg):
|
|
if ip not in self.ips:
|
|
random_ip = socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff)))
|
|
self.logger.debug("Registering new random IP: %s" % random_ip)
|
|
self.ips[ip] = random_ip
|
|
new_ip = self.get_ip(ip)
|
|
|
|
msg = msg.replace(ip, self.ips[ip])
|
|
|
|
return msg, tuple(new_args), kwargs
|
|
|
|
def close_logger(logger):
|
|
"""
|
|
Closes open file handles for a given logger.
|
|
"""
|
|
# Close the file handles so we don't hold a ton of file descriptors open
|
|
handlers = logger.handlers[:]
|
|
for handler in handlers:
|
|
if isinstance(handler, logging.FileHandler):
|
|
handler.close()
|
|
|
|
|
|
class Logger():
|
|
"""
|
|
Logging class context manager, as a thin wrapper around the logging class to help
|
|
handle closing open file descriptors.
|
|
"""
|
|
def __init__(self, log_dir, logger_name, log_name, environment_id, log_level="DEBUG"):
|
|
self.log_dir = log_dir
|
|
self.logger_name = logger_name
|
|
self.log_name = log_name
|
|
self.environment_id = environment_id
|
|
self.log_level = log_level
|
|
self.logger = None
|
|
|
|
def __enter__(self):
|
|
"""
|
|
Sets up a logger.
|
|
"""
|
|
self.logger = get_logger(PROJECT_ROOT, self.log_dir, self.logger_name, self.log_name, self.environment_id, log_level=self.log_level)
|
|
return self.logger
|
|
|
|
def __exit__(self, exc_type, exc_value, tb):
|
|
"""
|
|
Closes file handles.
|
|
"""
|
|
close_logger(self.logger)
|
|
|
|
|
|
|
|
def get_console_log_level():
|
|
"""
|
|
returns log level of console handler
|
|
"""
|
|
return CONSOLE_LOG_LEVEL
|
|
|
|
|
|
def get_plugins():
|
|
"""
|
|
Iterates over this current directory to retrieve plugins.
|
|
"""
|
|
plugins = []
|
|
for f in os.listdir(os.path.join(PROJECT_ROOT, "plugins")):
|
|
if os.path.isdir(os.path.join(PROJECT_ROOT, "plugins", f)) and "__pycache__" not in f:
|
|
plugins.append(f)
|
|
return plugins
|
|
|
|
|
|
def import_plugin(plugin, side):
|
|
"""
|
|
Imports given plugin.
|
|
Args:
|
|
- plugin: plugin to import (e.g. "http")
|
|
- side: which side of the connection should be imported ("client" or "server")
|
|
"""
|
|
|
|
# Define the full module for this plugin
|
|
mod = "plugins.%s.%s" % (plugin, side)
|
|
|
|
path = os.path.join(PROJECT_ROOT, "plugins", plugin)
|
|
if path not in sys.path:
|
|
sys.path.append(path)
|
|
|
|
# Import the module
|
|
importlib.import_module(mod)
|
|
|
|
# Predicate to filter classmembers
|
|
def check_plugin(obj):
|
|
"""
|
|
Filters class members to ensure we get only enabled Plugin subclasses
|
|
"""
|
|
return inspect.isclass(obj) and \
|
|
issubclass(obj, plugins.plugin.Plugin) and \
|
|
(obj != plugins.plugin_client.ClientPlugin and \
|
|
obj != plugins.plugin_server.ServerPlugin and \
|
|
obj != plugins.plugin.Plugin) and \
|
|
obj(None).enabled
|
|
|
|
# Filter the class members of the imported module to find our Plugin subclass
|
|
clsmembers = inspect.getmembers(sys.modules[mod], predicate=check_plugin)
|
|
|
|
# Sanity check the class members we identified
|
|
assert clsmembers, "Could not find plugin %s" % mod
|
|
assert len(clsmembers) == 1, "Too many matching plugins found for %s" % mod
|
|
|
|
# Extract the class - clsmembers[0] is a tuple of (name, class)
|
|
_, cls = clsmembers[0]
|
|
|
|
# Return the module path and class
|
|
return mod, cls
|
|
|
|
|
|
def build_command(args):
|
|
"""
|
|
Given a dictionary of arguments, build it back into a command line string.
|
|
"""
|
|
cmd = []
|
|
for opt in args:
|
|
# Don't pass along store true args that are false
|
|
if args[opt] in [False, None]:
|
|
continue
|
|
cmd.append("--%s" % opt.replace("_", "-"))
|
|
# If store true arg, we don't need to pass the value
|
|
if args[opt] is True:
|
|
continue
|
|
|
|
if args[opt] is '':
|
|
cmd.append("''")
|
|
elif " " in str(args[opt]):
|
|
cmd.append("\"" + str(args[opt]) + "\"")
|
|
else:
|
|
cmd.append(str(args[opt]))
|
|
return cmd
|
|
|
|
|
|
def string_to_protocol(protocol):
|
|
"""
|
|
Converts string representations of scapy protocol objects to
|
|
their actual objects. For example, "TCP" to the scapy TCP object.
|
|
"""
|
|
if protocol.upper() == "TCP":
|
|
return TCP
|
|
elif protocol.upper() == "IP":
|
|
return IP
|
|
elif protocol.upper() == "UDP":
|
|
return UDP
|
|
|
|
|
|
def get_id():
|
|
"""
|
|
Returns a random ID
|
|
"""
|
|
return ''.join([random.choice(string.ascii_lowercase + string.digits) for k in range(8)])
|
|
|
|
|
|
def setup_dirs(output_dir):
|
|
"""
|
|
Sets up Geneva folder structure.
|
|
"""
|
|
ga_log_dir = os.path.join(output_dir, "logs")
|
|
ga_flags_dir = os.path.join(output_dir, "flags")
|
|
ga_packets_dir = os.path.join(output_dir, "packets")
|
|
ga_generations_dir = os.path.join(output_dir, "generations")
|
|
ga_data_dir = os.path.join(output_dir, "data")
|
|
for directory in [ga_log_dir, ga_flags_dir, ga_packets_dir, ga_generations_dir, ga_data_dir]:
|
|
if not os.path.exists(directory):
|
|
os.makedirs(directory, exist_ok=True)
|
|
return ga_log_dir
|
|
|
|
|
|
def get_from_fuzzed_or_real_packet(environment_id, real_packet_probability, enable_options=True, enable_load=True):
|
|
"""
|
|
Retrieves a protocol, field, and value from a fuzzed or real packet, depending on
|
|
the given probability and if given packets is not None.
|
|
"""
|
|
packets = actions.utils.read_packets(environment_id)
|
|
if packets and random.random() < real_packet_probability:
|
|
packet = random.choice(packets)
|
|
return packet.get_random()
|
|
return layers.packet.Packet().gen_random()
|
|
|
|
|
|
def read_packets(environment_id):
|
|
"""
|
|
Reads the pcap file associated with the last evaluation of this strategy.
|
|
Returns a list of Geneva Packet objects.
|
|
"""
|
|
if not environment_id:
|
|
return None
|
|
|
|
packets_path = os.path.join(RUN_DIRECTORY, "packets", "original_" + str(environment_id) + ".pcap")
|
|
if not os.path.exists(packets_path):
|
|
return None
|
|
|
|
parsed = []
|
|
try:
|
|
packets = rdpcap(packets_path)
|
|
parsed = [layers.packet.Packet(p) for p in packets]
|
|
except Exception as e:
|
|
print(e)
|
|
print("FAILED TO PARSE!")
|
|
|
|
return parsed
|
|
|
|
|
|
def punish_fitness(fitness, logger, eng):
|
|
"""
|
|
Adjusts fitness based on additional optimizer functions.
|
|
"""
|
|
if not eng:
|
|
logger.warning("Requested fitness adjustment without an engine - returning original fitness.")
|
|
return fitness
|
|
logger.debug("Initiating fitness adjustment")
|
|
if eng and eng.strategy:
|
|
fitness = punish_complexity(fitness, logger, eng.strategy)
|
|
fitness = punish_unused(fitness, logger, eng.strategy)
|
|
if fitness > 0:
|
|
overhead = int(eng.overhead / 2)
|
|
logger.debug("Punishing for overhead: %d" % overhead)
|
|
fitness -= overhead
|
|
|
|
return fitness
|
|
|
|
|
|
def punish_unused(fitness, logger, ind):
|
|
"""
|
|
Punishes strategy for each action that was not run.
|
|
"""
|
|
if not ind:
|
|
return fitness
|
|
logger.debug("Punishing for unused actions")
|
|
num_unused = [action_tree.ran for action_tree in ind.out_actions].count(False)
|
|
fitness -= (num_unused * 10)
|
|
logger.debug(" - Number of unused actions in out forest: %d" % num_unused)
|
|
num_unused = [action_tree.ran for action_tree in ind.in_actions].count(False)
|
|
fitness -= (num_unused * 10)
|
|
logger.debug(" - Number of unused actions in in forest: %d" % num_unused)
|
|
return fitness
|
|
|
|
|
|
def punish_complexity(fitness, logger, ind):
|
|
"""
|
|
Reduces fitness based on number of actions - optimizes for simplicity.
|
|
"""
|
|
if not ind:
|
|
return fitness
|
|
# Punish for number of actions
|
|
if fitness > 0:
|
|
logger.debug("Punishing for complexity: %d" % len(ind))
|
|
fitness -= len(ind)
|
|
return fitness
|
|
|
|
|
|
def write_fitness(fitness, output_path, eid):
|
|
"""
|
|
Writes fitness to disk.
|
|
"""
|
|
try:
|
|
float(fitness)
|
|
except ValueError:
|
|
print("Given fitness (%r) is not a number!" % fitness)
|
|
raise
|
|
fitpath = os.path.join(PROJECT_ROOT, output_path, FLAGFOLDER, eid) + ".fitness"
|
|
with open(fitpath, "w") as fitfile:
|
|
fitfile.write(str(fitness))
|
|
|
|
|
|
def get_interface():
|
|
"""
|
|
Chooses an interface on the machine to use for socket testing.
|
|
"""
|
|
ifaces = netifaces.interfaces()
|
|
for iface in ifaces:
|
|
if "lo" in iface:
|
|
continue
|
|
info = netifaces.ifaddresses(iface)
|
|
# Filter for IPv4 addresses
|
|
if netifaces.AF_INET in info:
|
|
return iface
|
|
|
|
|
|
def get_worker(name, logger):
|
|
"""
|
|
Returns information dictionary about a worker given its name.
|
|
"""
|
|
path = os.path.join("workers", name, "worker.json")
|
|
if os.path.exists(name):
|
|
path = name
|
|
|
|
dirpath = os.path.dirname(path)
|
|
|
|
if not os.path.exists(path):
|
|
return None
|
|
|
|
with open(path, "r") as fd:
|
|
data = json.load(fd)
|
|
|
|
# If there is a private key, update the path to be relative to the project base
|
|
if data.get("keyfile"):
|
|
data["keyfile"] = os.path.join(dirpath, data["keyfile"])
|
|
|
|
return data
|