diff --git a/base/phase.py b/base/phase.py index 8b50d1f..89f3ac5 100644 --- a/base/phase.py +++ b/base/phase.py @@ -1,14 +1,21 @@ +from functools import total_ordering - +@total_ordering class Phase(object): description = None def __init__(self): pass - def __cmp__(self, other): + def __lt__(self, other): from common.phases import order - return order.index(self) - order.index(other) + return order.index(self) < order.index(other) + + def __eq__(self, other): + return self == other def __str__(self): return '{name}'.format(name=self.__class__.__name__) + + def __repr__(self): + return self.__str__() diff --git a/base/task.py b/base/task.py index 7725865..56b8da2 100644 --- a/base/task.py +++ b/base/task.py @@ -14,37 +14,24 @@ class Task(object): def run(self, info): pass - def _after(self, other): - return self.phase > other.phase or type(self) in other.before or type(other) in self.after - - def _before(self, other): - return self.phase < other.phase or type(other) in self.before or type(self) in other.after - - def __lt__(self, other): - return self._before(other) and not self._after(other) - - def __gt__(self, other): - return not self._before(other) and self._after(other) - - def __eq__(self, other): - return not self._before(other) and not self._after(other) - - def __ne__(self, other): - return self._before(other) or self._after(other) - def __str__(self): return '{module}.{task}'.format(module=self.__module__, task=self.__class__.__name__) + def __repr__(self): + return self.__str__() + def _check_ordering(self): + def name(ref): + return '{module}.{task}'.format(module=ref.__module__, task=ref.__class__.__name__) for task in self.before: if self.phase > task.phase: msg = ("The task {self} is specified as running before {other}, " "but its phase {phase} lies after the phase {other_phase}" - .format(self, other, self.phase, other.phase)) + .format(self=type(self), other=task, phase=self.phase, other_phase=task.phase)) raise TaskListError(msg) for task in self.after: if self.phase < task.phase: msg = ("The task {self} is specified as running after {other}, " "but its phase {phase} lies before the phase {other_phase}" - .format(self=self, other=other, phase=self.phase, other_phase=other.phase)) + .format(self=type(self), other=task, phase=self.phase, other_phase=task.phase)) raise TaskListError(msg) diff --git a/base/tasklist.py b/base/tasklist.py index 0be9299..0eade0a 100644 --- a/base/tasklist.py +++ b/base/tasklist.py @@ -23,30 +23,36 @@ class TaskList(object): def run(self, bootstrap_info): task_list = self.create_list() - for task in tasks: + log.debug('Tasklist:\n\t{list}'.format(list='\n\t'.join(repr(task) for task in task_list))) + for task in task_list: log.info(task) task.run(bootstrap_info) def create_list(self): + from common.phases import order graph = {} for task in self.tasks: - graph[task] = [self.get(succ) for succ in task.before] + graph[task] = [] + graph[task].extend([self.get(succ) for succ in task.before]) graph[task].extend([succ for succ in self.tasks if type(task) in succ.after]) + succeeding_phases = order[order.index(task.phase)+1:] + graph[task].extend([succ for succ in self.tasks if succ.phase in succeeding_phases]) + components = self.strongly_connected_components(graph) cycles_found = 0 for component in components: if len(component) > 1: cycles_found += 1 - log.debug('Cycle: {list}\n'.format(list=', '.join(str(task) for task in component))) + log.debug('Cycle: {list}\n'.format(list=', '.join(repr(task) for task in component))) if cycles_found > 0: msg = ('{0} cycles were found in the tasklist, ' 'consult the logfile for more information.'.format(cycles_found)) raise TaskListError(msg) sorted_tasks = self.topological_sort(graph) - log.debug('Tasklist:\n\t{list}\n'.format(list='\n\t'.join(str(task) for task in sorted_tasks))) + return sorted_tasks def strongly_connected_components(self, graph): # Source: http://www.logarithmic.net/pfh-files/blog/01208083168/sort.py diff --git a/plugins/build_metadata/buildmetadata.py b/plugins/build_metadata/buildmetadata.py index 35834bc..43f1438 100644 --- a/plugins/build_metadata/buildmetadata.py +++ b/plugins/build_metadata/buildmetadata.py @@ -1,10 +1,12 @@ from base import Task from common import phases +from providers.ec2.tasks.host import GetInfo class PrintInfo(Task): description = 'Printing `info\' to the console' phase = phases.InstallOS + after = [GetInfo] def run(self, info): super(PrintInfo, self).run(info) diff --git a/providers/ec2/tasks/host.py b/providers/ec2/tasks/host.py index 515f56c..e5aa6f4 100644 --- a/providers/ec2/tasks/host.py +++ b/providers/ec2/tasks/host.py @@ -1,6 +1,6 @@ from base import Task from common import phases -from providers.ec2.tasks import packages +import packages class CheckPackages(Task): description = 'Checking installed host packages'