Task refactoring, updates to support ssh-agent and responding to password prompts. Needs test for using credentials.

This commit is contained in:
Chris Church
2013-04-24 11:35:30 -04:00
parent cc25d55121
commit d6d468633f
6 changed files with 195 additions and 139 deletions

View File

@@ -319,6 +319,7 @@ class JobEventInlineForJob(JobEventInline):
class JobAdmin(BaseModelAdmin): class JobAdmin(BaseModelAdmin):
list_display = ('name', 'job_template', 'project', 'playbook', 'status') list_display = ('name', 'job_template', 'project', 'playbook', 'status')
list_filter = ('status',)
fieldsets = ( fieldsets = (
(None, {'fields': ('name', 'job_template', 'description')}), (None, {'fields': ('name', 'job_template', 'description')}),
(_('Job Parameters'), {'fields': ('inventory', 'project', 'playbook', (_('Job Parameters'), {'fields': ('inventory', 'project', 'playbook',

View File

@@ -1107,7 +1107,7 @@ class Job(CommonModel):
pass pass
def start(self, **kwargs): def start(self, **kwargs):
from lib.main.tasks import run_job from lib.main.tasks import RunJob
if self.status != 'new': if self.status != 'new':
return False return False
@@ -1116,7 +1116,7 @@ class Job(CommonModel):
opts = {} opts = {}
self.status = 'pending' self.status = 'pending'
self.save(update_fields=['status']) self.save(update_fields=['status'])
task_result = run_job.delay(self.pk, **opts) task_result = RunJob().delay(self.pk, **opts)
# The TaskMeta instance in the database isn't created until the worker # The TaskMeta instance in the database isn't created until the worker
# starts processing the task, so we can only store the task ID here. # starts processing the task, so we can only store the task ID here.
self.celery_task_id = task_result.task_id self.celery_task_id = task_result.task_id

View File

@@ -185,8 +185,8 @@ class CredentialSerializer(BaseSerializer):
model = Credential model = Credential
fields = ( fields = (
'url', 'id', 'related', 'name', 'description', 'creation_date', 'url', 'id', 'related', 'name', 'description', 'creation_date',
'default_username', 'ssh_key_data', 'ssh_key_unlock', 'ssh_password', 'sudo_password', 'ssh_username', 'ssh_password', 'ssh_key_data', 'ssh_key_unlock',
'user', 'team' 'sudo_username', 'sudo_password', 'user', 'team',
) )
def get_related(self, obj): def get_related(self, obj):

View File

@@ -14,85 +14,54 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Ansible Commander. If not, see <http://www.gnu.org/licenses/>. # along with Ansible Commander. If not, see <http://www.gnu.org/licenses/>.
import cStringIO
import logging import logging
import os import os
import select import select
import subprocess import subprocess
import time import time
import traceback import traceback
from celery import task from celery import Task
from django.conf import settings from django.conf import settings
import pexpect
from lib.main.models import * from lib.main.models import *
__all__ = ['run_job'] __all__ = ['RunJob']
logger = logging.getLogger('lib.tasks') logger = logging.getLogger('lib.main.tasks')
class Timeout(object):
def __init__(self, duration=None):
# If initializing from another instance, create a new timeout from the
# remaining time on the other instance.
if isinstance(duration, Timeout):
duration = duration.remaining
self.reset(duration)
def __repr__(self):
if self._duration is None:
return 'Timeout(None)'
else:
return 'Timeout(%f)' % self._duration
def __hash__(self):
return self._duration
def __nonzero__(self):
return self.block
def reset(self, duration=False):
if duration is not False:
self._duration = float(max(0, duration)) if duration is not None else None
self._begin = time.time()
def expire(self):
self._begin = time.time() - max(0, self._duration or 0.0)
@property
def duration(self):
return self._duration
@property
def elapsed(self):
return float(max(0, time.time() - self._begin))
@property
def remaining(self):
if self._duration is None:
return None
else:
return float(max(0, self._duration + self._begin - time.time()))
@property
def block(self):
return bool(self.remaining or self.remaining is None)
@task(name='run_job') class RunJob(Task):
def run_job(job_pk, **kwargs): '''
job = Job.objects.get(pk=job_pk) Celery task to run a job using ansible-playbook.
job.status = 'running' '''
job.save(update_fields=['status'])
try: name = 'run_job'
status, stdout, stderr, tb = 'error', '', '', ''
plugin_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', def update_job(self, job_pk, **job_updates):
'plugins', 'callback')) '''
inventory_script = os.path.abspath(os.path.join(os.path.dirname(__file__), Reload Job from database and update the given fields.
'management', 'commands', '''
'acom_inventory.py')) job = Job.objects.get(pk=job_pk)
callback_script = os.path.abspath(os.path.join(os.path.dirname(__file__), if job_updates:
'management', 'commands', for field, value in job_updates.items():
'acom_callback_event.py')) setattr(job, field, value)
job.save(update_fields=job_updates.keys())
return job
def get_path_to(self, *args):
'''
Return absolute path relative to this file.
'''
return os.path.abspath(os.path.join(os.path.dirname(__file__), *args))
def build_env(self, job, **kwargs):
'''
Build environment dictionary for ansible-playbook.
'''
plugin_dir = self.get_path_to('..', 'plugins', 'callback')
callback_script = self.get_path_to('management', 'commands',
'acom_callback_event.py')
env = dict(os.environ.items()) env = dict(os.environ.items())
# question: when running over CLI, generate a random ID or grab next, etc? # question: when running over CLI, generate a random ID or grab next, etc?
# answer: TBD # answer: TBD
@@ -100,82 +69,111 @@ def run_job(job_pk, **kwargs):
env['ACOM_INVENTORY_ID'] = str(job.inventory.pk) env['ACOM_INVENTORY_ID'] = str(job.inventory.pk)
env['ANSIBLE_CALLBACK_PLUGINS'] = plugin_dir env['ANSIBLE_CALLBACK_PLUGINS'] = plugin_dir
env['ACOM_CALLBACK_EVENT_SCRIPT'] = callback_script env['ACOM_CALLBACK_EVENT_SCRIPT'] = callback_script
if hasattr(settings, 'ANSIBLE_TRANSPORT'): if hasattr(settings, 'ANSIBLE_TRANSPORT'):
env['ANSIBLE_TRANSPORT'] = getattr(settings, 'ANSIBLE_TRANSPORT') env['ANSIBLE_TRANSPORT'] = getattr(settings, 'ANSIBLE_TRANSPORT')
env['ANSIBLE_NOCOLOR'] = '1' # Prevent output of escape sequences.
return env
def build_args(self, job, **kwargs):
'''
Build command line argument list for running ansible-playbook,
optionally using ssh-agent for public/private key authentication.
'''
creds = job.credential creds = job.credential
username = creds.ssh_username use_ssh_agent = False
#sudo_username = job.credential.sudo_username if creds:
username = creds.ssh_username
sudo_username = creds.sudo_username
# FIXME: Do something with creds.
cwd = job.project.local_path inventory_script = self.get_path_to('management', 'commands',
'acom_inventory.py')
cmdline = ['ansible-playbook', '-i', inventory_script] args = ['ansible-playbook', '-i', inventory_script]
if job.job_type == 'check': if job.job_type == 'check':
cmdline.append('--check') args.append('--check')
if job.use_sudo: if job.use_sudo:
cmdline.append('--sudo') args.append('--sudo')
if job.forks: # FIXME: Max limit? if job.forks: # FIXME: Max limit?
cmdline.append('--forks=%d' % job.forks) args.append('--forks=%d' % job.forks)
if job.limit: if job.limit:
cmdline.append('--limit=%s' % job.limit) args.append('--limit=%s' % job.limit)
if job.verbosity: if job.verbosity:
cmdline.append('-%s' % ('v' * min(3, job.verbosity))) args.append('-%s' % ('v' * min(3, job.verbosity)))
if job.extra_vars: if job.extra_vars:
# FIXME: escaping! # FIXME: escaping!
extra_vars = ' '.join(['%s=%s' % (str(k), str(v)) for k,v in extra_vars = ' '.join(['%s=%s' % (str(k), str(v)) for k,v in
job.extra_vars.items()]) job.extra_vars.items()])
cmdline.append('-e', extra_vars) args.append('-e', extra_vars)
cmdline.append(job.playbook) # relative path to project.local_path args.append(job.playbook) # relative path to project.local_path
if use_ssh_agent:
key_path = 'myrsa' # FIXME
cmd = '; '.join([subprocess.list2cmdline(['ssh-add', keypath]),
subprocess.list2cmdline(args)])
return ['ssh-agent', 'sh', '-c', cmd]
else:
return args
proc = subprocess.Popen(cmdline, stdout=subprocess.PIPE, def build_passwords(self, job, **kwargs):
stderr=subprocess.PIPE, cwd=cwd, env=env) '''
# stdout, stderr = proc.communicate() Build a dictionary of passwords for SSH private key, SSH user and sudo.
'''
return {}
def capture_subprocess_output(self, proc, timeout=1.0):
'''
Capture stdout/stderr from the given process until the timeout expires.
'''
stdout, stderr = '', ''
until = time.time() + timeout
remaining = max(0, until - time.time())
while remaining > 0:
# FIXME: Probably want to use poll (when on Linux), needs to be tested.
if hasattr(select, 'poll') and False:
poll = select.poll()
poll.register(proc.stdout.fileno(), select.POLLIN or select.POLLPRI)
poll.register(proc.stderr.fileno(), select.POLLIN or select.POLLPRI)
fd_events = poll.poll(remaining)
if not fd_events:
break
for fd, evt in fd_events:
if fd == proc.stdout.fileno() and evt > 0:
stdout += proc.stdout.read(1)
elif fd == proc.stderr.fileno() and evt > 0:
stderr += proc.stderr.read(1)
else:
stdout_byte, stderr_byte = '', ''
fdlist = [proc.stdout.fileno(), proc.stderr.fileno()]
rwx = select.select(fdlist, [], [], remaining)
if proc.stdout.fileno() in rwx[0]:
stdout_byte = proc.stdout.read(1)
stdout += stdout_byte
if proc.stderr.fileno() in rwx[0]:
stderr_byte = proc.stderr.read(1)
stderr += stderr_byte
if not stdout_byte and not stderr_byte:
break
remaining = max(0, until - time.time())
return stdout, stderr
def run_subprocess(self, job_pk, args, cwd, env, passwords):
'''
Run the job using subprocess to capture stdout/stderr.
'''
status, stdout, stderr = 'error', '', ''
proc = subprocess.Popen(args, cwd=cwd, env=env,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
proc_canceled = False proc_canceled = False
while proc.returncode is None: while proc.poll() is None:
new_stdout, new_stderr = '', '' new_stdout, new_stderr = self.capture_subprocess_output(proc)
timeout = Timeout(1.0) job_updates = {}
while timeout:
# FIXME: Probably want to use poll (when on Linux), needs to be tested.
if hasattr(select, 'poll') and False:
poll = select.poll()
poll.register(proc.stdout.fileno(), select.POLLIN or select.POLLPRI)
poll.register(proc.stderr.fileno(), select.POLLIN or select.POLLPRI)
fd_events = poll.poll(1.0)
if not fd_events:
break
for fd, evt in fd_events:
if fd == proc.stdout.fileno() and evt > 0:
new_stdout += proc.stdout.read(1)
elif fd == proc.stderr.fileno() and evt > 0:
new_stderr += proc.stderr.read(1)
else:
stdout_byte, stderr_byte = '', ''
fdlist = [proc.stdout.fileno(), proc.stderr.fileno()]
rwx = select.select(fdlist, [], [], timeout.remaining)
if proc.stdout.fileno() in rwx[0]:
stdout_byte = proc.stdout.read(1)
new_stdout += stdout_byte
if proc.stderr.fileno() in rwx[0]:
stderr_byte = proc.stderr.read(1)
new_stderr += stderr_byte
if not stdout_byte and not stderr_byte:
break
job = Job.objects.get(pk=job_pk)
update_fields = []
if new_stdout: if new_stdout:
stdout += new_stdout stdout += new_stdout
job.result_stdout = stdout job_updates['result_stdout'] = stdout
update_fields.append('result_stdout')
if new_stderr: if new_stderr:
stderr += new_stderr stderr += new_stderr
job.result_stderr = stderr job_updates['result_stdout'] = stdout
update_fields.append('result_stderr') job = self.update_job(job_pk, **job_updates)
if update_fields:
job.save(update_fields=update_fields)
proc.poll()
if job.cancel_flag and not proc_canceled: if job.cancel_flag and not proc_canceled:
proc.terminate() proc.terminate()
proc_canceled = True proc_canceled = True
@@ -187,14 +185,69 @@ def run_job(job_pk, **kwargs):
status = 'successful' status = 'successful'
else: else:
status = 'failed' status = 'failed'
except Exception: return status, stdout, stderr
tb = traceback.format_exc()
def run_pexpect(self, job_pk, args, cwd, env, passwords):
# Reload from database before updating/saving. '''
job = Job.objects.get(pk=job_pk) Run the job using pexpect to capture output and provide passwords when
job.status = status requested.
job.result_stdout = stdout '''
job.result_stderr = stderr status, stdout, stderr = 'error', '', ''
job.result_traceback = tb logfile = cStringIO.StringIO()
job.save(update_fields=['status', 'result_stdout', 'result_stderr', logfile_pos = logfile.tell()
'result_traceback']) child = pexpect.spawn(args[0], args[1:], cwd=cwd, env=env)
child.logfile_read = logfile
job_canceled = False
while child.isalive():
expect_list = [
r'Enter passphrase for .*:',
r'Bad passphrase, try again for .*:',
r'sudo password.*:',
r'SSH password:',
pexpect.TIMEOUT,
pexpect.EOF,
]
result_id = child.expect(expect_list, timeout=2)
if result_id == 0:
child.sendline(passwords.get('ssh_unlock_key', ''))
elif result_id == 1:
child.sendline('')
elif result_id == 2:
child.sendline(passwords.get('sudo_password', ''))
elif result_id == 3:
child.sendline(passwords.get('ssh_password', ''))
job_updates = {}
if logfile_pos != logfile.tell():
job_updates['result_stdout'] = logfile.getvalue()
job = self.update_job(job_pk, **job_updates)
if job.cancel_flag:
child.close(True)
job_canceled = True
if job_canceled:
status = 'canceled'
elif child.exitstatus == 0:
status = 'successful'
else:
status = 'failed'
stdout = logfile.getvalue()
return status, stdout, stderr
def run(self, job_pk, **kwargs):
'''
Run the job using ansible-playbook and capture its output.
'''
job = self.update_job(job_pk, status='running')
try:
status, stdout, stderr, tb = 'error', '', '', ''
args = self.build_args(job, **kwargs)
cwd = job.project.local_path
env = self.build_env(job, **kwargs)
passwords = self.build_passwords(job, **kwargs)
#status, stdout, stderr = self.run_subprocess(job_pk, args, cwd,
# env, passwords)
status, stdout, stderr = self.run_pexpect(job_pk, args, cwd, env,
passwords)
except Exception:
tb = traceback.format_exc()
self.update_job(job_pk, status=status, result_stdout=stdout,
result_stderr=stderr, result_traceback=tb)

View File

@@ -78,10 +78,11 @@ class RunJobTest(BaseCeleryTest):
opts = { opts = {
'name': 'test-creds', 'name': 'test-creds',
'user': self.super_django_user, 'user': self.super_django_user,
'default_username': '', 'ssh_username': '',
'ssh_key_data': '', 'ssh_key_data': '',
'ssh_key_unlock': '', 'ssh_key_unlock': '',
'ssh_password': '', 'ssh_password': '',
'sudo_username': '',
'sudo_password': '', 'sudo_password': '',
} }
opts.update(kwargs) opts.update(kwargs)

View File

@@ -5,6 +5,7 @@ django-extensions==1.1.1
django-jsonfield==0.9.2 django-jsonfield==0.9.2
ipython==0.13.1 ipython==0.13.1
paramiko==1.10.0 paramiko==1.10.0
pexpect==2.4
# psycopg2==2.4.6 # psycopg2==2.4.6
python-dateutil==1.5 python-dateutil==1.5
PyYAML==3.10 PyYAML==3.10