From 8e6020436c8fcd1654e38593c5e05defed529ab7 Mon Sep 17 00:00:00 2001 From: AlanCoding Date: Wed, 26 Apr 2017 10:48:24 -0400 Subject: [PATCH] modularization of inventory_import command This separates file parsing logic that was mixed in with other important code inside of the inventory import command. The logic around MemObject data structures was moved to utils, and the file parsing was moved to a legacy module. As of this commit, that module can operate within the Tower environment but it will be removed. Also refactor the loggers to fix old bug and work inside of the different contexts - the Loader classes, mem objects, and hopefully the inventory modules eventually. --- awx/api/serializers.py | 2 +- .../management/commands/inventory_import.py | 745 ++++++------------ awx/main/migrations/0038_v320_release.py | 12 + awx/main/models/inventory.py | 14 +- awx/main/tasks.py | 15 +- .../tests/unit/models/test_survey_models.py | 3 + .../plugins/test_tower_inventory_legacy.py | 88 +++ .../tests/unit/utils/test_mem_inventory.py | 128 +++ awx/main/utils/formatters.py | 10 + awx/main/utils/mem_inventory.py | 315 ++++++++ awx/plugins/ansible_inventory/legacy.py | 253 ++++++ awx/settings/defaults.py | 13 + pytest.ini | 2 + 13 files changed, 1084 insertions(+), 516 deletions(-) create mode 100644 awx/main/tests/unit/plugins/test_tower_inventory_legacy.py create mode 100644 awx/main/tests/unit/utils/test_mem_inventory.py create mode 100644 awx/main/utils/mem_inventory.py create mode 100755 awx/plugins/ansible_inventory/legacy.py diff --git a/awx/api/serializers.py b/awx/api/serializers.py index 5cc40b4e65..52ae82098d 100644 --- a/awx/api/serializers.py +++ b/awx/api/serializers.py @@ -1458,7 +1458,7 @@ class InventorySourceOptionsSerializer(BaseSerializer): class Meta: fields = ('*', 'source', 'source_path', 'source_script', 'source_vars', 'credential', 'source_regions', 'instance_filters', 'group_by', 'overwrite', 'overwrite_vars', - 'timeout') + 'timeout', 'verbosity') def get_related(self, obj): res = super(InventorySourceOptionsSerializer, self).get_related(obj) diff --git a/awx/main/management/commands/inventory_import.py b/awx/main/management/commands/inventory_import.py index 7dc8a5f7bb..8c42303edf 100644 --- a/awx/main/management/commands/inventory_import.py +++ b/awx/main/management/commands/inventory_import.py @@ -2,22 +2,16 @@ # All Rights Reserved. # Python -import glob import json import logging from optparse import make_option import os import re -import shlex -import string import subprocess import sys import time import traceback -# PyYAML -import yaml - # Django from django.conf import settings from django.core.management.base import NoArgsCommand, CommandError @@ -27,7 +21,12 @@ from django.utils.encoding import smart_text # AWX from awx.main.models import * # noqa from awx.main.task_engine import TaskEnhancer -from awx.main.utils import ignore_inventory_computed_fields, check_proot_installed, wrap_args_with_proot +from awx.main.utils import ( + ignore_inventory_computed_fields, + check_proot_installed, + wrap_args_with_proot +) +from awx.main.utils.mem_inventory import MemInventory, dict_to_mem_data from awx.main.signals import disable_activity_stream logger = logging.getLogger('awx.main.commands.inventory_import') @@ -49,432 +48,159 @@ Demo mode free license count exceeded, would bring available instances to %(new_ See http://www.ansible.com/renew for licensing information.''' -class MemObject(object): - ''' - Common code shared between in-memory groups and hosts. - ''' - - def __init__(self, name, source_dir): - assert name, 'no name' - assert source_dir, 'no source dir' - self.name = name - self.source_dir = source_dir - - def load_vars(self, base_path): - all_vars = {} - files_found = 0 - for suffix in ('', '.yml', '.yaml', '.json'): - path = ''.join([base_path, suffix]).encode("utf-8") - if not os.path.exists(path): - continue - if not os.path.isfile(path): - continue - files_found += 1 - if files_found > 1: - raise RuntimeError('Multiple variable files found. There should only be one. %s ' % self.name) - vars_name = os.path.basename(os.path.dirname(path)) - logger.debug('Loading %s from %s', vars_name, path) - try: - v = yaml.safe_load(file(path, 'r').read()) - if hasattr(v, 'items'): # is a dict - all_vars.update(v) - except yaml.YAMLError as e: - if hasattr(e, 'problem_mark'): - logger.error('Invalid YAML in %s:%s col %s', path, - e.problem_mark.line + 1, - e.problem_mark.column + 1) - else: - logger.error('Error loading YAML from %s', path) - raise - return all_vars - - -class MemGroup(MemObject): - ''' - In-memory representation of an inventory group. - ''' - - def __init__(self, name, source_dir): - super(MemGroup, self).__init__(name, source_dir) - self.children = [] - self.hosts = [] - self.variables = {} - self.parents = [] - # Used on the "all" group in place of previous global variables. - # maps host and group names to hosts to prevent redudant additions - self.all_hosts = {} - self.all_groups = {} - group_vars = os.path.join(source_dir, 'group_vars', self.name) - self.variables = self.load_vars(group_vars) - logger.debug('Loaded group: %s', self.name) - - def child_group_by_name(self, name, loader): - if name == 'all': - return - logger.debug('Looking for %s as child group of %s', name, self.name) - # slight hack here, passing in 'self' for all_group but child=True won't use it - group = loader.get_group(name, self, child=True) - if group: - # don't add to child groups if already there - for g in self.children: - if g.name == name: - return g - logger.debug('Adding child group %s to group %s', group.name, self.name) - self.children.append(group) - return group - - def add_child_group(self, group): - assert group.name is not 'all', 'group name is all' - assert isinstance(group, MemGroup), 'not MemGroup instance' - logger.debug('Adding child group %s to parent %s', group.name, self.name) - if group not in self.children: - self.children.append(group) - if self not in group.parents: - group.parents.append(self) - - def add_host(self, host): - assert isinstance(host, MemHost), 'not MemHost instance' - logger.debug('Adding host %s to group %s', host.name, self.name) - if host not in self.hosts: - self.hosts.append(host) - - def debug_tree(self, group_names=None): - group_names = group_names or set() - if self.name in group_names: - return - logger.debug('Dumping tree for group "%s":', self.name) - logger.debug('- Vars: %r', self.variables) - for h in self.hosts: - logger.debug('- Host: %s, %r', h.name, h.variables) - for g in self.children: - logger.debug('- Child: %s', g.name) - logger.debug('----') - group_names.add(self.name) - for g in self.children: - g.debug_tree(group_names) - - -class MemHost(MemObject): - ''' - In-memory representation of an inventory host. - ''' - - def __init__(self, name, source_dir, port=None): - super(MemHost, self).__init__(name, source_dir) - self.variables = {} - self.instance_id = None - self.name = name - if port: - self.variables['ansible_ssh_port'] = port - host_vars = os.path.join(source_dir, 'host_vars', name) - self.variables.update(self.load_vars(host_vars)) - logger.debug('Loaded host: %s', self.name) - - -class BaseLoader(object): - ''' - Common functions for an inventory loader from a given source. - ''' - - def __init__(self, source, all_group=None, group_filter_re=None, host_filter_re=None, is_custom=False): - self.source = source - self.source_dir = os.path.dirname(self.source) - self.all_group = all_group or MemGroup('all', self.source_dir) - self.group_filter_re = group_filter_re - self.host_filter_re = host_filter_re - self.ipv6_port_re = re.compile(r'^\[([A-Fa-f0-9:]{3,})\]:(\d+?)$') - self.is_custom = is_custom - - def get_host(self, name): - ''' - Return a MemHost instance from host name, creating if needed. If name - contains brackets, they will NOT be interpreted as a host pattern. - ''' - m = self.ipv6_port_re.match(name) - if m: - host_name = m.groups()[0] - port = int(m.groups()[1]) - elif name.count(':') == 1: - host_name = name.split(':')[0] - try: - port = int(name.split(':')[1]) - except (ValueError, UnicodeDecodeError): - logger.warning(u'Invalid port "%s" for host "%s"', - name.split(':')[1], host_name) - port = None - else: - host_name = name - port = None - if self.host_filter_re and not self.host_filter_re.match(host_name): - logger.debug('Filtering host %s', host_name) - return None - host = None - if host_name not in self.all_group.all_hosts: - host = MemHost(host_name, self.source_dir, port) - self.all_group.all_hosts[host_name] = host - return self.all_group.all_hosts[host_name] - - def get_hosts(self, name): - ''' - Return iterator over one or more MemHost instances from host name or - host pattern. - ''' - def iternest(*args): - if args: - for i in args[0]: - for j in iternest(*args[1:]): - yield ''.join([str(i), j]) - else: - yield '' - if self.ipv6_port_re.match(name): - yield self.get_host(name) - return - pattern_re = re.compile(r'(\[(?:(?:\d+\:\d+)|(?:[A-Za-z]\:[A-Za-z]))(?:\:\d+)??\])') - iters = [] - for s in re.split(pattern_re, name): - if re.match(pattern_re, s): - start, end, step = (s[1:-1] + ':1').split(':')[:3] - mapfunc = str - if start in string.ascii_letters: - istart = string.ascii_letters.index(start) - iend = string.ascii_letters.index(end) + 1 - if istart >= iend: - raise ValueError('invalid host range specified') - seq = string.ascii_letters[istart:iend:int(step)] - else: - if start[0] == '0' and len(start) > 1: - if len(start) != len(end): - raise ValueError('invalid host range specified') - mapfunc = lambda x: str(x).zfill(len(start)) - seq = xrange(int(start), int(end) + 1, int(step)) - iters.append(map(mapfunc, seq)) - elif re.search(r'[\[\]]', s): - raise ValueError('invalid host range specified') - elif s: - iters.append([s]) - for iname in iternest(*iters): - yield self.get_host(iname) - - def get_group(self, name, all_group=None, child=False): - ''' - Return a MemGroup instance from group name, creating if needed. - ''' - all_group = all_group or self.all_group - if name == 'all': - return all_group - if self.group_filter_re and not self.group_filter_re.match(name): - logger.debug('Filtering group %s', name) - return None - if name not in self.all_group.all_groups: - group = MemGroup(name, self.source_dir) - if not child: - all_group.add_child_group(group) - self.all_group.all_groups[name] = group - return self.all_group.all_groups[name] - - def load(self): - raise NotImplementedError - - -class IniLoader(BaseLoader): - ''' - Loader to read inventory from an INI-formatted text file. - ''' - - def load(self): - logger.info('Reading INI source: %s', self.source) - group = self.all_group - input_mode = 'host' - for line in file(self.source, 'r'): - line = line.split('#')[0].strip() - if not line: - continue - elif line.startswith('[') and line.endswith(']'): - # Mode change, possible new group name - line = line[1:-1].strip() - if line.endswith(':vars'): - input_mode = 'vars' - line = line[:-5] - elif line.endswith(':children'): - input_mode = 'children' - line = line[:-9] - else: - input_mode = 'host' - group = self.get_group(line) - elif group: - # If group is None, we are skipping this group and shouldn't - # capture any children/variables/hosts under it. - # Add hosts with inline variables, or variables/children to - # an existing group. - tokens = shlex.split(line) - if input_mode == 'host': - for host in self.get_hosts(tokens[0]): - if not host: - continue - if len(tokens) > 1: - for t in tokens[1:]: - k,v = t.split('=', 1) - host.variables[k] = v - group.add_host(host) - elif input_mode == 'children': - group.child_group_by_name(line, self) - elif input_mode == 'vars': - for t in tokens: - k, v = t.split('=', 1) - group.variables[k] = v - # TODO: expansion patterns are probably not going to be supported. YES THEY ARE! - - -# from API documentation: -# -# if called with --list, inventory outputs like so: -# -# { -# "databases" : { -# "hosts" : [ "host1.example.com", "host2.example.com" ], -# "vars" : { -# "a" : true -# } -# }, -# "webservers" : [ "host2.example.com", "host3.example.com" ], -# "atlanta" : { -# "hosts" : [ "host1.example.com", "host4.example.com", "host5.example.com" ], -# "vars" : { -# "b" : false -# }, -# "children": [ "marietta", "5points" ], -# }, -# "marietta" : [ "host6.example.com" ], -# "5points" : [ "host7.example.com" ] -# } -# +# if called with --list, inventory outputs a JSON representing everything +# in the inventory. Supported cases are maintained in tests. # if called with --host outputs JSON for that host -class ExecutableJsonLoader(BaseLoader): +class BaseLoader(object): + use_proot = True + + def __init__(self, source, group_filter_re=None, host_filter_re=None): + self.source = source + self.exe_dir = os.path.dirname(source) + self.inventory = MemInventory( + group_filter_re=group_filter_re, host_filter_re=host_filter_re) + + def build_env(self): + # Use ansible venv if it's available and setup to use + env = dict(os.environ.items()) + if settings.ANSIBLE_USE_VENV: + env['VIRTUAL_ENV'] = settings.ANSIBLE_VENV_PATH + # env['VIRTUAL_ENV'] += settings.ANSIBLE_VENV_PATH + env['PATH'] = os.path.join(settings.ANSIBLE_VENV_PATH, "bin") + ":" + env['PATH'] + # env['PATH'] += os.path.join(settings.ANSIBLE_VENV_PATH, "bin") + ":" + env['PATH'] + venv_libdir = os.path.join(settings.ANSIBLE_VENV_PATH, "lib") + env.pop('PYTHONPATH', None) # default to none if no python_ver matches + for python_ver in ["python2.7", "python2.6"]: + if os.path.isdir(os.path.join(venv_libdir, python_ver)): + env['PYTHONPATH'] = os.path.join(venv_libdir, python_ver, "site-packages") + ":" + break + env['PYTHONPATH'] += os.path.abspath(os.path.join(settings.BASE_DIR, '..')) + ":" + return env def command_to_json(self, cmd): data = {} stdout, stderr = '', '' try: - if self.is_custom and getattr(settings, 'AWX_PROOT_ENABLED', False): + if self.use_proot and getattr(settings, 'AWX_PROOT_ENABLED', False): if not check_proot_installed(): raise RuntimeError("proot is not installed but is configured for use") - kwargs = {'proot_temp_dir': self.source_dir} # TODO: Remove proot dir - cmd = wrap_args_with_proot(cmd, self.source_dir, **kwargs) - # Use ansible venv if it's available and setup to use - env = dict(os.environ.items()) - if settings.ANSIBLE_USE_VENV: - env['VIRTUAL_ENV'] = settings.ANSIBLE_VENV_PATH - env['PATH'] = os.path.join(settings.ANSIBLE_VENV_PATH, "bin") + ":" + env['PATH'] - venv_libdir = os.path.join(settings.ANSIBLE_VENV_PATH, "lib") - env.pop('PYTHONPATH', None) # default to none if no python_ver matches - for python_ver in ["python2.7", "python2.6"]: - if os.path.isdir(os.path.join(venv_libdir, python_ver)): - env['PYTHONPATH'] = os.path.join(venv_libdir, python_ver, "site-packages") + ":" - break + kwargs = {'proot_temp_dir': self.exe_dir} # TODO: Remove proot dir + cmd = wrap_args_with_proot(cmd, self.exe_dir, **kwargs) + env = self.build_env() proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) stdout, stderr = proc.communicate() if proc.returncode != 0: raise RuntimeError('%r failed (rc=%d) with output: %s' % (cmd, proc.returncode, stderr)) - try: - data = json.loads(stdout) - except ValueError: + data = json.loads(stdout) + if not isinstance(data, dict): raise TypeError('Returned JSON must be a dictionary, got %s instead' % str(type(data))) except: logger.error('Failed to load JSON from: %s', stdout) raise return data + def build_base_args(self): + raise NotImplementedError + def load(self): - logger.info('Reading executable JSON source: %s', self.source) - data = self.command_to_json([self.source, '--list']) - _meta = data.pop('_meta', {}) + base_args = self.build_base_args() + logger.info('Reading executable JSON source: %s', ' '.join(base_args)) + data = self.command_to_json(base_args + ['--list']) + self.has_hostvars = '_meta' in data and 'hostvars' in data['_meta'] - for k,v in data.iteritems(): - group = self.get_group(k) - if not group: - continue + inventory = dict_to_mem_data(data, inventory=self.inventory) - # Load group hosts/vars/children from a dictionary. - if isinstance(v, dict): - # Process hosts within a group. - hosts = v.get('hosts', {}) - if isinstance(hosts, dict): - for hk, hv in hosts.iteritems(): - host = self.get_host(hk) - if not host: - continue - if isinstance(hv, dict): - host.variables.update(hv) - else: - self.logger.warning('Expected dict of vars for ' - 'host "%s", got %s instead', - hk, str(type(hv))) - group.add_host(host) - elif isinstance(hosts, (list, tuple)): - for hk in hosts: - host = self.get_host(hk) - if not host: - continue - group.add_host(host) - else: - logger.warning('Expected dict or list of "hosts" for ' - 'group "%s", got %s instead', k, - str(type(hosts))) - # Process group variables. - vars = v.get('vars', {}) - if isinstance(vars, dict): - group.variables.update(vars) - else: - self.logger.warning('Expected dict of vars for ' - 'group "%s", got %s instead', - k, str(type(vars))) - # Process child groups. - children = v.get('children', []) - if isinstance(children, (list, tuple)): - for c in children: - child = self.get_group(c, self.all_group, child=True) - if child: - group.add_child_group(child) - else: - self.logger.warning('Expected list of children for ' - 'group "%s", got %s instead', - k, str(type(children))) + return inventory - # Load host names from a list. - elif isinstance(v, (list, tuple)): - for h in v: - host = self.get_host(h) - if not host: - continue - group.add_host(host) - else: - logger.warning('') - self.logger.warning('Expected dict or list for group "%s", ' - 'got %s instead', k, str(type(v))) - if k != 'all': - self.all_group.add_child_group(group) +class AnsibleInventoryLoader(BaseLoader): + ''' + Given executable `source` directory, executable, or file, this will + use the ansible-inventory CLI utility to convert it into in-memory + representational objects. Example: + /usr/bin/ansible/ansible-inventory -i hosts --list + ''' + + def build_base_args(self): + # Need absolute path of anisble-inventory in order to run inside + # of bubblewrap, inside of Popen + for path in os.environ["PATH"].split(os.pathsep): + potential_path = os.path.join(path.strip('"'), 'ansible-inventory') + if os.path.isfile(potential_path) and os.access(potential_path, os.X_OK): + return [potential_path] + raise RuntimeError( + 'ImproperlyConfigured: Called with modern method but ' + 'not detect ansible-inventory on this system. ' + 'Check to see that system Ansible is 2.4 or higher.') + + +# TODO: delete after Ansible 2.3 is deprecated +class InventoryPluginLoader(BaseLoader): + ''' + Implements a different use pattern for loading JSON content from an + Ansible inventory module, example: + /path/ansible_inventory_module.py -i my_inventory.ini --list + ''' + + def __init__(self, source, module, *args, **kwargs): + super(InventoryPluginLoader, self).__init__(source, *args, **kwargs) + assert module in ['legacy', 'backport'], ( + 'Supported modules are `legacy` and `backport`, received {}'.format(module)) + self.module = module + # self.use_proot = False + + def build_env(self): + if self.module == 'legacy': + # legacy script does not rely on Ansible imports + return dict(os.environ.items()) + return super(InventoryPluginLoader, self).build_env() + + def build_base_args(self): + abs_module_path = os.path.abspath(os.path.join( + os.path.dirname(__file__), '..', '..', '..', 'plugins', + 'ansible_inventory', '{}.py'.format(self.module))) + return [abs_module_path, '-i', self.source] + + +# TODO: delete after Ansible 2.3 is deprecated +class ExecutableJsonLoader(BaseLoader): + ''' + Directly calls an inventory script, example: + /path/ec2.py --list + ''' + + def __init__(self, source, is_custom=False, **kwargs): + super(ExecutableJsonLoader, self).__init__(source, **kwargs) + self.use_proot = is_custom + + def build_base_args(self): + return [self.source] + + def load(self): + inventory = super(ExecutableJsonLoader, self).load() # Invoke the executable once for each host name we've built up # to set their variables - for k,v in self.all_group.all_hosts.iteritems(): - if 'hostvars' not in _meta: - data = self.command_to_json([self.source, '--host', k.encode("utf-8")]) - else: - data = _meta['hostvars'].get(k, {}) - if isinstance(data, dict): - v.variables.update(data) - else: - self.logger.warning('Expected dict of vars for ' - 'host "%s", got %s instead', - k, str(type(data))) + if not self.has_hostvars: + base_args = self.build_base_args() + for k,v in self.inventory.all_group.all_hosts.iteritems(): + host_data = self.command_to_json( + base_args + ['--host', k.encode("utf-8")]) + if isinstance(host_data, dict): + v.variables.update(host_data) + else: + logger.warning('Expected dict of vars for ' + 'host "%s", got %s instead', + k, str(type(host_data))) + return inventory -def load_inventory_source(source, all_group=None, group_filter_re=None, - host_filter_re=None, exclude_empty_groups=False, is_custom=False): +def load_inventory_source(source, group_filter_re=None, + host_filter_re=None, exclude_empty_groups=False, + is_custom=False, method='legacy'): ''' Load inventory from given source directory or file. ''' @@ -484,40 +210,42 @@ def load_inventory_source(source, all_group=None, group_filter_re=None, source = source.replace('satellite6.py', 'foreman.py') source = source.replace('vmware.py', 'vmware_inventory.py') logger.debug('Analyzing type of source: %s', source) - original_all_group = all_group if not os.path.exists(source): raise IOError('Source does not exist: %s' % source) source = os.path.join(os.getcwd(), os.path.dirname(source), os.path.basename(source)) source = os.path.normpath(os.path.abspath(source)) - if os.path.isdir(source): - all_group = all_group or MemGroup('all', source) - for filename in glob.glob(os.path.join(source, '*')): - if filename.endswith(".ini") or os.path.isdir(filename): - continue - load_inventory_source(filename, all_group, group_filter_re, - host_filter_re, is_custom=is_custom) + + # TODO: delete options for 'legacy' and 'backport' after Ansible 2.3 deprecated + if method == 'modern': + inventory = AnsibleInventoryLoader( + source=source, + group_filter_re=group_filter_re, + host_filter_re=host_filter_re).load() + + elif method == 'legacy' and (os.access(source, os.X_OK) and not os.path.isdir(source)): + # Legacy method of loading executable files + inventory = ExecutableJsonLoader( + source=source, + group_filter_re=group_filter_re, + host_filter_re=host_filter_re, + is_custom=is_custom).load() + else: - all_group = all_group or MemGroup('all', os.path.dirname(source)) - if os.access(source, os.X_OK): - ExecutableJsonLoader(source, all_group, group_filter_re, host_filter_re, is_custom).load() - else: - IniLoader(source, all_group, group_filter_re, host_filter_re).load() + # Load using specified ansible-inventory module + inventory = InventoryPluginLoader( + source=source, + module=method, + group_filter_re=group_filter_re, + host_filter_re=host_filter_re).load() logger.debug('Finished loading from source: %s', source) # Exclude groups that are completely empty. - if original_all_group is None and exclude_empty_groups: - for name, group in all_group.all_groups.items(): - if not group.children and not group.hosts and not group.variables: - logger.debug('Removing empty group %s', name) - for parent in group.parents: - if group in parent.children: - parent.children.remove(group) - del all_group.all_groups[name] - if original_all_group is None: - logger.info('Loaded %d groups, %d hosts', len(all_group.all_groups), - len(all_group.all_hosts)) - return all_group + if exclude_empty_groups: + inventory.delete_empty_groups() + logger.info('Loaded %d groups, %d hosts', len(inventory.all_group.all_groups), + len(inventory.all_group.all_hosts)) + return inventory.all_group class Command(NoArgsCommand): @@ -571,24 +299,18 @@ class Command(NoArgsCommand): default=None, metavar='v', help='host variable that ' 'specifies the unique, immutable instance ID, may be ' 'specified as "foo.bar" to traverse nested dicts.'), + # TODO: remove --method option when Ansible 2.3 is deprecated + make_option('--method', dest='method', type='choice', + choices=['modern', 'backport', 'legacy'], + default='legacy', help='Method for importing inventory ' + 'to use, distinguishing whether to use `ansible-inventory`, ' + 'its backport, or the legacy algorithms.'), ) - def init_logging(self): + def set_logging_level(self): log_levels = dict(enumerate([logging.WARNING, logging.INFO, logging.DEBUG, 0])) - self.logger = logging.getLogger('awx.main.commands.inventory_import') - self.logger.setLevel(log_levels.get(self.verbosity, 0)) - handler = logging.StreamHandler() - - class Formatter(logging.Formatter): - def format(self, record): - record.relativeSeconds = record.relativeCreated / 1000.0 - return logging.Formatter.format(self, record) - - formatter = Formatter('%(relativeSeconds)9.3f %(levelname)-8s %(message)s') - handler.setFormatter(formatter) - self.logger.addHandler(handler) - self.logger.propagate = False + logger.setLevel(log_levels.get(self.verbosity, 0)) def _get_instance_id(self, from_dict, default=''): ''' @@ -650,8 +372,8 @@ class Command(NoArgsCommand): raise CommandError('Inventory with %s = %s cannot be found' % q.items()[0]) except Inventory.MultipleObjectsReturned: raise CommandError('Inventory with %s = %s returned multiple results' % q.items()[0]) - self.logger.info('Updating inventory %d: %s' % (self.inventory.pk, - self.inventory.name)) + logger.info('Updating inventory %d: %s' % (self.inventory.pk, + self.inventory.name)) # Load inventory source if specified via environment variable (when # inventory_import is called from an InventoryUpdate task). @@ -727,8 +449,8 @@ class Command(NoArgsCommand): for mem_host in self.all_group.all_hosts.values(): instance_id = self._get_instance_id(mem_host.variables) if not instance_id: - self.logger.warning('Host "%s" has no "%s" variable', - mem_host.name, self.instance_id_var) + logger.warning('Host "%s" has no "%s" variable', + mem_host.name, self.instance_id_var) continue mem_host.instance_id = instance_id self.mem_instance_id_map[instance_id] = mem_host.name @@ -768,11 +490,11 @@ class Command(NoArgsCommand): for host in hosts_qs.filter(pk__in=del_pks): host_name = host.name host.delete() - self.logger.info('Deleted host "%s"', host_name) + logger.info('Deleted host "%s"', host_name) if settings.SQL_DEBUG: - self.logger.warning('host deletions took %d queries for %d hosts', - len(connection.queries) - queries_before, - len(all_del_pks)) + logger.warning('host deletions took %d queries for %d hosts', + len(connection.queries) - queries_before, + len(all_del_pks)) def _delete_groups(self): ''' @@ -799,11 +521,11 @@ class Command(NoArgsCommand): group_name = group.name with ignore_inventory_computed_fields(): group.delete() - self.logger.info('Group "%s" deleted', group_name) + logger.info('Group "%s" deleted', group_name) if settings.SQL_DEBUG: - self.logger.warning('group deletions took %d queries for %d groups', - len(connection.queries) - queries_before, - len(all_del_pks)) + logger.warning('group deletions took %d queries for %d groups', + len(connection.queries) - queries_before, + len(all_del_pks)) def _delete_group_children_and_hosts(self): ''' @@ -831,8 +553,8 @@ class Command(NoArgsCommand): for db_child in db_children.filter(pk__in=child_group_pks): group_group_count += 1 db_group.children.remove(db_child) - self.logger.info('Group "%s" removed from group "%s"', - db_child.name, db_group.name) + logger.info('Group "%s" removed from group "%s"', + db_child.name, db_group.name) # FIXME: Inventory source group relationships # Delete group/host relationships not present in imported data. db_hosts = db_group.hosts @@ -859,12 +581,12 @@ class Command(NoArgsCommand): if db_host not in db_group.hosts.all(): continue db_group.hosts.remove(db_host) - self.logger.info('Host "%s" removed from group "%s"', - db_host.name, db_group.name) + logger.info('Host "%s" removed from group "%s"', + db_host.name, db_group.name) if settings.SQL_DEBUG: - self.logger.warning('group-group and group-host deletions took %d queries for %d relationships', - len(connection.queries) - queries_before, - group_group_count + group_host_count) + logger.warning('group-group and group-host deletions took %d queries for %d relationships', + len(connection.queries) - queries_before, + group_group_count + group_host_count) def _update_inventory(self): ''' @@ -884,11 +606,11 @@ class Command(NoArgsCommand): all_obj.variables = json.dumps(db_variables) all_obj.save(update_fields=['variables']) if self.overwrite_vars: - self.logger.info('%s variables replaced from "all" group', all_name.capitalize()) + logger.info('%s variables replaced from "all" group', all_name.capitalize()) else: - self.logger.info('%s variables updated from "all" group', all_name.capitalize()) + logger.info('%s variables updated from "all" group', all_name.capitalize()) else: - self.logger.info('%s variables unmodified', all_name.capitalize()) + logger.info('%s variables unmodified', all_name.capitalize()) def _create_update_groups(self): ''' @@ -920,11 +642,11 @@ class Command(NoArgsCommand): group.variables = json.dumps(db_variables) group.save(update_fields=['variables']) if self.overwrite_vars: - self.logger.info('Group "%s" variables replaced', group.name) + logger.info('Group "%s" variables replaced', group.name) else: - self.logger.info('Group "%s" variables updated', group.name) + logger.info('Group "%s" variables updated', group.name) else: - self.logger.info('Group "%s" variables unmodified', group.name) + logger.info('Group "%s" variables unmodified', group.name) existing_group_names.add(group.name) self._batch_add_m2m(self.inventory_source.groups, group) for group_name in all_group_names: @@ -932,13 +654,13 @@ class Command(NoArgsCommand): continue mem_group = self.all_group.all_groups[group_name] group = self.inventory.groups.create(name=group_name, variables=json.dumps(mem_group.variables), description='imported') - self.logger.info('Group "%s" added', group.name) + logger.info('Group "%s" added', group.name) self._batch_add_m2m(self.inventory_source.groups, group) self._batch_add_m2m(self.inventory_source.groups, flush=True) if settings.SQL_DEBUG: - self.logger.warning('group updates took %d queries for %d groups', - len(connection.queries) - queries_before, - len(self.all_group.all_groups)) + logger.warning('group updates took %d queries for %d groups', + len(connection.queries) - queries_before, + len(self.all_group.all_groups)) def _update_db_host_from_mem_host(self, db_host, mem_host): # Update host variables. @@ -971,24 +693,24 @@ class Command(NoArgsCommand): if update_fields: db_host.save(update_fields=update_fields) if 'name' in update_fields: - self.logger.info('Host renamed from "%s" to "%s"', old_name, mem_host.name) + logger.info('Host renamed from "%s" to "%s"', old_name, mem_host.name) if 'instance_id' in update_fields: if old_instance_id: - self.logger.info('Host "%s" instance_id updated', mem_host.name) + logger.info('Host "%s" instance_id updated', mem_host.name) else: - self.logger.info('Host "%s" instance_id added', mem_host.name) + logger.info('Host "%s" instance_id added', mem_host.name) if 'variables' in update_fields: if self.overwrite_vars: - self.logger.info('Host "%s" variables replaced', mem_host.name) + logger.info('Host "%s" variables replaced', mem_host.name) else: - self.logger.info('Host "%s" variables updated', mem_host.name) + logger.info('Host "%s" variables updated', mem_host.name) else: - self.logger.info('Host "%s" variables unmodified', mem_host.name) + logger.info('Host "%s" variables unmodified', mem_host.name) if 'enabled' in update_fields: if enabled: - self.logger.info('Host "%s" is now enabled', mem_host.name) + logger.info('Host "%s" is now enabled', mem_host.name) else: - self.logger.info('Host "%s" is now disabled', mem_host.name) + logger.info('Host "%s" is now disabled', mem_host.name) self._batch_add_m2m(self.inventory_source.hosts, db_host) def _create_update_hosts(self): @@ -1062,17 +784,17 @@ class Command(NoArgsCommand): host_attrs['instance_id'] = instance_id db_host = self.inventory.hosts.create(**host_attrs) if enabled is False: - self.logger.info('Host "%s" added (disabled)', mem_host_name) + logger.info('Host "%s" added (disabled)', mem_host_name) else: - self.logger.info('Host "%s" added', mem_host_name) + logger.info('Host "%s" added', mem_host_name) self._batch_add_m2m(self.inventory_source.hosts, db_host) self._batch_add_m2m(self.inventory_source.hosts, flush=True) if settings.SQL_DEBUG: - self.logger.warning('host updates took %d queries for %d hosts', - len(connection.queries) - queries_before, - len(self.all_group.all_hosts)) + logger.warning('host updates took %d queries for %d hosts', + len(connection.queries) - queries_before, + len(self.all_group.all_hosts)) def _create_update_group_children(self): ''' @@ -1092,14 +814,14 @@ class Command(NoArgsCommand): child_names = all_child_names[offset2:(offset2 + self._batch_size)] db_children_qs = self.inventory.groups.filter(name__in=child_names) for db_child in db_children_qs.filter(children__id=db_group.id): - self.logger.info('Group "%s" already child of group "%s"', db_child.name, db_group.name) + logger.info('Group "%s" already child of group "%s"', db_child.name, db_group.name) for db_child in db_children_qs.exclude(children__id=db_group.id): self._batch_add_m2m(db_group.children, db_child) - self.logger.info('Group "%s" added as child of "%s"', db_child.name, db_group.name) + logger.info('Group "%s" added as child of "%s"', db_child.name, db_group.name) self._batch_add_m2m(db_group.children, flush=True) if settings.SQL_DEBUG: - self.logger.warning('Group-group updates took %d queries for %d group-group relationships', - len(connection.queries) - queries_before, group_group_count) + logger.warning('Group-group updates took %d queries for %d group-group relationships', + len(connection.queries) - queries_before, group_group_count) def _create_update_group_hosts(self): # For each host in a mem group, add it to the parent(s) to which it @@ -1118,23 +840,23 @@ class Command(NoArgsCommand): host_names = all_host_names[offset2:(offset2 + self._batch_size)] db_hosts_qs = self.inventory.hosts.filter(name__in=host_names) for db_host in db_hosts_qs.filter(groups__id=db_group.id): - self.logger.info('Host "%s" already in group "%s"', db_host.name, db_group.name) + logger.info('Host "%s" already in group "%s"', db_host.name, db_group.name) for db_host in db_hosts_qs.exclude(groups__id=db_group.id): self._batch_add_m2m(db_group.hosts, db_host) - self.logger.info('Host "%s" added to group "%s"', db_host.name, db_group.name) + logger.info('Host "%s" added to group "%s"', db_host.name, db_group.name) all_instance_ids = sorted([h.instance_id for h in mem_group.hosts if h.instance_id]) for offset2 in xrange(0, len(all_instance_ids), self._batch_size): instance_ids = all_instance_ids[offset2:(offset2 + self._batch_size)] db_hosts_qs = self.inventory.hosts.filter(instance_id__in=instance_ids) for db_host in db_hosts_qs.filter(groups__id=db_group.id): - self.logger.info('Host "%s" already in group "%s"', db_host.name, db_group.name) + logger.info('Host "%s" already in group "%s"', db_host.name, db_group.name) for db_host in db_hosts_qs.exclude(groups__id=db_group.id): self._batch_add_m2m(db_group.hosts, db_host) - self.logger.info('Host "%s" added to group "%s"', db_host.name, db_group.name) + logger.info('Host "%s" added to group "%s"', db_host.name, db_group.name) self._batch_add_m2m(db_group.hosts, flush=True) if settings.SQL_DEBUG: - self.logger.warning('Group-host updates took %d queries for %d group-host relationships', - len(connection.queries) - queries_before, group_host_count) + logger.warning('Group-host updates took %d queries for %d group-host relationships', + len(connection.queries) - queries_before, group_host_count) def load_into_database(self): ''' @@ -1159,14 +881,14 @@ class Command(NoArgsCommand): def check_license(self): license_info = TaskEnhancer().validate_enhancements() if license_info.get('license_key', 'UNLICENSED') == 'UNLICENSED': - self.logger.error(LICENSE_NON_EXISTANT_MESSAGE) + logger.error(LICENSE_NON_EXISTANT_MESSAGE) raise CommandError('No Tower license found!') available_instances = license_info.get('available_instances', 0) free_instances = license_info.get('free_instances', 0) time_remaining = license_info.get('time_remaining', 0) new_count = Host.objects.active_count() if time_remaining <= 0 and not license_info.get('demo', False): - self.logger.error(LICENSE_EXPIRED_MESSAGE) + logger.error(LICENSE_EXPIRED_MESSAGE) raise CommandError("License has expired!") if free_instances < 0: d = { @@ -1174,9 +896,9 @@ class Command(NoArgsCommand): 'available_instances': available_instances, } if license_info.get('demo', False): - self.logger.error(DEMO_LICENSE_MESSAGE % d) + logger.error(DEMO_LICENSE_MESSAGE % d) else: - self.logger.error(LICENSE_MESSAGE % d) + logger.error(LICENSE_MESSAGE % d) raise CommandError('License count exceeded!') def mark_license_failure(self, save=True): @@ -1185,7 +907,7 @@ class Command(NoArgsCommand): def handle_noargs(self, **options): self.verbosity = int(options.get('verbosity', 1)) - self.init_logging() + self.set_logging_level() self.inventory_name = options.get('inventory_name', None) self.inventory_id = options.get('inventory_id', None) self.overwrite = bool(options.get('overwrite', False)) @@ -1224,7 +946,7 @@ class Command(NoArgsCommand): TODO: Remove this deprecation when we remove support for rax.py ''' if self.source == "rax.py": - self.logger.info("Rackspace inventory sync is Deprecated in Tower 3.1.0 and support for Rackspace will be removed in a future release.") + logger.info("Rackspace inventory sync is Deprecated in Tower 3.1.0 and support for Rackspace will be removed in a future release.") begin = time.time() self.load_inventory_from_database() @@ -1249,11 +971,12 @@ class Command(NoArgsCommand): self.inventory_update.save() # Load inventory from source. - self.all_group = load_inventory_source(self.source, None, + self.all_group = load_inventory_source(self.source, self.group_filter_re, self.host_filter_re, self.exclude_empty_groups, - self.is_custom) + self.is_custom, + options.get('method')) self.all_group.debug_tree() with batch_role_ancestor_rebuilding(): @@ -1262,7 +985,7 @@ class Command(NoArgsCommand): with transaction.atomic(): # Merge/overwrite inventory into database. if settings.SQL_DEBUG: - self.logger.warning('loading into database...') + logger.warning('loading into database...') with ignore_inventory_computed_fields(): if getattr(settings, 'ACTIVITY_STREAM_ENABLED_FOR_INVENTORY_SYNC', True): self.load_into_database() @@ -1273,8 +996,8 @@ class Command(NoArgsCommand): queries_before2 = len(connection.queries) self.inventory.update_computed_fields() if settings.SQL_DEBUG: - self.logger.warning('update computed fields took %d queries', - len(connection.queries) - queries_before2) + logger.warning('update computed fields took %d queries', + len(connection.queries) - queries_before2) try: self.check_license() except CommandError as e: @@ -1282,11 +1005,11 @@ class Command(NoArgsCommand): raise e if settings.SQL_DEBUG: - self.logger.warning('Inventory import completed for %s in %0.1fs', - self.inventory_source.name, time.time() - begin) + logger.warning('Inventory import completed for %s in %0.1fs', + self.inventory_source.name, time.time() - begin) else: - self.logger.info('Inventory import completed for %s in %0.1fs', - self.inventory_source.name, time.time() - begin) + logger.info('Inventory import completed for %s in %0.1fs', + self.inventory_source.name, time.time() - begin) status = 'successful' # If we're in debug mode, then log the queries and time @@ -1294,9 +1017,9 @@ class Command(NoArgsCommand): if settings.SQL_DEBUG: queries_this_import = connection.queries[queries_before:] sqltime = sum(float(x['time']) for x in queries_this_import) - self.logger.warning('Inventory import required %d queries ' - 'taking %0.3fs', len(queries_this_import), - sqltime) + logger.warning('Inventory import required %d queries ' + 'taking %0.3fs', len(queries_this_import), + sqltime) except Exception as e: if isinstance(e, KeyboardInterrupt): status = 'canceled' diff --git a/awx/main/migrations/0038_v320_release.py b/awx/main/migrations/0038_v320_release.py index 60c38009e2..009e9432ea 100644 --- a/awx/main/migrations/0038_v320_release.py +++ b/awx/main/migrations/0038_v320_release.py @@ -114,4 +114,16 @@ class Migration(migrations.Migration): name='notificationtemplate', unique_together=set([('organization', 'name')]), ), + + # Add verbosity option to inventory updates + migrations.AddField( + model_name='inventorysource', + name='verbosity', + field=models.PositiveIntegerField(default=1, blank=True, choices=[(0, b'0 (WARNING)'), (1, b'1 (INFO)'), (2, b'2 (DEBUG)')]), + ), + migrations.AddField( + model_name='inventoryupdate', + name='verbosity', + field=models.PositiveIntegerField(default=1, blank=True, choices=[(0, b'0 (WARNING)'), (1, b'1 (INFO)'), (2, b'2 (DEBUG)')]), + ), ] diff --git a/awx/main/models/inventory.py b/awx/main/models/inventory.py index c0d6af6a28..c5a0d83dfd 100644 --- a/awx/main/models/inventory.py +++ b/awx/main/models/inventory.py @@ -733,6 +733,13 @@ class InventorySourceOptions(BaseModel): ('custom', _('Custom Script')), ] + # From the options of the Django management base command + INVENTORY_UPDATE_VERBOSITY_CHOICES = [ + (0, '0 (WARNING)'), + (1, '1 (INFO)'), + (2, '2 (DEBUG)'), + ] + # Use tools/scripts/get_ec2_filter_names.py to build this list. INSTANCE_FILTER_NAMES = [ "architecture", @@ -879,6 +886,11 @@ class InventorySourceOptions(BaseModel): blank=True, default=0, ) + verbosity = models.PositiveIntegerField( + choices=INVENTORY_UPDATE_VERBOSITY_CHOICES, + blank=True, + default=1, + ) @classmethod def get_ec2_region_choices(cls): @@ -1116,7 +1128,7 @@ class InventorySource(UnifiedJobTemplate, InventorySourceOptions): def _get_unified_job_field_names(cls): return ['name', 'description', 'source', 'source_path', 'source_script', 'source_vars', 'schedule', 'credential', 'source_regions', 'instance_filters', 'group_by', 'overwrite', 'overwrite_vars', - 'timeout', 'launch_type', 'scm_project_update',] + 'timeout', 'verbosity', 'launch_type', 'scm_project_update',] def save(self, *args, **kwargs): # If update_fields has been specified, add our field names to it, diff --git a/awx/main/tasks.py b/awx/main/tasks.py index 58bf153316..37d83f4815 100644 --- a/awx/main/tasks.py +++ b/awx/main/tasks.py @@ -1761,6 +1761,16 @@ class RunInventoryUpdate(BaseTask): elif inventory_update.source == 'file': args.append(inventory_update.get_actual_source_path()) + if hasattr(settings, 'ANSIBLE_INVENTORY_MODULE'): + module_name = settings.ANSIBLE_INVENTORY_MODULE + else: + module_name = 'backport' + v = get_ansible_version() + if Version(v) > Version('2.4'): + module_name = 'modern' + elif Version(v) < Version('2.2'): + module_name = 'legacy' + args.extend(['--method', module_name]) elif inventory_update.source == 'custom': runpath = tempfile.mkdtemp(prefix='ansible_tower_launch_') handle, path = tempfile.mkstemp(dir=runpath) @@ -1770,11 +1780,10 @@ class RunInventoryUpdate(BaseTask): f.write(inventory_update.source_script.script.encode('utf-8')) f.close() os.chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) - args.append(runpath) + args.append(path) args.append("--custom") self.custom_dir_path.append(runpath) - verbosity = getattr(settings, 'INVENTORY_UPDATE_VERBOSITY', 1) - args.append('-v%d' % verbosity) + args.append('-v%d' % inventory_update.verbosity) if settings.DEBUG: args.append('--traceback') return args diff --git a/awx/main/tests/unit/models/test_survey_models.py b/awx/main/tests/unit/models/test_survey_models.py index 502ee5c5d9..af2d0e3735 100644 --- a/awx/main/tests/unit/models/test_survey_models.py +++ b/awx/main/tests/unit/models/test_survey_models.py @@ -79,6 +79,7 @@ def test_job_args_unredacted_passwords(job): assert extra_vars['secret_key'] == 'my_password' +@pytest.mark.survey def test_update_kwargs_survey_invalid_default(survey_spec_factory): spec = survey_spec_factory('var2') spec['spec'][0]['required'] = False @@ -91,6 +92,7 @@ def test_update_kwargs_survey_invalid_default(survey_spec_factory): assert json.loads(defaulted_extra_vars['extra_vars'])['var2'] == 2 +@pytest.mark.survey @pytest.mark.parametrize("question_type,default,expect_use,expect_value", [ ("multiplechoice", "", False, 'N/A'), # historical bug ("multiplechoice", "zeb", False, 'N/A'), # zeb not in choices @@ -125,6 +127,7 @@ def test_optional_survey_question_defaults( assert 'c' not in defaulted_extra_vars['extra_vars'] +@pytest.mark.survey class TestWorkflowSurveys: def test_update_kwargs_survey_defaults(self, survey_spec_factory): "Assure that the survey default over-rides a JT variable" diff --git a/awx/main/tests/unit/plugins/test_tower_inventory_legacy.py b/awx/main/tests/unit/plugins/test_tower_inventory_legacy.py new file mode 100644 index 0000000000..75300583d5 --- /dev/null +++ b/awx/main/tests/unit/plugins/test_tower_inventory_legacy.py @@ -0,0 +1,88 @@ +import pytest +import mock + +# python +import os +import sys + +# AWX main +from awx.main.utils.mem_inventory import MemGroup + +# Add awx/plugins to sys.path so we can use the plugin +TEST_DIR = os.path.dirname(__file__) +path = os.path.abspath(os.path.join( + TEST_DIR, '..', '..', '..', '..', 'plugins', 'ansible_inventory')) +if path not in sys.path: + sys.path.insert(0, path) + +# AWX plugin +from legacy import IniLoader # noqa + + +@pytest.fixture +def loader(): + return IniLoader(TEST_DIR, MemGroup('all')) + + +@pytest.mark.inventory_import +class TestHostPatterns: + + def test_simple_host_pattern(self, loader): + assert [h.name for h in loader.get_host_names_from_entry('server[1:3].io')] == [ + 'server1.io', 'server2.io', 'server3.io'] + + def test_host_with_port(self, loader): + assert [h.name for h in loader.get_host_names_from_entry('server.com:8080')] == ['server.com'] + assert [h.variables['ansible_port'] for h in loader.get_host_names_from_entry('server.com:8080')] == [8080] + + def test_host_pattern_with_step(self, loader): + assert [h.name for h in loader.get_host_names_from_entry('server[0:10:5].io')] == [ + 'server0.io', 'server5.io', 'server10.io'] + + def test_invalid_host_pattern_with_step(self, loader): + with pytest.raises(ValueError): + print [h.name for h in loader.get_host_names_from_entry('server[00:010:5].io')] + + def test_alphanumeric_pattern(self, loader): + assert [h.name for h in loader.get_host_names_from_entry('server[a:c].io')] == [ + 'servera.io', 'serverb.io', 'serverc.io'] + + def test_invalid_alphanumeric_pattern(self, loader): + with pytest.raises(ValueError): + print [h.name for h in loader.get_host_names_from_entry('server[c:a].io')] + + +@pytest.mark.inventory_import +class TestLoader: + + def test_group_and_host(self, loader): + group_and_host = mock.MagicMock(return_value=[ + '[my_group]', + 'my_host' + ]) + with mock.patch.object(loader, 'file_line_iterable', group_and_host): + inventory = loader.load() + g = inventory.all_group.children[0] + assert g.name == 'my_group' + assert g.hosts[0].name + + def test_host_comment(self, loader): + group_and_host = mock.MagicMock(return_value=['my_host # and a comment']) + with mock.patch.object(loader, 'file_line_iterable', group_and_host): + inventory = loader.load() + assert inventory.all_group.hosts[0].name == 'my_host' + + def test_group_parentage(self, loader): + group_and_host = mock.MagicMock(return_value=[ + '[my_group] # and a comment', + '[my_group:children] # and a comment', + 'child_group # and a comment' + ]) + with mock.patch.object(loader, 'file_line_iterable', group_and_host): + inventory = loader.load() + g = inventory.get_group('my_group') + assert g.name == 'my_group' + child = g.children[0] + assert child.name == 'child_group' + # We can not list non-root-level groups in the all_group + assert child not in inventory.all_group.children diff --git a/awx/main/tests/unit/utils/test_mem_inventory.py b/awx/main/tests/unit/utils/test_mem_inventory.py new file mode 100644 index 0000000000..078d303323 --- /dev/null +++ b/awx/main/tests/unit/utils/test_mem_inventory.py @@ -0,0 +1,128 @@ +# AWX utils +from awx.main.utils.mem_inventory import ( + MemInventory, + mem_data_to_dict, dict_to_mem_data +) + +import pytest +import json + + +@pytest.fixture +def memory_inventory(): + inventory = MemInventory() + h = inventory.get_host('my_host') + h.variables = {'foo': 'bar'} + g = inventory.get_group('my_group') + g.variables = {'foobar': 'barfoo'} + h2 = inventory.get_host('group_host') + g.add_host(h2) + return inventory + + +@pytest.fixture +def JSON_of_inv(): + # Implemented as fixture becuase it may be change inside of tests + return { + "_meta": { + "hostvars": { + "group_host": {}, + "my_host": {"foo": "bar"} + } + }, + "all": {"children": ["my_group", "ungrouped"]}, + "my_group": { + "hosts": ["group_host"], + "vars": {"foobar": "barfoo"} + }, + "ungrouped": {"hosts": ["my_host"]} + } + + +# Structure mentioned in official docs +# https://docs.ansible.com/ansible/dev_guide/developing_inventory.html +@pytest.fixture +def JSON_with_lists(): + docs_example = '''{ + "databases" : { + "hosts" : [ "host1.example.com", "host2.example.com" ], + "vars" : { + "a" : true + } + }, + "webservers" : [ "host2.example.com", "host3.example.com" ], + "atlanta" : { + "hosts" : [ "host1.example.com", "host4.example.com", "host5.example.com" ], + "vars" : { + "b" : false + }, + "children": [ "marietta", "5points" ] + }, + "marietta" : [ "host6.example.com" ], + "5points" : [ "host7.example.com" ] + }''' + return json.loads(docs_example) + + +# MemObject basic operations tests + +@pytest.mark.inventory_import +def test_inventory_create_all_group(): + inventory = MemInventory() + assert inventory.all_group.name == 'all' + + +@pytest.mark.inventory_import +def test_create_child_group(): + inventory = MemInventory() + g1 = inventory.get_group('g1') + # Create new group by name as child of g1 + g2 = inventory.get_group('g2', g1) + # Check that child is in the children of the parent group + assert g1.children == [g2] + # Check that _only_ the parent group is listed as a root group + assert inventory.all_group.children == [g1] + # Check that _both_ are tracked by the global `all_groups` dict + assert set(inventory.all_group.all_groups.values()) == set([g1, g2]) + + +@pytest.mark.inventory_import +def test_ungrouped_mechanics(): + # ansible-inventory returns a group called `ungrouped` + # we can safely treat this the same as the `all_group` + inventory = MemInventory() + ug = inventory.get_group('ungrouped') + assert ug is inventory.all_group + + +# MemObject --> JSON tests + +@pytest.mark.inventory_import +def test_convert_memory_to_JSON_with_vars(memory_inventory): + data = mem_data_to_dict(memory_inventory) + # Assertions about the variables on the objects + assert data['_meta']['hostvars']['my_host'] == {'foo': 'bar'} + assert data['my_group']['vars'] == {'foobar': 'barfoo'} + # Orphan host should be found in ungrouped false group + assert data['ungrouped']['hosts'] == ['my_host'] + + +# JSON --> MemObject tests + +@pytest.mark.inventory_import +def test_convert_JSON_to_memory_with_vars(JSON_of_inv): + inventory = dict_to_mem_data(JSON_of_inv) + # Assertions about the variables on the objects + assert inventory.get_host('my_host').variables == {'foo': 'bar'} + assert inventory.get_group('my_group').variables == {'foobar': 'barfoo'} + # Host should be child of group + assert inventory.get_host('group_host') in inventory.get_group('my_group').hosts + + +@pytest.mark.inventory_import +def test_host_lists_accepted(JSON_with_lists): + inventory = dict_to_mem_data(JSON_with_lists) + assert inventory.get_group('marietta').name == 'marietta' + # Check that marietta's hosts was saved + h = inventory.get_host('host6.example.com') + assert h.name == 'host6.example.com' diff --git a/awx/main/utils/formatters.py b/awx/main/utils/formatters.py index 868f1c50ee..5ee26062f7 100644 --- a/awx/main/utils/formatters.py +++ b/awx/main/utils/formatters.py @@ -5,6 +5,16 @@ from logstash.formatter import LogstashFormatterVersion1 from copy import copy import json import time +import logging + + +class TimeFormatter(logging.Formatter): + ''' + Custom log formatter used for inventory imports + ''' + def format(self, record): + record.relativeSeconds = record.relativeCreated / 1000.0 + return logging.Formatter.format(self, record) class LogstashFormatter(LogstashFormatterVersion1): diff --git a/awx/main/utils/mem_inventory.py b/awx/main/utils/mem_inventory.py new file mode 100644 index 0000000000..b7530fd358 --- /dev/null +++ b/awx/main/utils/mem_inventory.py @@ -0,0 +1,315 @@ +# Copyright (c) 2017 Ansible by Red Hat +# All Rights Reserved. + +# Python +import re +import logging +from collections import OrderedDict + + +# Logger is used for any data-related messages so that the log level +# can be adjusted on command invocation +logger = logging.getLogger('awx.main.commands.inventory_import') + + +__all__ = ['MemHost', 'MemGroup', 'MemInventory', + 'mem_data_to_dict', 'dict_to_mem_data'] + + +ipv6_port_re = re.compile(r'^\[([A-Fa-f0-9:]{3,})\]:(\d+?)$') + + +# Models for in-memory objects that represent an inventory + + +class MemObject(object): + ''' + Common code shared between in-memory groups and hosts. + ''' + + def __init__(self, name): + assert name, 'no name' + self.name = name + + +class MemGroup(MemObject): + ''' + In-memory representation of an inventory group. + ''' + + def __init__(self, name): + super(MemGroup, self).__init__(name) + self.children = [] + self.hosts = [] + self.variables = {} + self.parents = [] + # Used on the "all" group in place of previous global variables. + # maps host and group names to hosts to prevent redudant additions + self.all_hosts = {} + self.all_groups = {} + self.variables = {} + logger.debug('Loaded group: %s', self.name) + + def __repr__(self): + return '<_in-memory-group_ `{}`>'.format(self.name) + + def add_child_group(self, group): + assert group.name is not 'all', 'group name is all' + assert isinstance(group, MemGroup), 'not MemGroup instance' + logger.debug('Adding child group %s to parent %s', group.name, self.name) + if group not in self.children: + self.children.append(group) + if self not in group.parents: + group.parents.append(self) + + def add_host(self, host): + assert isinstance(host, MemHost), 'not MemHost instance' + logger.debug('Adding host %s to group %s', host.name, self.name) + if host not in self.hosts: + self.hosts.append(host) + + def debug_tree(self, group_names=None): + group_names = group_names or set() + if self.name in group_names: + return + logger.debug('Dumping tree for group "%s":', self.name) + logger.debug('- Vars: %r', self.variables) + for h in self.hosts: + logger.debug('- Host: %s, %r', h.name, h.variables) + for g in self.children: + logger.debug('- Child: %s', g.name) + logger.debug('----') + group_names.add(self.name) + for g in self.children: + g.debug_tree(group_names) + + +class MemHost(MemObject): + ''' + In-memory representation of an inventory host. + ''' + + def __init__(self, name, port=None): + super(MemHost, self).__init__(name) + self.variables = {} + self.instance_id = None + self.name = name + if port: + # was `ansible_ssh_port` in older Ansible/Tower versions + self.variables['ansible_port'] = port + logger.debug('Loaded host: %s', self.name) + + def __repr__(self): + return '<_in-memory-host_ `{}`>'.format(self.name) + + +class MemInventory(object): + ''' + Common functions for an inventory loader from a given source. + ''' + def __init__(self, all_group=None, group_filter_re=None, host_filter_re=None): + if all_group: + assert isinstance(all_group, MemGroup), '{} is not MemGroup instance'.format(all_group) + self.all_group = all_group + else: + self.all_group = self.create_group('all') + self.group_filter_re = group_filter_re + self.host_filter_re = host_filter_re + + def create_host(self, host_name, port): + host = MemHost(host_name, port) + self.all_group.all_hosts[host_name] = host + return host + + def get_host(self, name): + ''' + Return a MemHost instance from host name, creating if needed. If name + contains brackets, they will NOT be interpreted as a host pattern. + ''' + m = ipv6_port_re.match(name) + if m: + host_name = m.groups()[0] + port = int(m.groups()[1]) + elif name.count(':') == 1: + host_name = name.split(':')[0] + try: + port = int(name.split(':')[1]) + except (ValueError, UnicodeDecodeError): + logger.warning(u'Invalid port "%s" for host "%s"', + name.split(':')[1], host_name) + port = None + else: + host_name = name + port = None + if self.host_filter_re and not self.host_filter_re.match(host_name): + logger.debug('Filtering host %s', host_name) + return None + if host_name not in self.all_group.all_hosts: + self.create_host(host_name, port) + return self.all_group.all_hosts[host_name] + + def create_group(self, group_name): + group = MemGroup(group_name) + if group_name not in ['all', 'ungrouped']: + self.all_group.all_groups[group_name] = group + return group + + def get_group(self, name, all_group=None, child=False): + ''' + Return a MemGroup instance from group name, creating if needed. + ''' + all_group = all_group or self.all_group + if name in ['all', 'ungrouped']: + return all_group + if self.group_filter_re and not self.group_filter_re.match(name): + logger.debug('Filtering group %s', name) + return None + if name not in self.all_group.all_groups: + group = self.create_group(name) + if not child: + all_group.add_child_group(group) + return self.all_group.all_groups[name] + + def delete_empty_groups(self): + for name, group in self.all_group.all_groups.items(): + if not group.children and not group.hosts and not group.variables: + logger.debug('Removing empty group %s', name) + for parent in group.parents: + if group in parent.children: + parent.children.remove(group) + del self.all_group.all_groups[name] + + +# Conversion utilities + +def mem_data_to_dict(inventory): + ''' + Given an in-memory construct of an inventory, returns a dictionary that + follows Ansible guidelines on the structure of dynamic inventory sources + + May be replaced by removing in-memory constructs within this file later + ''' + all_group = inventory.all_group + inventory_data = OrderedDict([]) + # Save hostvars to _meta + inventory_data['_meta'] = OrderedDict([]) + hostvars = OrderedDict([]) + for name, host_obj in all_group.all_hosts.items(): + hostvars[name] = host_obj.variables + inventory_data['_meta']['hostvars'] = hostvars + # Save children of `all` group + inventory_data['all'] = OrderedDict([]) + if all_group.variables: + inventory_data['all']['vars'] = all_group.variables + inventory_data['all']['children'] = [c.name for c in all_group.children] + inventory_data['all']['children'].append('ungrouped') + # Save details of declared groups individually + ungrouped_hosts = set(all_group.all_hosts.keys()) + for name, group_obj in all_group.all_groups.items(): + group_host_names = [h.name for h in group_obj.hosts] + group_children_names = [c.name for c in group_obj.children] + group_data = OrderedDict([]) + if group_host_names: + group_data['hosts'] = group_host_names + ungrouped_hosts.difference_update(group_host_names) + if group_children_names: + group_data['children'] = group_children_names + if group_obj.variables: + group_data['vars'] = group_obj.variables + inventory_data[name] = group_data + # Save ungrouped hosts + inventory_data['ungrouped'] = OrderedDict([]) + if ungrouped_hosts: + inventory_data['ungrouped']['hosts'] = list(ungrouped_hosts) + return inventory_data + + +def dict_to_mem_data(data, inventory=None): + ''' + In-place operation on `inventory`, adds contents from `data` to the + in-memory representation of memory. + May be destructive on `data` + ''' + assert isinstance(data, dict), 'Expected dict, received {}'.format(type(data)) + if inventory is None: + inventory = MemInventory() + + _meta = data.pop('_meta', {}) + + for k,v in data.iteritems(): + group = inventory.get_group(k) + if not group: + continue + + # Load group hosts/vars/children from a dictionary. + if isinstance(v, dict): + # Process hosts within a group. + hosts = v.get('hosts', {}) + if isinstance(hosts, dict): + for hk, hv in hosts.iteritems(): + host = inventory.get_host(hk) + if not host: + continue + if isinstance(hv, dict): + host.variables.update(hv) + else: + logger.warning('Expected dict of vars for ' + 'host "%s", got %s instead', + hk, str(type(hv))) + group.add_host(host) + elif isinstance(hosts, (list, tuple)): + for hk in hosts: + host = inventory.get_host(hk) + if not host: + continue + group.add_host(host) + else: + logger.warning('Expected dict or list of "hosts" for ' + 'group "%s", got %s instead', k, + str(type(hosts))) + # Process group variables. + vars = v.get('vars', {}) + if isinstance(vars, dict): + group.variables.update(vars) + else: + logger.warning('Expected dict of vars for ' + 'group "%s", got %s instead', + k, str(type(vars))) + # Process child groups. + children = v.get('children', []) + if isinstance(children, (list, tuple)): + for c in children: + child = inventory.get_group(c, inventory.all_group, child=True) + if child and c != 'ungrouped': + group.add_child_group(child) + else: + logger.warning('Expected list of children for ' + 'group "%s", got %s instead', + k, str(type(children))) + + # Load host names from a list. + elif isinstance(v, (list, tuple)): + for h in v: + host = inventory.get_host(h) + if not host: + continue + group.add_host(host) + else: + logger.warning('') + logger.warning('Expected dict or list for group "%s", ' + 'got %s instead', k, str(type(v))) + + if k not in ['all', 'ungrouped']: + inventory.all_group.add_child_group(group) + + if _meta: + for k,v in inventory.all_group.all_hosts.iteritems(): + meta_hostvars = _meta['hostvars'].get(k, {}) + if isinstance(meta_hostvars, dict): + v.variables.update(meta_hostvars) + else: + logger.warning('Expected dict of vars for ' + 'host "%s", got %s instead', + k, str(type(meta_hostvars))) + + return inventory diff --git a/awx/plugins/ansible_inventory/legacy.py b/awx/plugins/ansible_inventory/legacy.py new file mode 100755 index 0000000000..a9e1c549b5 --- /dev/null +++ b/awx/plugins/ansible_inventory/legacy.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python + +# Copyright (c) 2017 Ansible by Red Hat +# All Rights Reserved. + +# Python +import glob +import json +import logging +import os +import shlex +import argparse +import re +import string +import yaml + +# import sys +# # Add awx/plugins to sys.path so we can use the plugin +# TEST_DIR = os.path.dirname(__file__) +# path = os.path.abspath(os.path.join(TEST_DIR, '..', '..', 'main', 'utils')) +# if path not in sys.path: +# sys.path.insert(0, path) + +# AWX +from awx.main.utils.mem_inventory import ( + MemGroup, MemInventory, mem_data_to_dict, ipv6_port_re +) # NOQA + + +# Logger is used for any data-related messages so that the log level +# can be adjusted on command invocation +# logger = logging.getLogger('awx.plugins.ansible_inventory.tower_inventory_legacy') +logger = logging.getLogger('awx.main.management.commands.inventory_import') + + +class FileMemInventory(MemInventory): + ''' + Adds on file-specific actions + ''' + def __init__(self, source_dir, all_group, group_filter_re, host_filter_re, **kwargs): + super(FileMemInventory, self).__init__(all_group, group_filter_re, host_filter_re, **kwargs) + self.source_dir = source_dir + + def load_vars(self, mem_object, dir_path): + all_vars = {} + files_found = 0 + for suffix in ('', '.yml', '.yaml', '.json'): + path = ''.join([dir_path, suffix]).encode("utf-8") + if not os.path.exists(path): + continue + if not os.path.isfile(path): + continue + files_found += 1 + if files_found > 1: + raise RuntimeError( + 'Multiple variable files found. There should only ' + 'be one. %s ' % self.name) + vars_name = os.path.basename(os.path.dirname(path)) + logger.debug('Loading %s from %s', vars_name, path) + try: + v = yaml.safe_load(file(path, 'r').read()) + if hasattr(v, 'items'): # is a dict + all_vars.update(v) + except yaml.YAMLError as e: + if hasattr(e, 'problem_mark'): + logger.error('Invalid YAML in %s:%s col %s', path, + e.problem_mark.line + 1, + e.problem_mark.column + 1) + else: + logger.error('Error loading YAML from %s', path) + raise + return all_vars + + def create_host(self, host_name, port): + host = super(FileMemInventory, self).create_host(host_name, port) + host_vars_dir = os.path.join(self.source_dir, 'host_vars', host.name) + host.variables.update(self.load_vars(host, host_vars_dir)) + return host + + def create_group(self, group_name): + group = super(FileMemInventory, self).create_group(group_name) + group_vars_dir = os.path.join(self.source_dir, 'group_vars', group.name) + group.variables.update(self.load_vars(group, group_vars_dir)) + return group + + +class IniLoader(object): + ''' + Loader to read inventory from an INI-formatted text file. + ''' + def __init__(self, source, all_group=None, group_filter_re=None, host_filter_re=None): + self.source = source + self.source_dir = os.path.dirname(self.source) + self.inventory = FileMemInventory( + self.source_dir, all_group, + group_filter_re=group_filter_re, host_filter_re=host_filter_re) + + def get_host_names_from_entry(self, name): + ''' + Given an entry in an Ansible inventory file, return an iterable of + the resultant host names, accounting for expansion patterns. + Examples: + web1.server.com -> web1.server.com + web[1:2].server.com -> web1.server.com, web2.server.com + ''' + def iternest(*args): + if args: + for i in args[0]: + for j in iternest(*args[1:]): + yield ''.join([str(i), j]) + else: + yield '' + if ipv6_port_re.match(name): + yield self.inventory.get_host(name) + return + pattern_re = re.compile(r'(\[(?:(?:\d+\:\d+)|(?:[A-Za-z]\:[A-Za-z]))(?:\:\d+)??\])') + iters = [] + for s in re.split(pattern_re, name): + if re.match(pattern_re, s): + start, end, step = (s[1:-1] + ':1').split(':')[:3] + mapfunc = str + if start in string.ascii_letters: + istart = string.ascii_letters.index(start) + iend = string.ascii_letters.index(end) + 1 + if istart >= iend: + raise ValueError('invalid host range specified') + seq = string.ascii_letters[istart:iend:int(step)] + else: + if start[0] == '0' and len(start) > 1: + if len(start) != len(end): + raise ValueError('invalid host range specified') + mapfunc = lambda x: str(x).zfill(len(start)) + seq = xrange(int(start), int(end) + 1, int(step)) + iters.append(map(mapfunc, seq)) + elif re.search(r'[\[\]]', s): + raise ValueError('invalid host range specified') + elif s: + iters.append([s]) + for iname in iternest(*iters): + yield self.inventory.get_host(iname) + + @staticmethod + def file_line_iterable(filename): + return file(filename, 'r') + + def load(self): + logger.info('Reading INI source: %s', self.source) + group = self.inventory.all_group + input_mode = 'host' + for line in self.file_line_iterable(self.source): + line = line.split('#')[0].strip() + if not line: + continue + elif line.startswith('[') and line.endswith(']'): + # Mode change, possible new group name + line = line[1:-1].strip() + if line.endswith(':vars'): + input_mode = 'vars' + line = line[:-5] + elif line.endswith(':children'): + input_mode = 'children' + line = line[:-9] + else: + input_mode = 'host' + group = self.inventory.get_group(line) + elif group: + # If group is None, we are skipping this group and shouldn't + # capture any children/variables/hosts under it. + # Add hosts with inline variables, or variables/children to + # an existing group. + tokens = shlex.split(line) + if input_mode == 'host': + for host in self.get_host_names_from_entry(tokens[0]): + if not host: + continue + if len(tokens) > 1: + for t in tokens[1:]: + k,v = t.split('=', 1) + host.variables[k] = v + group.add_host(host) + elif input_mode == 'children': + self.inventory.get_group(line, group) + elif input_mode == 'vars': + for t in tokens: + k, v = t.split('=', 1) + group.variables[k] = v + return self.inventory + + +def load_inventory_source(source, all_group=None, group_filter_re=None, + host_filter_re=None, exclude_empty_groups=False): + ''' + Load inventory from given source directory or file. + ''' + original_all_group = all_group + if not os.path.exists(source): + raise IOError('Source does not exist: %s' % source) + source = os.path.join(os.getcwd(), os.path.dirname(source), + os.path.basename(source)) + source = os.path.normpath(os.path.abspath(source)) + if os.path.isdir(source): + all_group = all_group or MemGroup('all') + for filename in glob.glob(os.path.join(source, '*')): + if filename.endswith(".ini") or os.path.isdir(filename): + continue + load_inventory_source(filename, all_group, group_filter_re, + host_filter_re) + elif os.access(source, os.X_OK): + raise NotImplementedError( + 'Source has been marked as executable, but script-based sources ' + 'are not supported by the legacy file import plugin. ' + 'This problem may be solved by upgrading to use `ansible-inventory`.') + else: + all_group = all_group or MemGroup('all', os.path.dirname(source)) + IniLoader(source, all_group, group_filter_re, host_filter_re).load() + + logger.debug('Finished loading from source: %s', source) + # Exclude groups that are completely empty. + if original_all_group is None and exclude_empty_groups: + for name, group in all_group.all_groups.items(): + if not group.children and not group.hosts and not group.variables: + logger.debug('Removing empty group %s', name) + for parent in group.parents: + if group in parent.children: + parent.children.remove(group) + del all_group.all_groups[name] + if original_all_group is None: + logger.info('Loaded %d groups, %d hosts', len(all_group.all_groups), + len(all_group.all_hosts)) + return all_group + + +def parse_args(): + parser = argparse.ArgumentParser(description='Ansible Inventory Import Plugin - Fallback Option') + parser.add_argument( + '-i', '--inventory-file', dest='inventory', required=True, + help="Specify inventory host path (does not support CSV host paths)") + parser.add_argument( + '--list', action='store_true', dest='list', default=None, required=True, + help='Output all hosts info, works as inventory script') + # --host and --graph and not supported + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + source = args.inventory + memory_data = load_inventory_source( + source, group_filter_re=None, + host_filter_re=None, exclude_empty_groups=False) + mem_inventory = MemInventory(all_group=memory_data) + inventory_dict = mem_data_to_dict(mem_inventory) + print json.dumps(inventory_dict, indent=4) diff --git a/awx/settings/defaults.py b/awx/settings/defaults.py index 856c883aba..bd8a62c66d 100644 --- a/awx/settings/defaults.py +++ b/awx/settings/defaults.py @@ -907,6 +907,10 @@ LOGGING = { }, 'json': { '()': 'awx.main.utils.formatters.LogstashFormatter' + }, + 'timed_import': { + '()': 'awx.main.utils.formatters.TimeFormatter', + 'format': '%(relativeSeconds)9.3f %(levelname)-8s %(message)s' } }, 'handlers': { @@ -958,6 +962,11 @@ LOGGING = { 'backupCount': 5, 'formatter':'simple', }, + 'inventory_import': { + 'level': 'DEBUG', + 'class':'logging.StreamHandler', + 'formatter': 'timed_import', + }, 'task_system': { 'level': 'INFO', 'class':'logging.handlers.RotatingFileHandler', @@ -1029,6 +1038,10 @@ LOGGING = { 'awx.main.commands.run_callback_receiver': { 'handlers': ['callback_receiver'], }, + 'awx.main.commands.inventory_import': { + 'handlers': ['inventory_import'], + 'propagate': False + }, 'awx.main.tasks': { 'handlers': ['task_system'], }, diff --git a/pytest.ini b/pytest.ini index 2993b1f577..4884ec4897 100644 --- a/pytest.ini +++ b/pytest.ini @@ -8,3 +8,5 @@ markers = ac: access control test license_feature: ensure license features are accessible or not depending on license mongo_db: drop mongodb test database before test runs + survey: tests related to survey feature + inventory_import: tests of code used by inventory import command