From d6d468633f78a66535373466c134512775a414ca Mon Sep 17 00:00:00 2001 From: Chris Church Date: Wed, 24 Apr 2013 11:35:30 -0400 Subject: [PATCH] Task refactoring, updates to support ssh-agent and responding to password prompts. Needs test for using credentials. --- lib/main/admin.py | 1 + lib/main/models/__init__.py | 4 +- lib/main/serializers.py | 4 +- lib/main/tasks.py | 321 +++++++++++++++++++++--------------- lib/main/tests/tasks.py | 3 +- requirements.txt | 1 + 6 files changed, 195 insertions(+), 139 deletions(-) diff --git a/lib/main/admin.py b/lib/main/admin.py index d8a15d4db3..74908bff10 100644 --- a/lib/main/admin.py +++ b/lib/main/admin.py @@ -319,6 +319,7 @@ class JobEventInlineForJob(JobEventInline): class JobAdmin(BaseModelAdmin): list_display = ('name', 'job_template', 'project', 'playbook', 'status') + list_filter = ('status',) fieldsets = ( (None, {'fields': ('name', 'job_template', 'description')}), (_('Job Parameters'), {'fields': ('inventory', 'project', 'playbook', diff --git a/lib/main/models/__init__.py b/lib/main/models/__init__.py index eddf439227..3eefe3e745 100644 --- a/lib/main/models/__init__.py +++ b/lib/main/models/__init__.py @@ -1107,7 +1107,7 @@ class Job(CommonModel): pass def start(self, **kwargs): - from lib.main.tasks import run_job + from lib.main.tasks import RunJob if self.status != 'new': return False @@ -1116,7 +1116,7 @@ class Job(CommonModel): opts = {} self.status = 'pending' self.save(update_fields=['status']) - task_result = run_job.delay(self.pk, **opts) + task_result = RunJob().delay(self.pk, **opts) # The TaskMeta instance in the database isn't created until the worker # starts processing the task, so we can only store the task ID here. self.celery_task_id = task_result.task_id diff --git a/lib/main/serializers.py b/lib/main/serializers.py index 526c478ab1..738a50822d 100644 --- a/lib/main/serializers.py +++ b/lib/main/serializers.py @@ -185,8 +185,8 @@ class CredentialSerializer(BaseSerializer): model = Credential fields = ( 'url', 'id', 'related', 'name', 'description', 'creation_date', - 'default_username', 'ssh_key_data', 'ssh_key_unlock', 'ssh_password', 'sudo_password', - 'user', 'team' + 'ssh_username', 'ssh_password', 'ssh_key_data', 'ssh_key_unlock', + 'sudo_username', 'sudo_password', 'user', 'team', ) def get_related(self, obj): diff --git a/lib/main/tasks.py b/lib/main/tasks.py index 63435514f6..c3f3c24b78 100644 --- a/lib/main/tasks.py +++ b/lib/main/tasks.py @@ -14,85 +14,54 @@ # You should have received a copy of the GNU General Public License # along with Ansible Commander. If not, see . +import cStringIO import logging import os import select import subprocess import time import traceback -from celery import task +from celery import Task from django.conf import settings +import pexpect from lib.main.models import * -__all__ = ['run_job'] +__all__ = ['RunJob'] -logger = logging.getLogger('lib.tasks') - -class Timeout(object): - - def __init__(self, duration=None): - # If initializing from another instance, create a new timeout from the - # remaining time on the other instance. - if isinstance(duration, Timeout): - duration = duration.remaining - self.reset(duration) - - def __repr__(self): - if self._duration is None: - return 'Timeout(None)' - else: - return 'Timeout(%f)' % self._duration - - def __hash__(self): - return self._duration - - def __nonzero__(self): - return self.block - - def reset(self, duration=False): - if duration is not False: - self._duration = float(max(0, duration)) if duration is not None else None - self._begin = time.time() - - def expire(self): - self._begin = time.time() - max(0, self._duration or 0.0) - - @property - def duration(self): - return self._duration - - @property - def elapsed(self): - return float(max(0, time.time() - self._begin)) - - @property - def remaining(self): - if self._duration is None: - return None - else: - return float(max(0, self._duration + self._begin - time.time())) - - @property - def block(self): - return bool(self.remaining or self.remaining is None) +logger = logging.getLogger('lib.main.tasks') -@task(name='run_job') -def run_job(job_pk, **kwargs): - job = Job.objects.get(pk=job_pk) - job.status = 'running' - job.save(update_fields=['status']) +class RunJob(Task): + ''' + Celery task to run a job using ansible-playbook. + ''' - try: - status, stdout, stderr, tb = 'error', '', '', '' - plugin_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', - 'plugins', 'callback')) - inventory_script = os.path.abspath(os.path.join(os.path.dirname(__file__), - 'management', 'commands', - 'acom_inventory.py')) - callback_script = os.path.abspath(os.path.join(os.path.dirname(__file__), - 'management', 'commands', - 'acom_callback_event.py')) + name = 'run_job' + + def update_job(self, job_pk, **job_updates): + ''' + Reload Job from database and update the given fields. + ''' + job = Job.objects.get(pk=job_pk) + if job_updates: + for field, value in job_updates.items(): + setattr(job, field, value) + job.save(update_fields=job_updates.keys()) + return job + + def get_path_to(self, *args): + ''' + Return absolute path relative to this file. + ''' + return os.path.abspath(os.path.join(os.path.dirname(__file__), *args)) + + def build_env(self, job, **kwargs): + ''' + Build environment dictionary for ansible-playbook. + ''' + plugin_dir = self.get_path_to('..', 'plugins', 'callback') + callback_script = self.get_path_to('management', 'commands', + 'acom_callback_event.py') env = dict(os.environ.items()) # question: when running over CLI, generate a random ID or grab next, etc? # answer: TBD @@ -100,82 +69,111 @@ def run_job(job_pk, **kwargs): env['ACOM_INVENTORY_ID'] = str(job.inventory.pk) env['ANSIBLE_CALLBACK_PLUGINS'] = plugin_dir env['ACOM_CALLBACK_EVENT_SCRIPT'] = callback_script - if hasattr(settings, 'ANSIBLE_TRANSPORT'): env['ANSIBLE_TRANSPORT'] = getattr(settings, 'ANSIBLE_TRANSPORT') + env['ANSIBLE_NOCOLOR'] = '1' # Prevent output of escape sequences. + return env + def build_args(self, job, **kwargs): + ''' + Build command line argument list for running ansible-playbook, + optionally using ssh-agent for public/private key authentication. + ''' creds = job.credential - username = creds.ssh_username - #sudo_username = job.credential.sudo_username - - - - cwd = job.project.local_path - - cmdline = ['ansible-playbook', '-i', inventory_script] + use_ssh_agent = False + if creds: + username = creds.ssh_username + sudo_username = creds.sudo_username + # FIXME: Do something with creds. + inventory_script = self.get_path_to('management', 'commands', + 'acom_inventory.py') + args = ['ansible-playbook', '-i', inventory_script] if job.job_type == 'check': - cmdline.append('--check') + args.append('--check') if job.use_sudo: - cmdline.append('--sudo') + args.append('--sudo') if job.forks: # FIXME: Max limit? - cmdline.append('--forks=%d' % job.forks) + args.append('--forks=%d' % job.forks) if job.limit: - cmdline.append('--limit=%s' % job.limit) + args.append('--limit=%s' % job.limit) if job.verbosity: - cmdline.append('-%s' % ('v' * min(3, job.verbosity))) + args.append('-%s' % ('v' * min(3, job.verbosity))) if job.extra_vars: # FIXME: escaping! extra_vars = ' '.join(['%s=%s' % (str(k), str(v)) for k,v in job.extra_vars.items()]) - cmdline.append('-e', extra_vars) - cmdline.append(job.playbook) # relative path to project.local_path + args.append('-e', extra_vars) + args.append(job.playbook) # relative path to project.local_path + if use_ssh_agent: + key_path = 'myrsa' # FIXME + cmd = '; '.join([subprocess.list2cmdline(['ssh-add', keypath]), + subprocess.list2cmdline(args)]) + return ['ssh-agent', 'sh', '-c', cmd] + else: + return args - proc = subprocess.Popen(cmdline, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, cwd=cwd, env=env) - # stdout, stderr = proc.communicate() + def build_passwords(self, job, **kwargs): + ''' + Build a dictionary of passwords for SSH private key, SSH user and sudo. + ''' + return {} + + def capture_subprocess_output(self, proc, timeout=1.0): + ''' + Capture stdout/stderr from the given process until the timeout expires. + ''' + stdout, stderr = '', '' + until = time.time() + timeout + remaining = max(0, until - time.time()) + while remaining > 0: + # FIXME: Probably want to use poll (when on Linux), needs to be tested. + if hasattr(select, 'poll') and False: + poll = select.poll() + poll.register(proc.stdout.fileno(), select.POLLIN or select.POLLPRI) + poll.register(proc.stderr.fileno(), select.POLLIN or select.POLLPRI) + fd_events = poll.poll(remaining) + if not fd_events: + break + for fd, evt in fd_events: + if fd == proc.stdout.fileno() and evt > 0: + stdout += proc.stdout.read(1) + elif fd == proc.stderr.fileno() and evt > 0: + stderr += proc.stderr.read(1) + else: + stdout_byte, stderr_byte = '', '' + fdlist = [proc.stdout.fileno(), proc.stderr.fileno()] + rwx = select.select(fdlist, [], [], remaining) + if proc.stdout.fileno() in rwx[0]: + stdout_byte = proc.stdout.read(1) + stdout += stdout_byte + if proc.stderr.fileno() in rwx[0]: + stderr_byte = proc.stderr.read(1) + stderr += stderr_byte + if not stdout_byte and not stderr_byte: + break + remaining = max(0, until - time.time()) + return stdout, stderr + + def run_subprocess(self, job_pk, args, cwd, env, passwords): + ''' + Run the job using subprocess to capture stdout/stderr. + ''' + status, stdout, stderr = 'error', '', '' + proc = subprocess.Popen(args, cwd=cwd, env=env, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) proc_canceled = False - while proc.returncode is None: - new_stdout, new_stderr = '', '' - timeout = Timeout(1.0) - while timeout: - # FIXME: Probably want to use poll (when on Linux), needs to be tested. - if hasattr(select, 'poll') and False: - poll = select.poll() - poll.register(proc.stdout.fileno(), select.POLLIN or select.POLLPRI) - poll.register(proc.stderr.fileno(), select.POLLIN or select.POLLPRI) - fd_events = poll.poll(1.0) - if not fd_events: - break - for fd, evt in fd_events: - if fd == proc.stdout.fileno() and evt > 0: - new_stdout += proc.stdout.read(1) - elif fd == proc.stderr.fileno() and evt > 0: - new_stderr += proc.stderr.read(1) - else: - stdout_byte, stderr_byte = '', '' - fdlist = [proc.stdout.fileno(), proc.stderr.fileno()] - rwx = select.select(fdlist, [], [], timeout.remaining) - if proc.stdout.fileno() in rwx[0]: - stdout_byte = proc.stdout.read(1) - new_stdout += stdout_byte - if proc.stderr.fileno() in rwx[0]: - stderr_byte = proc.stderr.read(1) - new_stderr += stderr_byte - if not stdout_byte and not stderr_byte: - break - job = Job.objects.get(pk=job_pk) - update_fields = [] + while proc.poll() is None: + new_stdout, new_stderr = self.capture_subprocess_output(proc) + job_updates = {} if new_stdout: stdout += new_stdout - job.result_stdout = stdout - update_fields.append('result_stdout') + job_updates['result_stdout'] = stdout if new_stderr: stderr += new_stderr - job.result_stderr = stderr - update_fields.append('result_stderr') - if update_fields: - job.save(update_fields=update_fields) - proc.poll() + job_updates['result_stdout'] = stdout + job = self.update_job(job_pk, **job_updates) if job.cancel_flag and not proc_canceled: proc.terminate() proc_canceled = True @@ -187,14 +185,69 @@ def run_job(job_pk, **kwargs): status = 'successful' else: status = 'failed' - except Exception: - tb = traceback.format_exc() - - # Reload from database before updating/saving. - job = Job.objects.get(pk=job_pk) - job.status = status - job.result_stdout = stdout - job.result_stderr = stderr - job.result_traceback = tb - job.save(update_fields=['status', 'result_stdout', 'result_stderr', - 'result_traceback']) + return status, stdout, stderr + + def run_pexpect(self, job_pk, args, cwd, env, passwords): + ''' + Run the job using pexpect to capture output and provide passwords when + requested. + ''' + status, stdout, stderr = 'error', '', '' + logfile = cStringIO.StringIO() + logfile_pos = logfile.tell() + child = pexpect.spawn(args[0], args[1:], cwd=cwd, env=env) + child.logfile_read = logfile + job_canceled = False + while child.isalive(): + expect_list = [ + r'Enter passphrase for .*:', + r'Bad passphrase, try again for .*:', + r'sudo password.*:', + r'SSH password:', + pexpect.TIMEOUT, + pexpect.EOF, + ] + result_id = child.expect(expect_list, timeout=2) + if result_id == 0: + child.sendline(passwords.get('ssh_unlock_key', '')) + elif result_id == 1: + child.sendline('') + elif result_id == 2: + child.sendline(passwords.get('sudo_password', '')) + elif result_id == 3: + child.sendline(passwords.get('ssh_password', '')) + job_updates = {} + if logfile_pos != logfile.tell(): + job_updates['result_stdout'] = logfile.getvalue() + job = self.update_job(job_pk, **job_updates) + if job.cancel_flag: + child.close(True) + job_canceled = True + if job_canceled: + status = 'canceled' + elif child.exitstatus == 0: + status = 'successful' + else: + status = 'failed' + stdout = logfile.getvalue() + return status, stdout, stderr + + def run(self, job_pk, **kwargs): + ''' + Run the job using ansible-playbook and capture its output. + ''' + job = self.update_job(job_pk, status='running') + try: + status, stdout, stderr, tb = 'error', '', '', '' + args = self.build_args(job, **kwargs) + cwd = job.project.local_path + env = self.build_env(job, **kwargs) + passwords = self.build_passwords(job, **kwargs) + #status, stdout, stderr = self.run_subprocess(job_pk, args, cwd, + # env, passwords) + status, stdout, stderr = self.run_pexpect(job_pk, args, cwd, env, + passwords) + except Exception: + tb = traceback.format_exc() + self.update_job(job_pk, status=status, result_stdout=stdout, + result_stderr=stderr, result_traceback=tb) diff --git a/lib/main/tests/tasks.py b/lib/main/tests/tasks.py index 2f3ed837f9..55c5c4fdd4 100644 --- a/lib/main/tests/tasks.py +++ b/lib/main/tests/tasks.py @@ -78,10 +78,11 @@ class RunJobTest(BaseCeleryTest): opts = { 'name': 'test-creds', 'user': self.super_django_user, - 'default_username': '', + 'ssh_username': '', 'ssh_key_data': '', 'ssh_key_unlock': '', 'ssh_password': '', + 'sudo_username': '', 'sudo_password': '', } opts.update(kwargs) diff --git a/requirements.txt b/requirements.txt index 36c449094b..eebba10140 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ django-extensions==1.1.1 django-jsonfield==0.9.2 ipython==0.13.1 paramiko==1.10.0 +pexpect==2.4 # psycopg2==2.4.6 python-dateutil==1.5 PyYAML==3.10