Add references instead of instances to the tasklist

This commit is contained in:
Anders Ingemann 2013-09-22 17:31:43 +02:00
parent 8aefa93646
commit 738ba47c65
12 changed files with 93 additions and 102 deletions

View file

@ -1,4 +1,3 @@
from common.exceptions import TaskListError
class Task(object): class Task(object):
@ -6,27 +5,8 @@ class Task(object):
before = [] before = []
after = [] after = []
def __init__(self):
self._check_ordering()
def __str__(self): def __str__(self):
return '{module}.{task}'.format(module=self.__module__, task=self.__class__.__name__) return '{module}.{task}'.format(module=self.__module__, task=self.__class__.__name__)
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
def _check_ordering(self):
def name(ref):
return '{module}.{task}'.format(module=ref.__module__, task=ref.__class__.__name__)
for task in self.before:
if self.phase > task.phase:
msg = ("The task {self} is specified as running before {other}, "
"but its phase '{phase}' lies after the phase '{other_phase}'"
.format(self=type(self), other=task, phase=self.phase, other_phase=task.phase))
raise TaskListError(msg)
for task in self.after:
if self.phase < task.phase:
msg = ("The task {self} is specified as running after {other}, "
"but its phase '{phase}' lies before the phase '{other_phase}'"
.format(self=type(self), other=task, phase=self.phase, other_phase=task.phase))
raise TaskListError(msg)

View file

@ -13,23 +13,19 @@ class TaskList(object):
self.tasks.update(args) self.tasks.update(args)
def remove(self, *args): def remove(self, *args):
for task_type in args: for task in args:
task = self.get(task_type) self.tasks.discard(task)
if task is not None:
self.tasks.discard(task)
def replace(self, task, replacement): def replace(self, task, replacement):
self.remove(task) self.remove(task)
self.add(replacement) self.add(replacement)
def get(self, ref):
return next((task for task in self.tasks if type(task) is ref), None)
def run(self, bootstrap_info): def run(self, bootstrap_info):
task_list = self.create_list(self.tasks) task_list = self.create_list(self.tasks)
log.debug('Tasklist:\n\t{list}'.format(list='\n\t'.join(repr(task) for task in task_list))) log.debug('Tasklist:\n\t{list}'.format(list='\n\t'.join(repr(task) for task in task_list)))
for task in task_list: for task_type in task_list:
task = task_type()
if hasattr(task, 'description'): if hasattr(task, 'description'):
log.info(task.description) log.info(task.description)
else: else:
@ -41,12 +37,13 @@ class TaskList(object):
from common.phases import order from common.phases import order
graph = {} graph = {}
for task in tasks: for task in tasks:
successors = [] self.check_ordering(task)
successors.extend([self.get(succ) for succ in task.before]) successors = set()
successors.extend(filter(lambda succ: type(task) in succ.after, tasks)) successors.update(task.before)
successors.update(filter(lambda succ: task in succ.after, tasks))
succeeding_phases = order[order.index(task.phase) + 1:] succeeding_phases = order[order.index(task.phase) + 1:]
successors.extend(filter(lambda succ: succ.phase in succeeding_phases, tasks)) successors.update(filter(lambda succ: succ.phase in succeeding_phases, tasks))
graph[task] = filter(lambda succ: succ in self.tasks, successors) graph[task] = filter(lambda succ: succ in tasks, successors)
components = self.strongly_connected_components(graph) components = self.strongly_connected_components(graph)
cycles_found = 0 cycles_found = 0
@ -63,6 +60,20 @@ class TaskList(object):
return sorted_tasks return sorted_tasks
def check_ordering(self, task):
for successor in task.before:
if successor.phase > successor.phase:
msg = ("The task {task} is specified as running before {other}, "
"but its phase '{phase}' lies after the phase '{other_phase}'"
.format(task=task, other=successor, phase=task.phase, other_phase=successor.phase))
raise TaskListError(msg)
for predecessor in task.after:
if task.phase < predecessor.phase:
msg = ("The task {task} is specified as running after {other}, "
"but its phase '{phase}' lies before the phase '{other_phase}'"
.format(task=task, other=predecessor, phase=task.phase, other_phase=predecessor.phase))
raise TaskListError(msg)
def strongly_connected_components(self, graph): def strongly_connected_components(self, graph):
# Source: http://www.logarithmic.net/pfh-files/blog/01208083168/sort.py # Source: http://www.logarithmic.net/pfh-files/blog/01208083168/sort.py
# Find the strongly connected components in a graph using Tarjan's algorithm. # Find the strongly connected components in a graph using Tarjan's algorithm.

View file

