AC-1071 Moved credential model to its own file. Added API support and tests for ssh_key_path field.

AC-1095 Added validation for SSH private keys.
This commit is contained in:
Chris Church
2014-03-26 16:05:05 -04:00
parent b47aed5bdb
commit bfb0159083
10 changed files with 491 additions and 286 deletions

View File

@@ -985,7 +985,7 @@ class CredentialSerializer(BaseSerializer):
class Meta: class Meta:
model = Credential model = Credential
fields = ('*', 'user', 'team', 'kind', 'cloud', 'username', fields = ('*', 'user', 'team', 'kind', 'cloud', 'username',
'password', 'ssh_key_data', 'ssh_key_unlock', 'password', 'ssh_key_data', 'ssh_key_path', 'ssh_key_unlock',
'sudo_username', 'sudo_password', 'vault_password') 'sudo_username', 'sudo_password', 'vault_password')
def to_native(self, obj): def to_native(self, obj):

View File

@@ -8,6 +8,7 @@ from django.conf import settings
from awx.main.models.base import * from awx.main.models.base import *
from awx.main.models.unified_jobs import * from awx.main.models.unified_jobs import *
from awx.main.models.organization import * from awx.main.models.organization import *
from awx.main.models.credential import *
from awx.main.models.projects import * from awx.main.models.projects import *
from awx.main.models.inventory import * from awx.main.models.inventory import *
from awx.main.models.jobs import * from awx.main.models.jobs import *

View File

