diff --git a/base/tasklist.py b/base/tasklist.py index 2803d7a..216c466 100644 --- a/base/tasklist.py +++ b/base/tasklist.py @@ -31,15 +31,26 @@ class TaskList(object): def create_list(self): from common.phases import order + # Get a hold of all tasks + tasks = self.get_all_tasks() + # Make sure the taskset is a subset of all the tasks we have gathered + self.tasks.issubset(tasks) + # Create a graph over all tasks by creating a map of each tasks successors graph = {} - for task in self.tasks: + for task in tasks: + # Do a sanity check first self.check_ordering(task) successors = set() + # Add all successors mentioned in the task successors.update(task.successors) - successors.update(filter(lambda succ: task in succ.predecessors, self.tasks)) + # Add all tasks that mention this task as a predecessor + successors.update(filter(lambda succ: task in succ.predecessors, tasks)) + # Create a list of phases that succeed the phase of this task succeeding_phases = order[order.index(task.phase) + 1:] - successors.update(filter(lambda succ: succ.phase in succeeding_phases, self.tasks)) - graph[task] = filter(lambda succ: succ in self.tasks, successors) + # Add all tasks that occur in above mentioned succeeding phases + successors.update(filter(lambda succ: succ.phase in succeeding_phases, tasks)) + # Map the successors to the task + graph[task] = successors components = self.strongly_connected_components(graph) cycles_found = 0 @@ -52,10 +63,41 @@ class TaskList(object): 'consult the logfile for more information.'.format(cycles_found)) raise TaskListError(msg) + # Run a topological sort on the graph, returning an ordered list sorted_tasks = self.topological_sort(graph) + # Filter out any tasks not in the tasklist + # We want to maintain ordering, so we don't use set intersection + sorted_tasks = filter(lambda task: task in self.tasks, sorted_tasks) return sorted_tasks + def get_all_tasks(self): + # Get a generator that returns all classes in the package + classes = self.get_all_classes('..') + + # lambda function to check whether a class is a task (excluding the superclass Task) + def is_task(obj): + from task import Task + return issubclass(obj, Task) and obj is not Task + return filter(is_task, classes) # Only return classes that are tasks + + # Given a path, retrieve all the classes in it + def get_all_classes(self, path=None): + import pkgutil + import importlib + import inspect + + def walk_error(module): + raise Exception('Unable to inspect module `{module}\''.format(module=module)) + walker = pkgutil.walk_packages(path, '', walk_error) + for _, module_name, _ in walker: + module = importlib.import_module(module_name) + classes = inspect.getmembers(module, inspect.isclass) + for class_name, obj in classes: + # We only want classes that are defined in the module, and not imported ones + if obj.__module__ == module_name: + yield obj + def check_ordering(self, task): for successor in task.successors: if successor.phase > successor.phase: