Order by phases as well

This commit is contained in:
Anders Ingemann 2013-06-23 17:54:25 +02:00
parent 0f29b3d0e2
commit a401f9edc0
5 changed files with 30 additions and 28 deletions

View file

@ -1,14 +1,21 @@
from functools import total_ordering
@total_ordering
class Phase(object):
description = None
def __init__(self):
pass
def __cmp__(self, other):
def __lt__(self, other):
from common.phases import order
return order.index(self) - order.index(other)
return order.index(self) < order.index(other)
def __eq__(self, other):
return self == other
def __str__(self):
return '{name}'.format(name=self.__class__.__name__)
def __repr__(self):
return self.__str__()

View file

@ -14,37 +14,24 @@ class Task(object):
def run(self, info):
pass
def _after(self, other):
return self.phase > other.phase or type(self) in other.before or type(other) in self.after
def _before(self, other):
return self.phase < other.phase or type(other) in self.before or type(self) in other.after
def __lt__(self, other):
return self._before(other) and not self._after(other)
def __gt__(self, other):
return not self._before(other) and self._after(other)
def __eq__(self, other):
return not self._before(other) and not self._after(other)
def __ne__(self, other):
return self._before(other) or self._after(other)
def __str__(self):
return '{module}.{task}'.format(module=self.__module__, task=self.__class__.__name__)
def __repr__(self):
return self.__str__()
def _check_ordering(self):
def name(ref):
return '{module}.{task}'.format(module=ref.__module__, task=ref.__class__.__name__)
for task in self.before:
if self.phase > task.phase:
msg = ("The task {self} is specified as running before {other}, "
"but its phase {phase} lies after the phase {other_phase}"
.format(self, other, self.phase, other.phase))
.format(self=type(self), other=task, phase=self.phase, other_phase=task.phase))
raise TaskListError(msg)
for task in self.after:
if self.phase < task.phase:
msg = ("The task {self} is specified as running after {other}, "
"but its phase {phase} lies before the phase {other_phase}"
.format(self=self, other=other, phase=self.phase, other_phase=other.phase))
.format(self=type(self), other=task, phase=self.phase, other_phase=task.phase))
raise TaskListError(msg)

View file

@ -23,30 +23,36 @@ class TaskList(object):
def run(self, bootstrap_info):
task_list = self.create_list()
for task in tasks:
log.debug('Tasklist:\n\t{list}'.format(list='\n\t'.join(repr(task) for task in task_list)))
for task in task_list:
log.info(task)
task.run(bootstrap_info)
def create_list(self):
from common.phases import order
graph = {}
for task in self.tasks:
graph[task] = [self.get(succ) for succ in task.before]
graph[task] = []
graph[task].extend([self.get(succ) for succ in task.before])
graph[task].extend([succ for succ in self.tasks if type(task) in succ.after])
succeeding_phases = order[order.index(task.phase)+1:]
graph[task].extend([succ for succ in self.tasks if succ.phase in succeeding_phases])
components = self.strongly_connected_components(graph)
cycles_found = 0
for component in components:
if len(component) > 1:
cycles_found += 1
log.debug('Cycle: {list}\n'.format(list=', '.join(str(task) for task in component)))
log.debug('Cycle: {list}\n'.format(list=', '.join(repr(task) for task in component)))
if cycles_found > 0:
msg = ('{0} cycles were found in the tasklist, '
'consult the logfile for more information.'.format(cycles_found))
raise TaskListError(msg)
sorted_tasks = self.topological_sort(graph)
log.debug('Tasklist:\n\t{list}\n'.format(list='\n\t'.join(str(task) for task in sorted_tasks)))
return sorted_tasks
def strongly_connected_components(self, graph):
# Source: http://www.logarithmic.net/pfh-files/blog/01208083168/sort.py

View file

@ -1,10 +1,12 @@
from base import Task
from common import phases
from providers.ec2.tasks.host import GetInfo
class PrintInfo(Task):
description = 'Printing `info\' to the console'
phase = phases.InstallOS
after = [GetInfo]
def run(self, info):
super(PrintInfo, self).run(info)

View file

@ -1,6 +1,6 @@
from base import Task
from common import phases
from providers.ec2.tasks import packages
import packages
class CheckPackages(Task):
description = 'Checking installed host packages'