@@ -0,0 +1,342 @@
# Copyright (c) 2014 AnsibleWorks, Inc.
# All Rights Reserved.
# Python
import base64
import re
# Django
from django.conf import settings
from django.db import models
from django.utils.translation import ugettext_lazy as _
from django.core.exceptions import ValidationError, NON_FIELD_ERRORS
from django.core.urlresolvers import reverse
# AWX
from awx.main.utils import encrypt_field, decrypt_field
from awx.main.models.base import *
__all__ = ['Credential']
class Credential(CommonModelNameNotUnique):
'''
A credential contains information about how to talk to a remote resource
Usually this is a SSH key location, and possibly an unlock password.
If used with sudo, a sudo password should be set if required.
'''
KIND_CHOICES = [
('ssh', _('Machine')),
('scm', _('SCM')),
('aws', _('AWS')),
('rax', _('Rackspace')),
]
PASSWORD_FIELDS = ('password', 'ssh_key_data', 'ssh_key_unlock',
'sudo_password', 'vault_password')
class Meta:
app_label = 'main'
unique_together = [('user', 'team', 'kind', 'name')]
user = models.ForeignKey(
'auth.User',
null=True,
default=None,
blank=True,
on_delete=models.CASCADE,
related_name='credentials',
)
team = models.ForeignKey(
'Team',
null=True,
default=None,
blank=True,
on_delete=models.CASCADE,
related_name='credentials',
)
kind = models.CharField(
max_length=32,
choices=KIND_CHOICES,
default='ssh',
)
cloud = models.BooleanField(
default=False,
editable=False,
)
username = models.CharField(
blank=True,
default='',
max_length=1024,
verbose_name=_('Username'),
help_text=_('Username for this credential.'),
)
password = models.CharField(
blank=True,
default='',
max_length=1024,
verbose_name=_('Password'),
help_text=_('Password for this credential (or "ASK" to prompt the '
'user for machine credentials).'),
)
ssh_key_data = models.TextField(
blank=True,
default='',
verbose_name=_('SSH private key'),
help_text=_('RSA or DSA private key to be used instead of password.'),
)
ssh_key_path = models.CharField(
max_length=1024,
blank=True,
default='',
verbose_name=_('SSH key path'),
help_text=_('Path to SSH private key file.'),
)
ssh_key_unlock = models.CharField(
max_length=1024,
blank=True,
default='',
verbose_name=_('SSH key unlock'),
help_text=_('Passphrase to unlock SSH private key if encrypted (or '
'"ASK" to prompt the user for machine credentials).'),
)
sudo_username = models.CharField(
max_length=1024,
blank=True,
default='',
help_text=_('Sudo username for a job using this credential.'),
)
sudo_password = models.CharField(
max_length=1024,
blank=True,
default='',
help_text=_('Sudo password (or "ASK" to prompt the user).'),
)
vault_password = models.CharField(
max_length=1024,
blank=True,
default='',
help_text=_('Vault password (or "ASK" to prompt the user).'),
)
@property
def needs_password(self):
return self.kind == 'ssh' and self.password == 'ASK'
@property
def needs_ssh_key_unlock(self):
ssh_key_data = ''
if self.kind == 'ssh' and self.ssh_key_unlock == 'ASK':
if self.ssh_key_data:
if self.pk:
ssh_key_data = decrypt_field(self, 'ssh_key_data')
else:
ssh_key_data = self.ssh_key_data
elif self.ssh_key_path:
try:
ssh_key_data = file(self.ssh_key_path).read(2**15)
except IOError:
pass
return 'ENCRYPTED' in ssh_key_data
@property
def needs_sudo_password(self):
return self.kind == 'ssh' and self.sudo_password == 'ASK'
@property
def needs_vault_password(self):
return self.kind == 'ssh' and self.vault_password == 'ASK'
@property
def passwords_needed(self):
needed = []
for field in ('password', 'sudo_password', 'ssh_key_unlock', 'vault_password'):
if getattr(self, 'needs_%s' % field):
needed.append(field)
return needed
def get_absolute_url(self):
return reverse('api:credential_detail', args=(self.pk,))
def clean_username(self):
username = self.username or ''
if not username and self.kind == 'aws':
raise ValidationError('Access key required for "aws" credential')
if not username and self.kind == 'rax':
raise ValidationError('Username required for "rax" credential')
return username
def clean_password(self):
password = self.password or ''
if not password and self.kind == 'aws':
raise ValidationError('Secret key required for "aws" credential')
if not password and self.kind == 'rax':
raise ValidationError('API key required for "rax" credential')
return password
def _validate_ssh_private_key(self, data):
validation_error = ValidationError('Invalid SSH private key')
begin_re = re.compile(r'^(-{4,})\s*?BEGIN\s([A-Z0-9]+?)\sPRIVATE\sKEY\s*?(-{4,})$')
header_re = re.compile(r'^(.+?):\s*?(.+?)(\\??)$')
end_re = re.compile(r'^(-{4,})\s*?END\s([A-Z0-9]+?)\sPRIVATE\sKEY\s*?(-{4,})$')
lines = data.strip().splitlines()
if not lines:
raise validation_error
begin_match = begin_re.match(lines[0])
end_match = end_re.match(lines[-1])
if not begin_match or not end_match:
raise validation_error
dashes = set([begin_match.groups()[0], begin_match.groups()[2],
end_match.groups()[0], end_match.groups()[2]])
if len(dashes) != 1:
raise validation_error
if begin_match.groups()[1] != end_match.groups()[1]:
raise validation_error
line_continues = False
base64_data = ''
for line in lines[1:-1]:
line = line.strip()
if not line:
continue
if line_continues:
line_continues = line.endswith('\\')
continue
line_match = header_re.match(line)
if line_match:
line_continues = line.endswith('\\')
continue
base64_data += line
try:
decoded_data = base64.b64decode(base64_data)
if not decoded_data:
raise validation_error
except TypeError:
raise validation_error
def clean_ssh_key_data(self):
if self.pk:
ssh_key_data = decrypt_field(self, 'ssh_key_data')
else:
ssh_key_data = self.ssh_key_data
if ssh_key_data:
self._validate_ssh_private_key(ssh_key_data)
return self.ssh_key_data # No need to return decrypted version here.
def clean_ssh_key_path(self):
ssh_key_path = self.ssh_key_path or ''
if ssh_key_path:
try:
ssh_key_data = file(ssh_key_path).read(2**15)
except IOError, e:
raise ValidationError(e.strerror or 'Unable to read SSH key path')
self._validate_ssh_private_key(ssh_key_data)
return ssh_key_path
def clean_ssh_key_unlock(self):
ssh_key_data = ''
if self.ssh_key_data:
if self.pk:
ssh_key_data = decrypt_field(self, 'ssh_key_data')
else:
ssh_key_data = self.ssh_key_data
elif self.ssh_key_path:
try:
ssh_key_data = file(self.ssh_key_path).read(2**15)
except IOError:
pass
if 'ENCRYPTED' in ssh_key_data and not self.ssh_key_unlock:
raise ValidationError('SSH key unlock must be set when SSH key '
'is encrypted')
return self.ssh_key_unlock
def clean(self):
if self.user and self.team:
raise ValidationError('Credential cannot be assigned to both a user and team')
if self.ssh_key_data and self.ssh_key_path:
raise ValidationError('Only one of SSH key data or path should be provided')
def _validate_unique_together_with_null(self, unique_check, exclude=None):
# Based on existing Django model validation code, except it doesn't
# skip the check for unique violations when a field is None. See:
# https://github.com/django/django/blob/stable/1.5.x/django/db/models/base.py#L792
errors = {}
model_class = self.__class__
if set(exclude or []) & set(unique_check):
return
lookup_kwargs = {}
for field_name in unique_check:
f = self._meta.get_field(field_name)
lookup_value = getattr(self, f.attname)
if f.primary_key and not self._state.adding:
# no need to check for unique primary key when editing
continue
lookup_kwargs[str(field_name)] = lookup_value
if len(unique_check) != len(lookup_kwargs):
return
qs = model_class._default_manager.filter(**lookup_kwargs)
# Exclude the current object from the query if we are editing an
# instance (as opposed to creating a new one)
# Note that we need to use the pk as defined by model_class, not
# self.pk. These can be different fields because model inheritance
# allows single model to have effectively multiple primary keys.
# Refs #17615.
model_class_pk = self._get_pk_val(model_class._meta)
if not self._state.adding and model_class_pk is not None:
qs = qs.exclude(pk=model_class_pk)
if qs.exists():
key = NON_FIELD_ERRORS
errors.setdefault(key, []).append( \
self.unique_error_message(model_class, unique_check))
if errors:
raise ValidationError(errors)
def validate_unique(self, exclude=None):
errors = {}
try:
super(Credential, self).validate_unique(exclude)
except ValidationError, e:
errors = e.update_error_dict(errors)
try:
unique_fields = ('user', 'team', 'kind', 'name')
self._validate_unique_together_with_null(unique_fields, exclude)
except ValidationError, e:
errors = e.update_error_dict(errors)
if errors:
raise ValidationError(errors)
def save(self, *args, **kwargs):
new_instance = not bool(self.pk)
update_fields = kwargs.get('update_fields', [])
# When first saving to the database, don't store any password field
# values, but instead save them until after the instance is created.
if new_instance:
for field in self.PASSWORD_FIELDS:
value = getattr(self, field, '')
setattr(self, '_saved_%s' % field, value)
setattr(self, field, '')
# Otherwise, store encrypted values to the database.
else:
# If update_fields has been specified, add our field names to it,
# if hit hasn't been specified, then we're just doing a normal save.
for field in self.PASSWORD_FIELDS:
ask = bool(self.kind == 'ssh' and field != 'ssh_key_data')
encrypted = encrypt_field(self, field, ask)
setattr(self, field, encrypted)
if field not in update_fields:
update_fields.append(field)
cloud = self.kind in ('aws', 'rax')
if self.cloud != cloud:
self.cloud = cloud
if 'cloud' not in update_fields:
update_fields.append('cloud')
super(Credential, self).save(*args, **kwargs)
# After saving a new instance for the first time, set the password
# fields and save again.
if new_instance:
update_fields=[]
for field in self.PASSWORD_FIELDS:
saved_value = getattr(self, '_saved_%s' % field, '')
setattr(self, field, saved_value)
update_fields.append(field)
self.save(update_fields=update_fields)

