diff --git a/bootstrapvz/remote/callback.py b/bootstrapvz/remote/build_servers/callback.py similarity index 80% rename from bootstrapvz/remote/callback.py rename to bootstrapvz/remote/build_servers/callback.py index 2c385cc..79cf41b 100644 --- a/bootstrapvz/remote/callback.py +++ b/bootstrapvz/remote/build_servers/callback.py @@ -14,19 +14,19 @@ class CallbackServer(object): self.daemon.register(self) self.abort = False - def start(self): + def __enter__(self): def serve(): self.daemon.requestLoop() from threading import Thread self.thread = Thread(target=serve) - log.debug('Starting the callback server') + log.debug('Starting callback server') self.thread.start() + return self - def stop(self): - if hasattr(self, 'daemon'): - self.daemon.shutdown() - if hasattr(self, 'thread'): - self.thread.join() + def __exit__(self, type, value, traceback): + log.debug('Shutting down callback server') + self.daemon.shutdown() + self.thread.join() @Pyro4.expose def handle_log(self, pickled_record): diff --git a/bootstrapvz/remote/build_servers/remote.py b/bootstrapvz/remote/build_servers/remote.py index c14c23f..6e13b26 100644 --- a/bootstrapvz/remote/build_servers/remote.py +++ b/bootstrapvz/remote/build_servers/remote.py @@ -1,5 +1,6 @@ from build_server import BuildServer from bootstrapvz.common.tools import log_check_call +from contextlib import contextmanager import logging log = logging.getLogger(__name__) @@ -15,18 +16,27 @@ class RemoteBuildServer(BuildServer): self.keyfile = settings['keyfile'] self.server_bin = settings['server_bin'] + @contextmanager + def connect(self): + with self.spawn_server() as forwards: + with connect_pyro('localhost', forwards['local_server_port']) as connection: + from callback import CallbackServer + args = {'listen_port': forwards['local_callback_port'], + 'remote_port': forwards['remote_callback_port']} + with CallbackServer(**args) as callback_server: + connection.set_callback_server(callback_server) + yield (connection, callback_server) + + @contextmanager + def spawn_server(self): from . import getNPorts # We can't use :0 for the forwarding ports because # A: It's quite hard to retrieve the port on the remote after the daemon has started # B: SSH doesn't accept 0:localhost:0 as a port forwarding option - [self.local_server_port, self.local_callback_port] = getNPorts(2) - [self.remote_server_port, self.remote_callback_port] = getNPorts(2) + [local_server_port, local_callback_port] = getNPorts(2) + [remote_server_port, remote_callback_port] = getNPorts(2) - def connect(self): - log.debug('Opening SSH connection to build server `{name}\''.format(name=self.name)) - import subprocess - - server_cmd = ['sudo', self.server_bin, '--listen', str(self.remote_server_port)] + server_cmd = ['sudo', self.server_bin, '--listen', str(remote_server_port)] def set_process_group(): # Changes the process group of a command so that any SIGINT @@ -38,46 +48,28 @@ class RemoteBuildServer(BuildServer): addr_arg = '{user}@{host}'.format(user=self.username, host=self.address) ssh_cmd = ['ssh', '-i', self.keyfile, '-p', str(self.port), - '-L' + str(self.local_server_port) + ':localhost:' + str(self.remote_server_port), - '-R' + str(self.remote_callback_port) + ':localhost:' + str(self.local_callback_port), + '-L' + str(local_server_port) + ':localhost:' + str(remote_server_port), + '-R' + str(remote_callback_port) + ':localhost:' + str(local_callback_port), addr_arg] full_cmd = ssh_cmd + ['--'] + server_cmd + + log.debug('Opening SSH connection to build server `{name}\''.format(name=self.name)) import sys - self.ssh_process = subprocess.Popen(args=full_cmd, stdout=sys.stderr, stderr=sys.stderr, - preexec_fn=set_process_group) - - # Check that we can connect to the server + import subprocess + ssh_process = subprocess.Popen(args=full_cmd, stdout=sys.stderr, stderr=sys.stderr, + preexec_fn=set_process_group) try: - import Pyro4 - server_uri = 'PYRO:server@localhost:{server_port}'.format(server_port=self.local_server_port) - self.connection = Pyro4.Proxy(server_uri) - - log.debug('Connecting to the RPC daemon on build server `{name}\''.format(name=self.name)) - remaining_retries = 5 - while True: - try: - self.connection.ping() - break - except (Pyro4.errors.ConnectionClosedError, Pyro4.errors.CommunicationError): - if remaining_retries > 0: - remaining_retries -= 1 - from time import sleep - sleep(2) - else: - raise + yield {'local_server_port': local_server_port, + 'local_callback_port': local_callback_port, + 'remote_server_port': remote_server_port, + 'remote_callback_port': remote_callback_port} except (Exception, KeyboardInterrupt): - self.ssh_process.terminate() + log.debug('Forcefully terminating SSH connection to the build server') + ssh_process.terminate() raise - return self.connection - - def disconnect(self): - if hasattr(self, 'connection'): - log.debug('Stopping the RPC daemon on build server `{name}\''.format(name=self.name)) - self.connection.stop() - self.connection._pyroRelease() - if hasattr(self, 'ssh_process'): - log.debug('Waiting for SSH connection to build server `{name}\' to terminate'.format(name=self.name)) - self.ssh_process.wait() + else: + log.debug('Waiting for SSH connection to the build server to close') + ssh_process.wait() def download(self, src, dst): log.debug('Downloading file `{src}\' from ' @@ -103,3 +95,31 @@ class RemoteBuildServer(BuildServer): def run(self, manifest): from bootstrapvz.remote.main import run return run(manifest, self) + + +@contextmanager +def connect_pyro(host, port): + import Pyro4 + server_uri = 'PYRO:server@{host}:{port}'.format(host=host, port=port) + connection = Pyro4.Proxy(server_uri) + + log.debug('Connecting to the RPC daemon') + remaining_retries = 5 + while True: + try: + connection.ping() + break + except (Pyro4.errors.ConnectionClosedError, Pyro4.errors.CommunicationError): + if remaining_retries > 0: + remaining_retries -= 1 + from time import sleep + sleep(2) + else: + raise + + try: + yield connection + finally: + log.debug('Stopping the RPC daemon') + connection.stop() + connection._pyroRelease() diff --git a/bootstrapvz/remote/main.py b/bootstrapvz/remote/main.py index eb247f3..6508136 100644 --- a/bootstrapvz/remote/main.py +++ b/bootstrapvz/remote/main.py @@ -75,40 +75,22 @@ def run(manifest, build_server, debug=False, dry_run=False): on the other side and initiates a remote bootstrapping procedure """ bootstrap_info = None - try: - # Connect to the build server - connection = build_server.connect() - # Start a callback server on this side, so that we may receive log entries - from callback import CallbackServer - callback_server = CallbackServer(listen_port=build_server.local_callback_port, - remote_port=build_server.remote_callback_port) - try: - # Start the callback server (in a background thread) - callback_server.start() - # Tell the RPC daemon about the callback server - connection.set_callback_server(callback_server) + with build_server.connect() as (connection, callback_server): + # Replace the standard SIGINT handler with a remote call to the server + # so that it may abort the run. + def abort(signum, frame): + import logging + logging.getLogger(__name__).warn('SIGINT received, asking remote to abort.') + callback_server.abort_run() + import signal + orig_sigint = signal.signal(signal.SIGINT, abort) - # Replace the standard SIGINT handler with a remote call to the server - # so that it may abort the run. - def abort(signum, frame): - import logging - logging.getLogger(__name__).warn('SIGINT received, asking remote to abort.') - callback_server.abort_run() - import signal - orig_sigint = signal.signal(signal.SIGINT, abort) - - # Everything has been set up, begin the bootstrapping process - bootstrap_info = connection.run(manifest, - debug=debug, - # We can't pause the bootstrapping process remotely, yet... - pause_on_error=False, - dry_run=dry_run) - # Restore the old SIGINT handler - signal.signal(signal.SIGINT, orig_sigint) - finally: - # Stop the callback server - callback_server.stop() - finally: - # Stop the RPC daemon and close the SSH connection - build_server.disconnect() + # Everything has been set up, begin the bootstrapping process + bootstrap_info = connection.run(manifest, + debug=debug, + # We can't pause the bootstrapping process remotely, yet... + pause_on_error=False, + dry_run=dry_run) + # Restore the old SIGINT handler + signal.signal(signal.SIGINT, orig_sigint) return bootstrap_info