@ -2,11 +2,11 @@
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
import tasks import tasks
tasklist.add(tasks.AddSudoPackage()) tasklist.add(tasks.AddSudoPackage.
tasklist.add(tasks.CreateAdminUser()) tasks.CreateAdminUser,
tasklist.add(tasks.PasswordlessSudo()) tasks.PasswordlessSudo,
tasklist.add(tasks.AdminUserCredentials()) tasks.AdminUserCredentials,
tasklist.add(tasks.DisableRootLogin()) tasks.DisableRootLogin)
def validate_manifest(data, schema_validate): def validate_manifest(data, schema_validate):

View file

@ -2,8 +2,8 @@
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
import tasks import tasks
tasklist.add(tasks.AptSourcesBackports()) tasklist.add(tasks.AptSourcesBackports,
tasklist.add(tasks.AddBackportsPackages()) tasks.AddBackportsPackages)
def validate_manifest(data, schema_validate): def validate_manifest(data, schema_validate):

View file

@ -2,4 +2,4 @@
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
from tasks import WriteMetadata from tasks import WriteMetadata
tasklist.add(WriteMetadata()) tasklist.add(WriteMetadata

View file

@ -2,7 +2,7 @@
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
from tasks import ConvertImage from tasks import ConvertImage
tasklist.add(ConvertImage()) tasklist.add(ConvertImage)
def validate_manifest(data, schema_validate): def validate_manifest(data, schema_validate):

View file

@ -2,4 +2,4 @@
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
import tasks import tasks
tasklist.add(tasks.OpenNebulaContext()) tasklist.add(tasks.OpenNebulaContext)

View file

@ -21,16 +21,16 @@ def tasks(tasklist, manifest):
bootstrap.Bootstrap] bootstrap.Bootstrap]
if manifest.volume['backing'] == 'ebs': if manifest.volume['backing'] == 'ebs':
if 'snapshot' in settings and settings['snapshot'] is not None: if 'snapshot' in settings and settings['snapshot'] is not None:
tasklist.replace(ebs.Create, CreateFromSnapshot()) tasklist.replace(ebs.Create, CreateFromSnapshot)
tasklist.remove(*skip_tasks) tasklist.remove(*skip_tasks)
else: else:
tasklist.add(Snapshot()) tasklist.add(Snapshot)
else: else:
if 'image' in settings and settings['image'] is not None: if 'image' in settings and settings['image'] is not None:
tasklist.replace(loopback.Create, CreateFromImage()) tasklist.replace(loopback.Create, CreateFromImage)
tasklist.remove(*skip_tasks) tasklist.remove(*skip_tasks)
else: else:
tasklist.add(CopyImage()) tasklist.add(CopyImage)
def rollback_tasks(tasklist, tasks_completed, manifest): def rollback_tasks(tasklist, tasks_completed, manifest):
@ -38,7 +38,7 @@ def rollback_tasks(tasklist, tasks_completed, manifest):
def counter_task(task, counter): def counter_task(task, counter):
if task in completed and counter not in completed: if task in completed and counter not in completed:
tasklist.add(counter()) tasklist.add(counter)
if manifest.volume['backing'] == 'ebs': if manifest.volume['backing'] == 'ebs':
counter_task(CreateFromSnapshot, volume.Delete) counter_task(CreateFromSnapshot, volume.Delete)

View file

@ -3,7 +3,7 @@
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
from common.tasks.security import DisableSSHPasswordAuthentication from common.tasks.security import DisableSSHPasswordAuthentication
from tasks import SetRootPassword from tasks import SetRootPassword
tasklist.replace(DisableSSHPasswordAuthentication, SetRootPassword()) tasklist.replace(DisableSSHPasswordAuthentication, SetRootPassword)
def validate_manifest(data, schema_validate): def validate_manifest(data, schema_validate):

View file

@ -2,8 +2,8 @@
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
import tasks import tasks
tasklist.add(tasks.AddUnattendedUpgradesPackage()) tasklist.add(tasks.AddUnattendedUpgradesPackage,
tasklist.add(tasks.EnablePeriodicUpgrades()) tasks.EnablePeriodicUpgrades)
def validate_manifest(data, schema_validate): def validate_manifest(data, schema_validate):

View file

@ -2,5 +2,5 @@
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
from user_packages import AddUserPackages, AddLocalUserPackages from user_packages import AddUserPackages, AddLocalUserPackages
tasklist.add(AddUserPackages()) tasklist.add(AddUserPackages,
tasklist.add(AddLocalUserPackages()) AddLocalUserPackages)

View file