View File

@@ -5,36 +5,21 @@
import datetime import datetime
import hashlib import hashlib
import hmac import hmac
import json
import logging
import os
import re
import shlex
import uuid import uuid
# PyYAML
import yaml
# Django # Django
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.core.exceptions import ValidationError, NON_FIELD_ERRORS
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.utils.timezone import now, make_aware, get_default_timezone from django.utils.timezone import now
# Django-JSONField
from jsonfield import JSONField
# AWX # AWX
from awx.lib.compat import slugify
from awx.main.fields import AutoOneToOneField from awx.main.fields import AutoOneToOneField
from awx.main.utils import encrypt_field, decrypt_field
from awx.main.models.base import * from awx.main.models.base import *
__all__ = ['Organization', 'Team', 'Permission', 'Credential', 'Profile', __all__ = ['Organization', 'Team', 'Permission', 'Profile', 'AuthToken']
'AuthToken']
class Organization(CommonModel): class Organization(CommonModel):
@@ -149,251 +134,6 @@ class Permission(CommonModelNameNotUnique):
return reverse('api:permission_detail', args=(self.pk,)) return reverse('api:permission_detail', args=(self.pk,))
class Credential(CommonModelNameNotUnique):
'''
A credential contains information about how to talk to a remote resource
Usually this is a SSH key location, and possibly an unlock password.
If used with sudo, a sudo password should be set if required.
'''
KIND_CHOICES = [
('ssh', _('Machine')),
('scm', _('SCM')),
('aws', _('AWS')),
('rax', _('Rackspace')),
]
PASSWORD_FIELDS = ('password', 'ssh_key_data', 'ssh_key_unlock',
'sudo_password', 'vault_password')
class Meta:
app_label = 'main'
unique_together = [('user', 'team', 'kind', 'name')]
user = models.ForeignKey(
'auth.User',
null=True,
default=None,
blank=True,
on_delete=models.CASCADE,
related_name='credentials',
)
team = models.ForeignKey(
'Team',
null=True,
default=None,
blank=True,
on_delete=models.CASCADE,
related_name='credentials',
)
kind = models.CharField(
max_length=32,
choices=KIND_CHOICES,
default='ssh',
)
cloud = models.BooleanField(
default=False,
editable=False,
)
username = models.CharField(
blank=True,
default='',
max_length=1024,
verbose_name=_('Username'),
help_text=_('Username for this credential.'),
)
password = models.CharField(
blank=True,
default='',
max_length=1024,
verbose_name=_('Password'),
help_text=_('Password for this credential (or "ASK" to prompt the '
'user for machine credentials).'),
)
ssh_key_data = models.TextField(
blank=True,
default='',
verbose_name=_('SSH private key'),
help_text=_('RSA or DSA private key to be used instead of password.'),
)
ssh_key_path = models.CharField(
max_length=1024,
blank=True,
default='',
verbose_name=_('SSH key path'),
help_text=_('Path to SSH private key file.'),
)
ssh_key_unlock = models.CharField(
max_length=1024,
blank=True,
default='',
verbose_name=_('SSH key unlock'),
help_text=_('Passphrase to unlock SSH private key if encrypted (or '
'"ASK" to prompt the user for machine credentials).'),
)
sudo_username = models.CharField(
max_length=1024,
blank=True,
default='',
help_text=_('Sudo username for a job using this credential.'),
)
sudo_password = models.CharField(
max_length=1024,
blank=True,
default='',
help_text=_('Sudo password (or "ASK" to prompt the user).'),
)
vault_password = models.CharField(
max_length=1024,
blank=True,
default='',
help_text=_('Vault password (or "ASK" to prompt the user).'),
)
@property
def needs_password(self):
return self.kind == 'ssh' and self.password == 'ASK'
@property
def needs_ssh_key_unlock(self):
return self.kind == 'ssh' and self.ssh_key_unlock == 'ASK' and \
'ENCRYPTED' in decrypt_field(self, 'ssh_key_data') # FIXME: Support ssh_key_path
@property
def needs_sudo_password(self):
return self.kind == 'ssh' and self.sudo_password == 'ASK'
@property
def needs_vault_password(self):
return self.kind == 'ssh' and self.vault_password == 'ASK'
@property
def passwords_needed(self):
needed = []
for field in ('password', 'sudo_password', 'ssh_key_unlock', 'vault_password'):
if getattr(self, 'needs_%s' % field):
needed.append(field)
return needed
def get_absolute_url(self):
return reverse('api:credential_detail', args=(self.pk,))
def clean_username(self):
username = self.username or ''
if not username and self.kind == 'aws':
raise ValidationError('Access key required for "aws" credential')
if not username and self.kind == 'rax':
raise ValidationError('Username required for "rax" credential')
return username
def clean_password(self):
password = self.password or ''
if not password and self.kind == 'aws':
raise ValidationError('Secret key required for "aws" credential')
if not password and self.kind == 'rax':
raise ValidationError('API key required for "rax" credential')
return password
def clean_ssh_key_unlock(self):
if self.pk:
ssh_key_data = decrypt_field(self, 'ssh_key_data')
else:
ssh_key_data = self.ssh_key_data
if 'ENCRYPTED' in ssh_key_data and not self.ssh_key_unlock:
raise ValidationError('SSH key unlock must be set when SSH key '
'data is encrypted')
return self.ssh_key_unlock
def clean(self):
if self.user and self.team:
raise ValidationError('Credential cannot be assigned to both a user and team')
def _validate_unique_together_with_null(self, unique_check, exclude=None):
# Based on existing Django model validation code, except it doesn't
# skip the check for unique violations when a field is None. See:
# https://github.com/django/django/blob/stable/1.5.x/django/db/models/base.py#L792
errors = {}
model_class = self.__class__
if set(exclude or []) & set(unique_check):
return
lookup_kwargs = {}
for field_name in unique_check:
f = self._meta.get_field(field_name)
lookup_value = getattr(self, f.attname)
if f.primary_key and not self._state.adding:
# no need to check for unique primary key when editing
continue
lookup_kwargs[str(field_name)] = lookup_value
if len(unique_check) != len(lookup_kwargs):
return
qs = model_class._default_manager.filter(**lookup_kwargs)
# Exclude the current object from the query if we are editing an
# instance (as opposed to creating a new one)
# Note that we need to use the pk as defined by model_class, not
# self.pk. These can be different fields because model inheritance
# allows single model to have effectively multiple primary keys.
# Refs #17615.
model_class_pk = self._get_pk_val(model_class._meta)
if not self._state.adding and model_class_pk is not None:
qs = qs.exclude(pk=model_class_pk)
if qs.exists():
key = NON_FIELD_ERRORS
errors.setdefault(key, []).append( \
self.unique_error_message(model_class, unique_check))
if errors:
raise ValidationError(errors)
def validate_unique(self, exclude=None):
errors = {}
try:
super(Credential, self).validate_unique(exclude)
except ValidationError, e:
errors = e.update_error_dict(errors)
try:
unique_fields = ('user', 'team', 'kind', 'name')
self._validate_unique_together_with_null(unique_fields, exclude)
except ValidationError, e:
errors = e.update_error_dict(errors)
if errors:
raise ValidationError(errors)
def save(self, *args, **kwargs):
new_instance = not bool(self.pk)
update_fields = kwargs.get('update_fields', [])
# When first saving to the database, don't store any password field
# values, but instead save them until after the instance is created.
if new_instance:
for field in self.PASSWORD_FIELDS:
value = getattr(self, field, '')
setattr(self, '_saved_%s' % field, value)
setattr(self, field, '')
# Otherwise, store encrypted values to the database.
else:
# If update_fields has been specified, add our field names to it,
# if hit hasn't been specified, then we're just doing a normal save.
for field in self.PASSWORD_FIELDS:
ask = bool(self.kind == 'ssh' and field != 'ssh_key_data')
encrypted = encrypt_field(self, field, ask)
setattr(self, field, encrypted)
if field not in update_fields:
update_fields.append(field)
cloud = self.kind in ('aws', 'rax')
if self.cloud != cloud:
self.cloud = cloud
if 'cloud' not in update_fields:
update_fields.append('cloud')
super(Credential, self).save(*args, **kwargs)
# After saving a new instance for the first time, set the password
# fields and save again.
if new_instance:
update_fields=[]
for field in self.PASSWORD_FIELDS:
saved_value = getattr(self, '_saved_%s' % field, '')
setattr(self, field, saved_value)
update_fields.append(field)
self.save(update_fields=update_fields)
class Profile(CreatedModifiedModel): class Profile(CreatedModifiedModel):
''' '''
Profile model related to User object. Currently stores LDAP DN for users Profile model related to User object. Currently stores LDAP DN for users

