import copy import datetime import importlib import inspect import json import logging import os import string import sys import random import urllib.parse import re import socket import random import struct import actions.action import actions.trigger import layers.packet import plugins.plugin_client import plugins.plugin_server from scapy.all import TCP, IP, UDP, rdpcap import netifaces RUN_DIRECTORY = os.path.join("trials", datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")) # Hard coded options FLAGFOLDER = "flags" # Holds copy of console file handler's log level CONSOLE_LOG_LEVEL = "debug" BASEPATH = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.dirname(BASEPATH) class SkipStrategyException(Exception): """ Raised to signal that this strategy evaluation should be cut off. """ def __init__(self, msg, fitness): """ Creates the exception with the fitness to pass back """ self.fitness = fitness self.msg = msg def parse(requested_trees, logger): """ Parses a string representation of a solution into its object form. """ # First, strip off any hanging quotes at beginning/end of the strategy if requested_trees.startswith("\""): requested_trees = requested_trees[1:] if requested_trees.endswith("\""): requested_trees = requested_trees[:-1] # Define a blank strategy to initialize with the user specified string strat = actions.strategy.Strategy([], []) # Actions for the in and out forest are separated by a "\/". # Split the given string by this token out_in_actions = requested_trees.split("\\/") # Specify that we're starting with the out forest before we parse the in forest out = True direction = "out" # For each string representation of the action directions, in or out for str_actions in out_in_actions: # Individual action trees always end in "|" to signify the end - split the # entire action sequence into individual trees str_actions = str_actions.split("|") # For each string representation of each tree in the forest for str_action in str_actions: # If it's an empty action, skip it if not str_action.strip(): continue assert " " not in str_action.strip(), "Strategy includes a space - malformed!" # Get rid of hanging whitespace from the splitting str_action = str_action.strip() # ActionTree uses the last "|" as a sanity check for well-formed # strategies, so restore the "|" that was lost from the split str_action = str_action + "|" new_tree = actions.tree.ActionTree(direction) success = new_tree.parse(str_action, logger) if success is False: raise actions.tree.ActionTreeParseError("Failed to parse tree") # Once all the actions are parsed, add this tree to the # current direction of actions if out: strat.out_actions.append(new_tree) else: strat.in_actions.append(new_tree) # Change the flag to tell it to parse the IN direction during the next loop iteration out = False direction = "in" return strat def get_logger(basepath, log_dir, logger_name, log_name, environment_id, log_level="DEBUG", file_log_level="DEBUG", demo_mode=False): """ Configures and returns a logger. """ if type(log_level) == str: log_level = log_level.upper() if type(file_log_level) == str: file_log_level = file_log_level.upper() global CONSOLE_LOG_LEVEL full_path = os.path.join(basepath, log_dir, "logs") if not os.path.exists(full_path): os.makedirs(full_path) flag_path = os.path.join(basepath, log_dir, "flags") if not os.path.exists(flag_path): os.makedirs(flag_path) # Set up a client logger logger = logging.getLogger(logger_name + environment_id) logger.setLevel("DEBUG") # Disable the root logger to avoid double printing logger.propagate = False # If we've already setup the handlers for this logger, just return it if logger.handlers: return logger fh = logging.FileHandler(os.path.join(basepath, log_dir, "logs", "%s.%s.log" % (environment_id, log_name))) log_prefix = "[%s] " % log_name.upper() formatter = logging.Formatter("%(asctime)s %(levelname)s:" + log_prefix + "%(message)s", datefmt="%Y-%m-%d %H:%M:%S") file_formatter = logging.Formatter(log_prefix + "%(asctime)s %(message)s") fh.setFormatter(file_formatter) logger.addHandler(fh) ch = logging.StreamHandler() ch.setFormatter(formatter) ch.setLevel(log_level) fh.setLevel(file_log_level) CONSOLE_LOG_LEVEL = log_level.lower() logger.addHandler(ch) return CustomAdapter(logger, {}) if demo_mode else logger class CustomAdapter(logging.LoggerAdapter): """ Used for demo mode, to change sensitive IP addresses where necessary. Can be used (mostly) like a regular logger. """ regex = re.compile(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}") def __init__(self, logger, extras): super().__init__(logger, extras) self.handlers = logger.handlers self.ips = {} def debug(self, msg, *args, **kwargs): """ Print a debug message, uses logger.debug. """ msg, args, kwargs = self.process(msg, args, kwargs) self.logger.debug(msg, *args, **kwargs) def info(self, msg, *args, **kwargs): """ Print an info message, uses logger.info. """ msg, args, kwargs = self.process(msg, args, kwargs) self.logger.info(msg, *args, **kwargs) def warning(self, msg, *args, **kwargs): """ Print a warning message, uses logger.warning. """ msg, args, kwargs = self.process(msg, args, kwargs) self.logger.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): """ Print an error message, uses logger.error. """ msg, args, kwargs = self.process(msg, args, kwargs) self.logger.error(msg, *args, **kwargs) def critical(self, msg, *args, **kwargs): """ Print a critical message, uses logger.critical. """ msg, args, kwargs = self.process(msg, args, kwargs) self.logger.critical(msg, *args, **kwargs) def get_ip(self, ip): """ Lookup the assigned random IP for a given real IP. If no random IP exists, a new one is created and a message is logged indicating it. """ if ip not in self.ips: random_ip = socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff))) self.logger.info("Registering new random IP: %s" % random_ip) self.ips[ip] = random_ip def process(self, msg, args, kwargs): """ Modify the log message to replace any instance of an IP in msg or args with its assigned random IP. """ new_args = [] for arg in args: if type(arg) == str: for ip in self.regex.findall(arg): new_ip = self.get_ip(ip) arg = arg.replace(ip, self.ips[ip]) new_args.append(arg) for ip in self.regex.findall(msg): if ip not in self.ips: random_ip = socket.inet_ntoa(struct.pack('>I', random.randint(1, 0xffffffff))) self.logger.debug("Registering new random IP: %s" % random_ip) self.ips[ip] = random_ip new_ip = self.get_ip(ip) msg = msg.replace(ip, self.ips[ip]) return msg, tuple(new_args), kwargs def close_logger(logger): """ Closes open file handles for a given logger. """ # Close the file handles so we don't hold a ton of file descriptors open handlers = logger.handlers[:] for handler in handlers: if isinstance(handler, logging.FileHandler): handler.close() class Logger(): """ Logging class context manager, as a thin wrapper around the logging class to help handle closing open file descriptors. """ def __init__(self, log_dir, logger_name, log_name, environment_id, log_level="DEBUG"): self.log_dir = log_dir self.logger_name = logger_name self.log_name = log_name self.environment_id = environment_id self.log_level = log_level self.logger = None def __enter__(self): """ Sets up a logger. """ self.logger = get_logger(PROJECT_ROOT, self.log_dir, self.logger_name, self.log_name, self.environment_id, log_level=self.log_level) return self.logger def __exit__(self, exc_type, exc_value, tb): """ Closes file handles. """ close_logger(self.logger) def get_console_log_level(): """ returns log level of console handler """ return CONSOLE_LOG_LEVEL def get_plugins(): """ Iterates over this current directory to retrieve plugins. """ plugins = [] for f in os.listdir(os.path.join(PROJECT_ROOT, "plugins")): if os.path.isdir(os.path.join(PROJECT_ROOT, "plugins", f)) and "__pycache__" not in f: plugins.append(f) return plugins def import_plugin(plugin, side): """ Imports given plugin. Args: - plugin: plugin to import (e.g. "http") - side: which side of the connection should be imported ("client" or "server") """ # Define the full module for this plugin mod = "plugins.%s.%s" % (plugin, side) path = os.path.join(PROJECT_ROOT, "plugins", plugin) if path not in sys.path: sys.path.append(path) # Import the module importlib.import_module(mod) # Predicate to filter classmembers def check_plugin(obj): """ Filters class members to ensure we get only enabled Plugin subclasses """ return inspect.isclass(obj) and \ issubclass(obj, plugins.plugin.Plugin) and \ (obj != plugins.plugin_client.ClientPlugin and \ obj != plugins.plugin_server.ServerPlugin and \ obj != plugins.plugin.Plugin) and \ obj(None).enabled # Filter the class members of the imported module to find our Plugin subclass clsmembers = inspect.getmembers(sys.modules[mod], predicate=check_plugin) # Sanity check the class members we identified assert clsmembers, "Could not find plugin %s" % mod assert len(clsmembers) == 1, "Too many matching plugins found for %s" % mod # Extract the class - clsmembers[0] is a tuple of (name, class) _, cls = clsmembers[0] # Return the module path and class return mod, cls def build_command(args): """ Given a dictionary of arguments, build it back into a command line string. """ cmd = [] for opt in args: # Don't pass along store true args that are false if args[opt] in [False, None]: continue cmd.append("--%s" % opt.replace("_", "-")) # If store true arg, we don't need to pass the value if args[opt] is True: continue if args[opt] is '': cmd.append("''") elif " " in str(args[opt]): cmd.append("\"" + str(args[opt]) + "\"") else: cmd.append(str(args[opt])) return cmd def string_to_protocol(protocol): """ Converts string representations of scapy protocol objects to their actual objects. For example, "TCP" to the scapy TCP object. """ if protocol.upper() == "TCP": return TCP elif protocol.upper() == "IP": return IP elif protocol.upper() == "UDP": return UDP def get_id(): """ Returns a random ID """ return ''.join([random.choice(string.ascii_lowercase + string.digits) for k in range(8)]) def setup_dirs(output_dir): """ Sets up Geneva folder structure. """ ga_log_dir = os.path.join(output_dir, "logs") ga_flags_dir = os.path.join(output_dir, "flags") ga_packets_dir = os.path.join(output_dir, "packets") ga_generations_dir = os.path.join(output_dir, "generations") ga_data_dir = os.path.join(output_dir, "data") for directory in [ga_log_dir, ga_flags_dir, ga_packets_dir, ga_generations_dir, ga_data_dir]: if not os.path.exists(directory): os.makedirs(directory, exist_ok=True) return ga_log_dir def get_from_fuzzed_or_real_packet(environment_id, real_packet_probability, enable_options=True, enable_load=True): """ Retrieves a protocol, field, and value from a fuzzed or real packet, depending on the given probability and if given packets is not None. """ packets = actions.utils.read_packets(environment_id) if packets and random.random() < real_packet_probability: packet = random.choice(packets) return packet.get_random() return layers.packet.Packet().gen_random() def read_packets(environment_id): """ Reads the pcap file associated with the last evaluation of this strategy. Returns a list of Geneva Packet objects. """ if not environment_id: return None packets_path = os.path.join(RUN_DIRECTORY, "packets", "original_" + str(environment_id) + ".pcap") if not os.path.exists(packets_path): return None parsed = [] try: packets = rdpcap(packets_path) parsed = [layers.packet.Packet(p) for p in packets] except Exception as e: print(e) print("FAILED TO PARSE!") return parsed def punish_fitness(fitness, logger, eng): """ Adjusts fitness based on additional optimizer functions. """ if not eng: logger.warning("Requested fitness adjustment without an engine - returning original fitness.") return fitness logger.debug("Initiating fitness adjustment") if eng and eng.strategy: fitness = punish_complexity(fitness, logger, eng.strategy) fitness = punish_unused(fitness, logger, eng.strategy) if fitness > 0: overhead = int(eng.overhead / 2) logger.debug("Punishing for overhead: %d" % overhead) fitness -= overhead return fitness def punish_unused(fitness, logger, ind): """ Punishes strategy for each action that was not run. """ if not ind: return fitness logger.debug("Punishing for unused actions") num_unused = [action_tree.ran for action_tree in ind.out_actions].count(False) fitness -= (num_unused * 10) logger.debug(" - Number of unused actions in out forest: %d" % num_unused) num_unused = [action_tree.ran for action_tree in ind.in_actions].count(False) fitness -= (num_unused * 10) logger.debug(" - Number of unused actions in in forest: %d" % num_unused) return fitness def punish_complexity(fitness, logger, ind): """ Reduces fitness based on number of actions - optimizes for simplicity. """ if not ind: return fitness # Punish for number of actions if fitness > 0: logger.debug("Punishing for complexity: %d" % len(ind)) fitness -= len(ind) return fitness def write_fitness(fitness, output_path, eid): """ Writes fitness to disk. """ try: float(fitness) except ValueError: print("Given fitness (%r) is not a number!" % fitness) raise fitpath = os.path.join(PROJECT_ROOT, output_path, FLAGFOLDER, eid) + ".fitness" with open(fitpath, "w") as fitfile: fitfile.write(str(fitness)) def get_interface(): """ Chooses an interface on the machine to use for socket testing. """ ifaces = netifaces.interfaces() for iface in ifaces: if "lo" in iface: continue info = netifaces.ifaddresses(iface) # Filter for IPv4 addresses if netifaces.AF_INET in info: return iface def get_worker(name, logger): """ Returns information dictionary about a worker given its name. """ path = os.path.join("workers", name, "worker.json") if os.path.exists(name): path = name dirpath = os.path.dirname(path) if not os.path.exists(path): return None with open(path, "r") as fd: data = json.load(fd) # If there is a private key, update the path to be relative to the project base if data.get("keyfile"): data["keyfile"] = os.path.join(dirpath, data["keyfile"]) return data