@ -23,54 +23,54 @@ def initialize():
def tasks(tasklist, manifest): def tasks(tasklist, manifest):
tasklist.add(workspace.CreateWorkspace(), tasklist.add(workspace.CreateWorkspace,
packages.HostPackages(), packages.HostPackages,
common_packages.HostPackages(), common_packages.HostPackages,
packages.ImagePackages(), packages.ImagePackages,
common_packages.ImagePackages(), common_packages.ImagePackages,
host.CheckPackages(), host.CheckPackages,
loopback.Create(), loopback.Create,
volume_tasks.Attach(), volume_tasks.Attach,
partitioning.PartitionVolume(), partitioning.PartitionVolume,
partitioning.MapPartitions(), partitioning.MapPartitions,
filesystem.Format(), filesystem.Format,
filesystem.CreateMountDir(), filesystem.CreateMountDir,
filesystem.MountRoot(), filesystem.MountRoot,
bootstrap.Bootstrap(), bootstrap.Bootstrap,
filesystem.MountSpecials(), filesystem.MountSpecials,
locale.GenerateLocale(), locale.GenerateLocale,
locale.SetTimezone(), locale.SetTimezone,
apt.DisableDaemonAutostart(), apt.DisableDaemonAutostart,
apt.AptSources(), apt.AptSources,
apt.AptUpgrade(), apt.AptUpgrade,
boot.ConfigureGrub(), boot.ConfigureGrub,
filesystem.FStab(), filesystem.FStab,
common_boot.BlackListModules(), common_boot.BlackListModules,
common_boot.DisableGetTTYs(), common_boot.DisableGetTTYs,
security.EnableShadowConfig(), security.EnableShadowConfig,
network.RemoveDNSInfo(), network.RemoveDNSInfo,
network.ConfigureNetworkIF(), network.ConfigureNetworkIF,
network.RemoveHostname(), network.RemoveHostname,
initd.ResolveInitScripts(), initd.ResolveInitScripts,
initd.InstallInitScripts(), initd.InstallInitScripts,
cleanup.ClearMOTD(), cleanup.ClearMOTD,
cleanup.CleanTMP(), cleanup.CleanTMP,
apt.PurgeUnusedPackages(), apt.PurgeUnusedPackages,
apt.AptClean(), apt.AptClean,
apt.EnableDaemonAutostart(), apt.EnableDaemonAutostart,
filesystem.UnmountSpecials(), filesystem.UnmountSpecials,
filesystem.UnmountRoot(), filesystem.UnmountRoot,
partitioning.UnmapPartitions(), partitioning.UnmapPartitions,
volume_tasks.Detach(), volume_tasks.Detach,
filesystem.DeleteMountDir(), filesystem.DeleteMountDir,
loopback.MoveImage(), loopback.MoveImage,
workspace.DeleteWorkspace()) workspace.DeleteWorkspace)
if manifest.bootstrapper.get('tarball', False): if manifest.bootstrapper.get('tarball', False):
tasklist.add(bootstrap.MakeTarball()) tasklist.add(bootstrap.MakeTarball)
partitions = manifest.volume['partitions'] partitions = manifest.volume['partitions']
import re import re
@ -78,19 +78,19 @@ def tasks(tasklist, manifest):
if key not in partitions: if key not in partitions:
continue continue
if re.match('^ext[2-4]$', partitions[key]['filesystem']) is not None: if re.match('^ext[2-4]$', partitions[key]['filesystem']) is not None:
tasklist.add(filesystem.TuneVolumeFS()) tasklist.add(filesystem.TuneVolumeFS)
break break
for key in ['boot', 'root']: for key in ['boot', 'root']:
if key not in partitions: if key not in partitions:
continue continue
if partitions[key]['filesystem'] == 'xfs': if partitions[key]['filesystem'] == 'xfs':
tasklist.add(filesystem.AddXFSProgs()) tasklist.add(filesystem.AddXFSProgs)
break break
if 'boot' in manifest.volume['partitions']: if 'boot' in manifest.volume['partitions']:
tasklist.add(filesystem.CreateBootMountDir(), tasklist.add(filesystem.CreateBootMountDir,
filesystem.MountBoot(), filesystem.MountBoot,
filesystem.UnmountBoot()) filesystem.UnmountBoot)
def rollback_tasks(tasklist, tasks_completed, manifest): def rollback_tasks(tasklist, tasks_completed, manifest):
@ -98,7 +98,7 @@ def rollback_tasks(tasklist, tasks_completed, manifest):
def counter_task(task, counter): def counter_task(task, counter):
if task in completed and counter not in completed: if task in completed and counter not in completed:
tasklist.add(counter()) tasklist.add(counter)
counter_task(loopback.Create, volume_tasks.Delete) counter_task(loopback.Create, volume_tasks.Delete)
counter_task(filesystem.CreateMountDir, filesystem.DeleteMountDir) counter_task(filesystem.CreateMountDir, filesystem.DeleteMountDir)