View File

@@ -371,10 +371,11 @@ class RunJob(BaseTask):
def build_private_data(self, job, **kwargs): def build_private_data(self, job, **kwargs):
''' '''
Return SSH private key data needed for this job. Return SSH private key data needed for this job (only if stored in DB
as ssh_key_data).
''' '''
credential = getattr(job, 'credential', None) credential = getattr(job, 'credential', None)
if credential: if credential and credential.ssh_key_data:
return decrypt_field(credential, 'ssh_key_data') or None return decrypt_field(credential, 'ssh_key_data') or None
def build_passwords(self, job, **kwargs): def build_passwords(self, job, **kwargs):
@@ -472,6 +473,13 @@ class RunJob(BaseTask):
except ValueError: except ValueError:
pass pass
# If private key isn't encrypted, pass the path on the command line.
ssh_key_path = kwargs.get('private_data_file', '')
ssh_key_path = ssh_key_path or (creds and creds.ssh_key_path) or ''
use_ssh_agent = 'ssh_key_unlock' in kwargs.get('passwords', {})
if ssh_key_path and not use_ssh_agent:
args.append('--private-key=%s' % ssh_key_path)
if job.forks: # FIXME: Max limit? if job.forks: # FIXME: Max limit?
args.append('--forks=%d' % job.forks) args.append('--forks=%d' % job.forks)
if job.limit: if job.limit:
@@ -483,11 +491,13 @@ class RunJob(BaseTask):
if job.job_tags: if job.job_tags:
args.extend(['-t', job.job_tags]) args.extend(['-t', job.job_tags])
args.append(job.playbook) # relative path to project.local_path args.append(job.playbook) # relative path to project.local_path
ssh_key_path = kwargs.get('private_data_file', '')
if ssh_key_path: # If ssh unlock password is needed, run using ssh-agent.
if ssh_key_path and use_ssh_agent:
cmd = ' '.join([self.args2cmdline('ssh-add', ssh_key_path), cmd = ' '.join([self.args2cmdline('ssh-add', ssh_key_path),
'&&', self.args2cmdline(*args)]) '&&', self.args2cmdline(*args)])
args = ['ssh-agent', 'sh', '-c', cmd] args = ['ssh-agent', 'sh', '-c', cmd]
return args return args
def build_cwd(self, job, **kwargs): def build_cwd(self, job, **kwargs):

