diff --git a/awx/main/models/credential.py b/awx/main/models/credential.py index 8f0d32d87a..82e0f576e1 100644 --- a/awx/main/models/credential.py +++ b/awx/main/models/credential.py @@ -165,7 +165,7 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique): else: ssh_key_data = self.ssh_key_data try: - key_data = self._validate_ssh_private_key(ssh_key_data) + key_data = validate_ssh_private_key(ssh_key_data) except ValidationError: return False else: @@ -238,128 +238,6 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique): raise ValidationError('Project name required for OpenStack credential.') return project - def _validate_ssh_private_key(self, data): - """Validate that the given SSH private key or certificate is, - in fact, valid. - """ - # Map the X in BEGIN X PRIVATE KEY to the key type (ssh-keygen -t). - # Tower jobs using OPENSSH format private keys may still fail if the - # system SSH implementation lacks support for this format. - key_types = { - 'RSA': 'rsa', - 'DSA': 'dsa', - 'EC': 'ecdsa', - 'OPENSSH': 'ed25519', - '': 'rsa1', - } - # Key properties to return if valid. - key_data = { - 'key_type': None, # Key type (from above mapping). - 'key_seg': '', # Key segment (all text including begin/end). - 'key_b64': '', # Key data as base64. - 'key_bin': '', # Key data as binary. - 'key_enc': None, # Boolean, whether key is encrypted. - 'cert_seg': '', # Cert segment (all text including begin/end). - 'cert_b64': '', # Cert data as base64. - 'cert_bin': '', # Cert data as binary. - } - data = data.strip() - validation_error = ValidationError('Invalid private key') - - # Sanity check: We may potentially receive a full PEM certificate, - # and we want to accept these. - cert_begin_re = r'(-{4,})\s*BEGIN\s+CERTIFICATE\s*(-{4,})' - cert_end_re = r'(-{4,})\s*END\s+CERTIFICATE\s*(-{4,})' - cert_begin_match = re.search(cert_begin_re, data) - cert_end_match = re.search(cert_end_re, data) - if cert_begin_match and not cert_end_match: - raise validation_error - elif not cert_begin_match and cert_end_match: - raise validation_error - elif cert_begin_match and cert_end_match: - cert_dashes = set([cert_begin_match.groups()[0], cert_begin_match.groups()[1], - cert_end_match.groups()[0], cert_end_match.groups()[1]]) - if len(cert_dashes) != 1: - raise validation_error - key_data['cert_seg'] = data[cert_begin_match.start():cert_end_match.end()] - - # Find the private key, and also ensure that it internally matches - # itself. - # Set up the valid private key header and footer. - begin_re = r'(-{4,})\s*BEGIN\s+([A-Z0-9]+)?\s*PRIVATE\sKEY\s*(-{4,})' - end_re = r'(-{4,})\s*END\s+([A-Z0-9]+)?\s*PRIVATE\sKEY\s*(-{4,})' - begin_match = re.search(begin_re, data) - end_match = re.search(end_re, data) - if not begin_match or not end_match: - raise validation_error - - # Ensure that everything, such as dash counts and key type, lines up, - # and raise an error if it does not. - dashes = set([begin_match.groups()[0], begin_match.groups()[2], - end_match.groups()[0], end_match.groups()[2]]) - if len(dashes) != 1: - raise validation_error - if begin_match.groups()[1] != end_match.groups()[1]: - raise validation_error - key_type = begin_match.groups()[1] - try: - key_data['key_type'] = key_types[key_type] - except KeyError: - raise ValidationError('Invalid private key: unsupported type %s' % key_type) - - # The private key data begins and ends with the private key. - key_data['key_seg'] = data[begin_match.start():end_match.end()] - - # Establish that we are able to base64 decode the private key; - # if we can't, then it's not a valid key. - # - # If we got a certificate, validate that also, in the same way. - header_re = re.compile(r'^(.+?):\s*?(.+?)(\\??)$') - for segment_name in ('cert', 'key'): - segment_to_validate = key_data['%s_seg' % segment_name] - # If we have nothing; skip this one. - # We've already validated that we have a private key above, - # so we don't need to do it again. - if not segment_to_validate: - continue - - # Ensure that this segment is valid base64 data. - base64_data = '' - line_continues = False - lines = segment_to_validate.splitlines() - for line in lines[1:-1]: - line = line.strip() - if not line: - continue - if line_continues: - line_continues = line.endswith('\\') - continue - line_match = header_re.match(line) - if line_match: - line_continues = line.endswith('\\') - continue - base64_data += line - try: - decoded_data = base64.b64decode(base64_data) - if not decoded_data: - raise validation_error - key_data['%s_b64' % segment_name] = base64_data - key_data['%s_bin' % segment_name] = decoded_data - except TypeError: - raise validation_error - - # Determine if key is encrypted. - if key_data['key_type'] == 'ed25519': - # See https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L3218 - # Decoded key data starts with magic string (null-terminated), four byte - # length field, followed by the ciphername -- if ciphername is anything - # other than 'none' the key is encrypted. - key_data['key_enc'] = not bool(key_data['key_bin'].startswith('openssh-key-v1\x00\x00\x00\x00\x04none')) - else: - key_data['key_enc'] = bool('ENCRYPTED' in key_data['key_seg']) - - return key_data - def clean_ssh_key_data(self): if self.pk: ssh_key_data = decrypt_field(self, 'ssh_key_data') @@ -379,7 +257,7 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique): # Validate the private key to ensure that it looks like something # that we can accept. - self._validate_ssh_private_key(ssh_key_data) + validate_ssh_private_key(ssh_key_data) return self.ssh_key_data # No need to return decrypted version here. def clean_ssh_key_unlock(self): @@ -471,3 +349,124 @@ class Credential(PasswordFieldsModel, CommonModelNameNotUnique): update_fields.append('cloud') super(Credential, self).save(*args, **kwargs) +def validate_ssh_private_key(data): + """Validate that the given SSH private key or certificate is, + in fact, valid. + """ + # Map the X in BEGIN X PRIVATE KEY to the key type (ssh-keygen -t). + # Tower jobs using OPENSSH format private keys may still fail if the + # system SSH implementation lacks support for this format. + key_types = { + 'RSA': 'rsa', + 'DSA': 'dsa', + 'EC': 'ecdsa', + 'OPENSSH': 'ed25519', + '': 'rsa1', + } + # Key properties to return if valid. + key_data = { + 'key_type': None, # Key type (from above mapping). + 'key_seg': '', # Key segment (all text including begin/end). + 'key_b64': '', # Key data as base64. + 'key_bin': '', # Key data as binary. + 'key_enc': None, # Boolean, whether key is encrypted. + 'cert_seg': '', # Cert segment (all text including begin/end). + 'cert_b64': '', # Cert data as base64. + 'cert_bin': '', # Cert data as binary. + } + data = data.strip() + validation_error = ValidationError('Invalid private key') + + # Sanity check: We may potentially receive a full PEM certificate, + # and we want to accept these. + cert_begin_re = r'(-{4,})\s*BEGIN\s+CERTIFICATE\s*(-{4,})' + cert_end_re = r'(-{4,})\s*END\s+CERTIFICATE\s*(-{4,})' + cert_begin_match = re.search(cert_begin_re, data) + cert_end_match = re.search(cert_end_re, data) + if cert_begin_match and not cert_end_match: + raise validation_error + elif not cert_begin_match and cert_end_match: + raise validation_error + elif cert_begin_match and cert_end_match: + cert_dashes = set([cert_begin_match.groups()[0], cert_begin_match.groups()[1], + cert_end_match.groups()[0], cert_end_match.groups()[1]]) + if len(cert_dashes) != 1: + raise validation_error + key_data['cert_seg'] = data[cert_begin_match.start():cert_end_match.end()] + + # Find the private key, and also ensure that it internally matches + # itself. + # Set up the valid private key header and footer. + begin_re = r'(-{4,})\s*BEGIN\s+([A-Z0-9]+)?\s*PRIVATE\sKEY\s*(-{4,})' + end_re = r'(-{4,})\s*END\s+([A-Z0-9]+)?\s*PRIVATE\sKEY\s*(-{4,})' + begin_match = re.search(begin_re, data) + end_match = re.search(end_re, data) + if not begin_match or not end_match: + raise validation_error + + # Ensure that everything, such as dash counts and key type, lines up, + # and raise an error if it does not. + dashes = set([begin_match.groups()[0], begin_match.groups()[2], + end_match.groups()[0], end_match.groups()[2]]) + if len(dashes) != 1: + raise validation_error + if begin_match.groups()[1] != end_match.groups()[1]: + raise validation_error + key_type = begin_match.groups()[1] or '' + try: + key_data['key_type'] = key_types[key_type] + except KeyError: + raise ValidationError('Invalid private key: unsupported type %s' % key_type) + + # The private key data begins and ends with the private key. + key_data['key_seg'] = data[begin_match.start():end_match.end()] + + # Establish that we are able to base64 decode the private key; + # if we can't, then it's not a valid key. + # + # If we got a certificate, validate that also, in the same way. + header_re = re.compile(r'^(.+?):\s*?(.+?)(\\??)$') + for segment_name in ('cert', 'key'): + segment_to_validate = key_data['%s_seg' % segment_name] + # If we have nothing; skip this one. + # We've already validated that we have a private key above, + # so we don't need to do it again. + if not segment_to_validate: + continue + + # Ensure that this segment is valid base64 data. + base64_data = '' + line_continues = False + lines = segment_to_validate.splitlines() + for line in lines[1:-1]: + line = line.strip() + if not line: + continue + if line_continues: + line_continues = line.endswith('\\') + continue + line_match = header_re.match(line) + if line_match: + line_continues = line.endswith('\\') + continue + base64_data += line + try: + decoded_data = base64.b64decode(base64_data) + if not decoded_data: + raise validation_error + key_data['%s_b64' % segment_name] = base64_data + key_data['%s_bin' % segment_name] = decoded_data + except TypeError: + raise validation_error + + # Determine if key is encrypted. + if key_data['key_type'] == 'ed25519': + # See https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L3218 + # Decoded key data starts with magic string (null-terminated), four byte + # length field, followed by the ciphername -- if ciphername is anything + # other than 'none' the key is encrypted. + key_data['key_enc'] = not bool(key_data['key_bin'].startswith('openssh-key-v1\x00\x00\x00\x00\x04none')) + else: + key_data['key_enc'] = bool('ENCRYPTED' in key_data['key_seg']) + + return key_data diff --git a/awx/main/tasks.py b/awx/main/tasks.py index 4276ac85b0..0ac7776547 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -1219,6 +1219,7 @@ class RunInventoryUpdate(BaseTask): env['GCE_EMAIL'] = passwords.get('source_username', '') env['GCE_PROJECT'] = passwords.get('source_project', '') env['GCE_PEM_FILE_PATH'] = cloud_credential + env['GCE_ZONE'] = inventory_source.source_regions elif inventory_update.source == 'openstack': env['OS_CLIENT_CONFIG_FILE'] = cloud_credential elif inventory_update.source == 'file': diff --git a/awx/main/tests/unit/test_credentials.py b/awx/main/tests/unit/test_credentials.py new file mode 100644 index 0000000000..7445d28fda --- /dev/null +++ b/awx/main/tests/unit/test_credentials.py @@ -0,0 +1,56 @@ +from django.core.exceptions import ValidationError +from awx.main.models.credential import validate_ssh_private_key + +import pytest + +def test_valid_rsa_key(): + begin = """-----BEGIN RSA PRIVATE KEY-----""" + end = """-----END RSA PRIVATE KEY-----""" + unvalidated_key = build_key(begin, body, end) + key_data = validate_ssh_private_key(unvalidated_key) + assert key_data['key_type'] == 'rsa' + +def test_invalid_key(): + unvalidated_key = build_key(key_begin, body, "END KEY") + with pytest.raises(ValidationError): + validate_ssh_private_key(unvalidated_key) + +def test_key_type_empty(): + unvalidated_key = build_key(key_begin, body, key_end) + key_data = validate_ssh_private_key(unvalidated_key) + assert key_data['key_type'] == 'rsa1' + + +def build_key(begin, body, end): + return """%s%s%s""" % (begin, body, end) + +key_begin = """-----BEGIN PRIVATE KEY-----""" +key_end = """-----END PRIVATE KEY-----""" + +body = """ +uFZFyag7VVqI+q/oGnQu+wj/pMi5ox+Qz5L3W0D745DzwgDXOeObAfNlr9NtIKbn +sZ5E0+rYB4Q/U0CYr5juNJQV1dbxq2Em1160axboe2QbvX6wE6Sm6wW9b9cr+PoF +MoYQebUnCY0ObrLbrRugSfZc17lyxK0ZGRgPXKhpMg6Ecv8XpvhjUYU9Esyqfuco +/p26Q140/HsHeHYNma0dQHCEjMr/qEzOY1qguHj+hRf3SARtM9Q+YNgpxchcDDVS +O+n+8Ljd/p82bpEJwxmpXealeWbI6gB9/R6wcCL+ZyCZpnHJd/NJ809Vtu47ZdDi +E6jvqS/3AQhuQKhJlLSDIzezB2VKKrHwOvHkg/+uLoCqHN34Gk6Qio7x69SvXy88 +a7q9D1l/Zx60o08FyZyqlo7l0l/r8EY+36cuI/lvAvfxc5VHVEOvKseUjFRBiCv9 +MkKNxaScoYsPwY7SIS6gD93tg3eM5pA0nfMfya9u1+uq/QCM1gNG3mm6Zd8YG4c/ +Dx4bmsj8cp5ni/Ffl/sKzKYq1THunJEFGXOZRibdxk/Fal3SQrRAwy7CgLQL8SMh +IWqcFm25OtSOP1r1LE25t5pQsMdmp0IP2fEF0t/pXPm1ZfrTurPMqpo4FGm2hkki +U3sH/o6nrkSOjklOLWlwtTkkL4dWPlNwc8OYj8zFizXJkAfv1spzhv3lRouNkw4N +Mm22W7us2f3Ob0H5C07k26h6VuXX+0AybD4tIIcUXCLoNTqA0HvqhKpEuHu3Ck10 +RaB8xHTxgwdhGVaNHMfy9B9l4tNs3Tb5k0LyeRRGVDhWCFo6axYULYebkj+hFLLY ++JE5RzPDFpTf1xbuT+e56H/lLFCUdDu0bn+D0W4ifXaVFegak4r6O4B53CbMqr+R +t6qDPKLUIuVJXK0J6Ay6XgmheXJGbgKh4OtDsc06gsTCE1nY4f/Z82AQahPBfTtF +J2z+NHdsLPn//HlxspGQtmLpuS7Wx0HYXZ+kPRSiE/vmITw85R2u8JSHQicVNN4C +2rlUo15TIU3tTx+WUIrHKHPidUNNotRb2p9n9FoSidU6upKnQHAT/JNv/zcvaia3 +Bhl/wagheWTDnFKSmJ4HlKxplM/32h6MfHqsMVOl4F6eZWKaKgSgN8doXyFJo+sc +yAC6S0gJlD2gQI24iTI4Du1+UGh2MGb69eChvi5mbbdesaZrlR1dRqZpHG+6ob4H +nYLndRvobXS5l6pgGTDRYoUgSbQe21a7Uf3soGl5jHqLWc1zEPwrxV7Wr31mApr6 +8VtGZcLSr0691Q1NLO3eIfuhbMN2mssX/Sl4t+4BibaucNIMfmhKQi8uHtwAXb47 ++TMFlG2EQhZULFM4fLdF1vaizInU3cBk8lsz8i71tDc+5VQTEwoEB7Gksy/XZWEt +6SGHxXUDtNYa+G2O+sQhgqBjLIkVTV6KJOpvNZM+s8Vzv8qoFnD7isKBBrRvF1bP +GOXEG1jd7nSR0WSwcMCHGOrFEELDQPw3k5jqEdPFgVODoZPr+drZVnVz5SAGBk5Y +wsCNaDW+1dABYFlqRTepP5rrSu9wHnRAZ3ZGv+DHoGqenIC5IBR0sQ== +""" diff --git a/awx/plugins/inventory/gce.py b/awx/plugins/inventory/gce.py index b13c194a6e..498511d635 100755 --- a/awx/plugins/inventory/gce.py +++ b/awx/plugins/inventory/gce.py @@ -117,8 +117,10 @@ class GceInventory(object): pretty=self.args.pretty)) sys.exit(0) + zones = self.parse_env_zones() + # Otherwise, assume user wants all instances grouped - print(self.json_format_dict(self.group_instances(), + print(self.json_format_dict(self.group_instances(zones), pretty=self.args.pretty)) sys.exit(0) @@ -190,6 +192,14 @@ class GceInventory(object): ) return gce + def parse_env_zones(self): + '''returns a list of comma seperated zones parsed from the GCE_ZONE environment variable. + If provided, this will be used to filter the results of the grouped_instances call''' + import csv + reader = csv.reader([os.environ.get('GCE_ZONE',"")], skipinitialspace=True) + zones = [r for r in reader] + return [z for z in zones[0]] + def parse_cli_args(self): ''' Command line argument processing ''' @@ -240,7 +250,7 @@ class GceInventory(object): except Exception as e: return None - def group_instances(self): + def group_instances(self, zones=None): '''Group all instances''' groups = {} meta = {} @@ -252,6 +262,12 @@ class GceInventory(object): meta["hostvars"][name] = self.node_to_dict(node) zone = node.extra['zone'].name + + # To avoid making multiple requests per zone + # we list all nodes and then filter the results + if zones and zone not in zones: + continue + if groups.has_key(zone): groups[zone].append(name) else: groups[zone] = [name]