View File

@@ -38,7 +38,7 @@ class BaseTestMixin(object):
def setUp(self): def setUp(self):
super(BaseTestMixin, self).setUp() super(BaseTestMixin, self).setUp()
self.object_ctr = 0 self.object_ctr = 0
self._temp_project_dirs = [] self._temp_paths = []
self._current_auth = None self._current_auth = None
self._user_passwords = {} self._user_passwords = {}
self.ansible_version = get_ansible_version() self.ansible_version = get_ansible_version()
@@ -63,18 +63,18 @@ class BaseTestMixin(object):
callback_port = random.randint(55700, 55799) callback_port = random.randint(55700, 55799)
settings.CALLBACK_CONSUMER_PORT = 'tcp://127.0.0.1:%d' % callback_port settings.CALLBACK_CONSUMER_PORT = 'tcp://127.0.0.1:%d' % callback_port
callback_queue_path = '/tmp/callback_receiver_test_%d.ipc' % callback_port callback_queue_path = '/tmp/callback_receiver_test_%d.ipc' % callback_port
self._temp_project_dirs.append(callback_queue_path) self._temp_paths.append(callback_queue_path)
settings.CALLBACK_QUEUE_PORT = 'ipc://%s' % callback_queue_path settings.CALLBACK_QUEUE_PORT = 'ipc://%s' % callback_queue_path
settings.TASK_COMMAND_PORT = 'ipc:///tmp/task_command_receiver_%d.ipc' % callback_port settings.TASK_COMMAND_PORT = 'ipc:///tmp/task_command_receiver_%d.ipc' % callback_port
# Make temp job status directory for unit tests. # Make temp job status directory for unit tests.
job_status_dir = tempfile.mkdtemp() job_status_dir = tempfile.mkdtemp()
self._temp_project_dirs.append(job_status_dir) self._temp_paths.append(job_status_dir)
settings.JOBOUTPUT_ROOT = os.path.abspath(job_status_dir) settings.JOBOUTPUT_ROOT = os.path.abspath(job_status_dir)
self._start_time = time.time() self._start_time = time.time()
def tearDown(self): def tearDown(self):
super(BaseTestMixin, self).tearDown() super(BaseTestMixin, self).tearDown()
for project_dir in self._temp_project_dirs: for project_dir in self._temp_paths:
if os.path.exists(project_dir): if os.path.exists(project_dir):
if os.path.isdir(project_dir): if os.path.isdir(project_dir):
shutil.rmtree(project_dir, True) shutil.rmtree(project_dir, True)
@@ -131,7 +131,7 @@ class BaseTestMixin(object):
os.makedirs(settings.PROJECTS_ROOT) os.makedirs(settings.PROJECTS_ROOT)
# Create temp project directory. # Create temp project directory.
project_dir = tempfile.mkdtemp(dir=settings.PROJECTS_ROOT) project_dir = tempfile.mkdtemp(dir=settings.PROJECTS_ROOT)
self._temp_project_dirs.append(project_dir) self._temp_paths.append(project_dir)
# Create temp playbook in project (if playbook content is given). # Create temp playbook in project (if playbook content is given).
if playbook_content: if playbook_content:
handle, playbook_path = tempfile.mkstemp(suffix='.yml', handle, playbook_path = tempfile.mkstemp(suffix='.yml',

View File

@@ -463,7 +463,7 @@ class InventoryImportTest(BaseCommandMixin, BaseLiveServerTest):
def create_test_dir(self, hostnames=None): def create_test_dir(self, hostnames=None):
hostnames = hostnames or [] hostnames = hostnames or []
self.inv_dir = tempfile.mkdtemp() self.inv_dir = tempfile.mkdtemp()
self._temp_project_dirs.append(self.inv_dir) self._temp_paths.append(self.inv_dir)
self.create_test_ini(self.inv_dir) self.create_test_ini(self.inv_dir)
group_vars = os.path.join(self.inv_dir, 'group_vars') group_vars = os.path.join(self.inv_dir, 'group_vars')
os.makedirs(group_vars) os.makedirs(group_vars)

View File

@@ -24,7 +24,7 @@ from django.utils.timezone import now
# AWX # AWX
from awx.main.models import * from awx.main.models import *
from awx.main.tests.base import BaseTest, BaseTransactionTest from awx.main.tests.base import BaseTest, BaseTransactionTest
from awx.main.tests.tasks import TEST_SSH_KEY_DATA_LOCKED, TEST_SSH_KEY_DATA_UNLOCK from awx.main.tests.tasks import TEST_SSH_KEY_DATA, TEST_SSH_KEY_DATA_LOCKED, TEST_SSH_KEY_DATA_UNLOCK
from awx.main.utils import decrypt_field, update_scm_url from awx.main.utils import decrypt_field, update_scm_url
TEST_PLAYBOOK = '''- hosts: mygroup TEST_PLAYBOOK = '''- hosts: mygroup
@@ -221,7 +221,7 @@ class ProjectsTest(BaseTest):
# can add projects (super user) # can add projects (super user)
project_dir = tempfile.mkdtemp(dir=settings.PROJECTS_ROOT) project_dir = tempfile.mkdtemp(dir=settings.PROJECTS_ROOT)
self._temp_project_dirs.append(project_dir) self._temp_paths.append(project_dir)
project_data = { project_data = {
'name': 'My Test Project', 'name': 'My Test Project',
'description': 'Does amazing things', 'description': 'Does amazing things',
@@ -452,8 +452,8 @@ class ProjectsTest(BaseTest):
name = 'credential', name = 'credential',
project = Project.objects.order_by('pk')[0].pk, project = Project.objects.order_by('pk')[0].pk,
default_username = 'foo', default_username = 'foo',
ssh_key_data = 'bar', ssh_key_data = TEST_SSH_KEY_DATA_LOCKED,
ssh_key_unlock = 'baz', ssh_key_unlock = TEST_SSH_KEY_DATA_UNLOCK,
ssh_password = 'narf', ssh_password = 'narf',
sudo_password = 'troz' sudo_password = 'troz'
) )
@@ -532,6 +532,54 @@ class ProjectsTest(BaseTest):
data['ssh_key_unlock'] = TEST_SSH_KEY_DATA_UNLOCK data['ssh_key_unlock'] = TEST_SSH_KEY_DATA_UNLOCK
self.post(url, data, expect=201) self.post(url, data, expect=201)
# Test with invalid ssh key data.
with self.current_user(self.super_django_user):
bad_key_data = TEST_SSH_KEY_DATA.replace('PRIVATE', 'PUBLIC')
data = dict(name='wyx', user=self.super_django_user.pk, kind='ssh',
ssh_key_data=bad_key_data)
self.post(url, data, expect=400)
data['ssh_key_data'] = TEST_SSH_KEY_DATA.replace('-', '=')
self.post(url, data, expect=400)
data['ssh_key_data'] = '\n'.join(TEST_SSH_KEY_DATA.splitlines()[1:-1])
self.post(url, data, expect=400)
data['ssh_key_data'] = TEST_SSH_KEY_DATA.replace('--B', '---B')
self.post(url, data, expect=400)
data['ssh_key_data'] = TEST_SSH_KEY_DATA
self.post(url, data, expect=201)
# Test with ssh_key_path (invalid path, bad data, then valid key).
handle, ssh_key_path = tempfile.mkstemp(suffix='.key')
self._temp_paths.append(ssh_key_path)
ssh_key_file = os.fdopen(handle, 'w')
ssh_key_file.write(TEST_SSH_KEY_DATA)
ssh_key_file.close()
handle, invalid_ssh_key_path = tempfile.mkstemp(suffix='.key')
self._temp_paths.append(invalid_ssh_key_path)
invalid_ssh_key_file = os.fdopen(handle, 'w')
invalid_ssh_key_file.write('not a valid key')
invalid_ssh_key_file.close()
with self.current_user(self.super_django_user):
data = dict(name='yzv', user=self.super_django_user.pk, kind='ssh',
ssh_key_path=ssh_key_path + '.moo')
self.post(url, data, expect=400)
data['ssh_key_path'] = invalid_ssh_key_path
self.post(url, data, expect=400)
data['ssh_key_path'] = ssh_key_path
self.post(url, data, expect=201)
# Test with encrypted key on ssh_key_path.
handle, enc_ssh_key_path = tempfile.mkstemp(suffix='.key')
self._temp_paths.append(enc_ssh_key_path)
enc_ssh_key_file = os.fdopen(handle, 'w')
enc_ssh_key_file.write(TEST_SSH_KEY_DATA_LOCKED)
enc_ssh_key_file.close()
with self.current_user(self.super_django_user):
data = dict(name='wvz', user=self.super_django_user.pk, kind='ssh',
ssh_key_path=enc_ssh_key_path)
self.post(url, data, expect=400)
data['ssh_key_unlock'] = TEST_SSH_KEY_DATA_UNLOCK
self.post(url, data, expect=201)
# Test post as organization admin where team is part of org, but user # Test post as organization admin where team is part of org, but user
# creating credential is not a member of the team. UI may pass user # creating credential is not a member of the team. UI may pass user
# as an empty string instead of None. # as an empty string instead of None.
@@ -719,7 +767,7 @@ class ProjectUpdatesTest(BaseTransactionTest):
kwargs['credential'] = credential kwargs['credential'] = credential
project = Project.objects.create(**kwargs) project = Project.objects.create(**kwargs)
project_path = project.get_project_path(check_if_exists=False) project_path = project.get_project_path(check_if_exists=False)
self._temp_project_dirs.append(project_path) self._temp_paths.append(project_path)
return project return project
def test_update_scm_url(self): def test_update_scm_url(self):
@@ -1313,7 +1361,7 @@ class ProjectUpdatesTest(BaseTransactionTest):
def create_local_git_repo(self): def create_local_git_repo(self):
repo_dir = tempfile.mkdtemp() repo_dir = tempfile.mkdtemp()
self._temp_project_dirs.append(repo_dir) self._temp_paths.append(repo_dir)
handle, playbook_path = tempfile.mkstemp(suffix='.yml', dir=repo_dir) handle, playbook_path = tempfile.mkstemp(suffix='.yml', dir=repo_dir)
test_playbook_file = os.fdopen(handle, 'w') test_playbook_file = os.fdopen(handle, 'w')
test_playbook_file.write(TEST_PLAYBOOK) test_playbook_file.write(TEST_PLAYBOOK)
@@ -1408,7 +1456,7 @@ class ProjectUpdatesTest(BaseTransactionTest):
def create_local_hg_repo(self): def create_local_hg_repo(self):
repo_dir = tempfile.mkdtemp() repo_dir = tempfile.mkdtemp()
self._temp_project_dirs.append(repo_dir) self._temp_paths.append(repo_dir)
handle, playbook_path = tempfile.mkstemp(suffix='.yml', dir=repo_dir) handle, playbook_path = tempfile.mkstemp(suffix='.yml', dir=repo_dir)
test_playbook_file = os.fdopen(handle, 'w') test_playbook_file = os.fdopen(handle, 'w')
test_playbook_file.write(TEST_PLAYBOOK) test_playbook_file.write(TEST_PLAYBOOK)
@@ -1477,7 +1525,7 @@ class ProjectUpdatesTest(BaseTransactionTest):
def create_local_svn_repo(self): def create_local_svn_repo(self):
repo_dir = tempfile.mkdtemp() repo_dir = tempfile.mkdtemp()
self._temp_project_dirs.append(repo_dir) self._temp_paths.append(repo_dir)
subprocess.check_call(['svnadmin', 'create', '.'], cwd=repo_dir, subprocess.check_call(['svnadmin', 'create', '.'], cwd=repo_dir,
stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout=subprocess.PIPE, stderr=subprocess.PIPE)
handle, playbook_path = tempfile.mkstemp(suffix='.yml', dir=repo_dir) handle, playbook_path = tempfile.mkstemp(suffix='.yml', dir=repo_dir)

View File

@@ -728,7 +728,7 @@ class RunJobTest(BaseCeleryTest):
self.check_job_result(job, 'failed') self.check_job_result(job, 'failed')
self.assertTrue('-l' in job.job_args) self.assertTrue('-l' in job.job_args)
def test_limit_option_with_group_pattern_and_ssh_agent(self): def test_limit_option_with_group_pattern_and_ssh_key(self):
self.create_test_credential(ssh_key_data=TEST_SSH_KEY_DATA) self.create_test_credential(ssh_key_data=TEST_SSH_KEY_DATA)
self.create_test_project(TEST_PLAYBOOK) self.create_test_project(TEST_PLAYBOOK)
job_template = self.create_test_job_template(limit='test-group:&test-group2') job_template = self.create_test_job_template(limit='test-group:&test-group2')
@@ -738,7 +738,8 @@ class RunJobTest(BaseCeleryTest):
self.assertTrue(job.signal_start()) self.assertTrue(job.signal_start())
job = Job.objects.get(pk=job.pk) job = Job.objects.get(pk=job.pk)
self.check_job_result(job, 'successful') self.check_job_result(job, 'successful')
self.assertTrue('ssh-agent' in job.job_args) self.assertTrue('--private-key=' in job.job_args)
self.assertFalse('ssh-agent' in job.job_args)
def test_ssh_username_and_password(self): def test_ssh_username_and_password(self):
self.create_test_credential(username='sshuser', password='sshpass') self.create_test_credential(username='sshuser', password='sshpass')
@@ -810,7 +811,8 @@ class RunJobTest(BaseCeleryTest):
self.assertTrue(job.signal_start()) self.assertTrue(job.signal_start())
job = Job.objects.get(pk=job.pk) job = Job.objects.get(pk=job.pk)
self.check_job_result(job, 'successful') self.check_job_result(job, 'successful')
self.assertTrue('ssh-agent' in job.job_args) self.assertTrue('--private-key=' in job.job_args)
self.assertFalse('ssh-agent' in job.job_args)
def test_locked_ssh_key_with_password(self): def test_locked_ssh_key_with_password(self):
self.create_test_credential(ssh_key_data=TEST_SSH_KEY_DATA_LOCKED, self.create_test_credential(ssh_key_data=TEST_SSH_KEY_DATA_LOCKED,
@@ -860,6 +862,68 @@ class RunJobTest(BaseCeleryTest):
self.assertTrue('ssh-agent' in job.job_args) self.assertTrue('ssh-agent' in job.job_args)
self.assertTrue('Bad passphrase' not in job.result_stdout) self.assertTrue('Bad passphrase' not in job.result_stdout)
def test_unlocked_ssh_key_path(self):
handle, ssh_key_path = tempfile.mkstemp(suffix='.key')
self._temp_paths.append(ssh_key_path)
ssh_key_file = os.fdopen(handle, 'w')
ssh_key_file.write(TEST_SSH_KEY_DATA)
ssh_key_file.close()
self.create_test_credential(ssh_key_path=ssh_key_path)
self.create_test_project(TEST_PLAYBOOK)
job_template = self.create_test_job_template()
job = self.create_test_job(job_template=job_template)
self.assertEqual(job.status, 'new')
self.assertFalse(job.passwords_needed_to_start)
self.assertTrue(job.signal_start())
job = Job.objects.get(pk=job.pk)
self.check_job_result(job, 'successful')
self.assertTrue('--private-key=' in job.job_args)
self.assertFalse('ssh-agent' in job.job_args)
def test_locked_ssh_key_path_with_password(self):
handle, ssh_key_path = tempfile.mkstemp(suffix='.key')
self._temp_paths.append(ssh_key_path)
ssh_key_file = os.fdopen(handle, 'w')
ssh_key_file.write(TEST_SSH_KEY_DATA_LOCKED)
ssh_key_file.close()
self.create_test_credential(ssh_key_path=ssh_key_path,
ssh_key_unlock=TEST_SSH_KEY_DATA_UNLOCK)
self.create_test_project(TEST_PLAYBOOK)
job_template = self.create_test_job_template()
job = self.create_test_job(job_template=job_template)
self.assertEqual(job.status, 'new')
self.assertFalse(job.passwords_needed_to_start)
self.assertTrue(job.signal_start())
job = Job.objects.get(pk=job.pk)
self.check_job_result(job, 'successful')
self.assertTrue('ssh-agent' in job.job_args)
self.assertTrue('Bad passphrase' not in job.result_stdout)
def test_locked_ssh_key_path_ask_password(self):
handle, ssh_key_path = tempfile.mkstemp(suffix='.key')
self._temp_paths.append(ssh_key_path)
ssh_key_file = os.fdopen(handle, 'w')
ssh_key_file.write(TEST_SSH_KEY_DATA_LOCKED)
ssh_key_file.close()
self.create_test_credential(ssh_key_path=ssh_key_path,
ssh_key_unlock='ASK')
self.create_test_project(TEST_PLAYBOOK)
job_template = self.create_test_job_template()
job = self.create_test_job(job_template=job_template)
self.assertEqual(job.status, 'new')
self.assertTrue(job.passwords_needed_to_start)
self.assertTrue('ssh_key_unlock' in job.passwords_needed_to_start)
self.assertFalse(job.signal_start())
job.status = 'failed'
job.save()
job = self.create_test_job(job_template=job_template)
self.assertEqual(job.status, 'new')
self.assertTrue(job.signal_start(ssh_key_unlock=TEST_SSH_KEY_DATA_UNLOCK))
job = Job.objects.get(pk=job.pk)
self.check_job_result(job, 'successful')
self.assertTrue('ssh-agent' in job.job_args)
self.assertTrue('Bad passphrase' not in job.result_stdout)
def test_vault_password(self): def test_vault_password(self):
self.create_test_credential(vault_password=TEST_VAULT_PASSWORD) self.create_test_credential(vault_password=TEST_VAULT_PASSWORD)
self.create_test_project(TEST_VAULT_PLAYBOOK) self.create_test_project(TEST_VAULT_PLAYBOOK)

View File

@@ -233,7 +233,7 @@ def model_instance_diff(old, new, serializer_mapping=None):
When provided, read-only fields will not be included in the resulting dictionary When provided, read-only fields will not be included in the resulting dictionary
""" """
from django.db.models import Model from django.db.models import Model
from awx.main.models.organization import Credential from awx.main.models.credential import Credential
if not(old is None or isinstance(old, Model)): if not(old is None or isinstance(old, Model)):
raise TypeError('The supplied old instance is not a valid model instance.') raise TypeError('The supplied old instance is not a valid model instance.')
@@ -281,7 +281,7 @@ def model_to_dict(obj, serializer_mapping=None):
serializer_mapping are used to determine read-only fields. serializer_mapping are used to determine read-only fields.
When provided, read-only fields will not be included in the resulting dictionary When provided, read-only fields will not be included in the resulting dictionary
""" """
from awx.main.models.organization import Credential from awx.main.models.credential import Credential
attr_d = {} attr_d = {}
if serializer_mapping is not None and obj.__class__ in serializer_mapping: if serializer_mapping is not None and obj.__class__ in serializer_mapping:
serializer_actual = serializer_mapping[obj.__class__]() serializer_actual = serializer_mapping[obj.__class__]()